sglang 0.1.20__py3-none-any.whl → 0.1.22__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +8 -8
- sglang/api.py +1 -1
- sglang/backend/runtime_endpoint.py +14 -4
- sglang/backend/vertexai.py +5 -4
- sglang/bench.py +627 -0
- sglang/bench_latency.py +22 -20
- sglang/bench_serving.py +758 -0
- sglang/check_env.py +171 -0
- sglang/global_config.py +3 -1
- sglang/lang/backend/__init__.py +0 -0
- sglang/lang/backend/anthropic.py +77 -0
- sglang/lang/backend/base_backend.py +80 -0
- sglang/lang/backend/litellm.py +90 -0
- sglang/lang/backend/openai.py +438 -0
- sglang/lang/backend/runtime_endpoint.py +283 -0
- sglang/lang/backend/vertexai.py +149 -0
- sglang/lang/chat_template.py +2 -2
- sglang/lang/ir.py +3 -3
- sglang/lang/tracer.py +1 -1
- sglang/launch_server.py +1 -1
- sglang/launch_server_llavavid.py +1 -4
- sglang/srt/conversation.py +1 -1
- sglang/srt/layers/context_flashattention_nopad.py +0 -29
- sglang/srt/layers/extend_attention.py +0 -39
- sglang/srt/layers/linear.py +869 -0
- sglang/srt/layers/quantization/__init__.py +49 -0
- sglang/srt/layers/quantization/fp8.py +662 -0
- sglang/srt/layers/radix_attention.py +31 -5
- sglang/srt/layers/token_attention.py +1 -51
- sglang/srt/managers/controller/cuda_graph_runner.py +44 -18
- sglang/srt/managers/controller/infer_batch.py +76 -72
- sglang/srt/managers/controller/manager_multi.py +109 -98
- sglang/srt/managers/controller/manager_single.py +105 -50
- sglang/srt/managers/controller/model_runner.py +42 -18
- sglang/srt/managers/controller/radix_cache.py +4 -3
- sglang/srt/managers/controller/schedule_heuristic.py +4 -0
- sglang/srt/managers/controller/tp_worker.py +143 -156
- sglang/srt/managers/detokenizer_manager.py +49 -5
- sglang/srt/managers/io_struct.py +36 -17
- sglang/srt/managers/tokenizer_manager.py +228 -125
- sglang/srt/memory_pool.py +46 -58
- sglang/srt/model_loader/model_loader.py +277 -0
- sglang/srt/model_loader/utils.py +260 -0
- sglang/srt/models/chatglm.py +1 -0
- sglang/srt/models/dbrx.py +1 -0
- sglang/srt/models/grok.py +1 -0
- sglang/srt/models/internlm2.py +317 -0
- sglang/srt/models/llama2.py +65 -16
- sglang/srt/models/llama_classification.py +1 -0
- sglang/srt/models/llava.py +1 -0
- sglang/srt/models/llavavid.py +1 -0
- sglang/srt/models/minicpm.py +2 -8
- sglang/srt/models/mixtral.py +1 -0
- sglang/srt/models/mixtral_quant.py +1 -0
- sglang/srt/models/qwen.py +1 -0
- sglang/srt/models/qwen2.py +6 -0
- sglang/srt/models/qwen2_moe.py +130 -108
- sglang/srt/models/stablelm.py +1 -0
- sglang/srt/openai_api/adapter.py +432 -0
- sglang/srt/openai_api/api_adapter.py +432 -0
- sglang/srt/openai_api/openai_api_adapter.py +431 -0
- sglang/srt/openai_api/openai_protocol.py +207 -0
- sglang/srt/openai_api/protocol.py +208 -0
- sglang/srt/openai_protocol.py +17 -0
- sglang/srt/sampling_params.py +2 -0
- sglang/srt/server.py +114 -90
- sglang/srt/server_args.py +27 -17
- sglang/srt/utils.py +17 -118
- sglang/test/test_conversation.py +1 -1
- sglang/test/test_openai_protocol.py +1 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +2 -2
- {sglang-0.1.20.dist-info → sglang-0.1.22.dist-info}/METADATA +157 -159
- sglang-0.1.22.dist-info/RECORD +103 -0
- {sglang-0.1.20.dist-info → sglang-0.1.22.dist-info}/WHEEL +1 -1
- sglang-0.1.20.dist-info/RECORD +0 -82
- {sglang-0.1.20.dist-info → sglang-0.1.22.dist-info}/LICENSE +0 -0
- {sglang-0.1.20.dist-info → sglang-0.1.22.dist-info}/top_level.txt +0 -0
@@ -3,19 +3,17 @@ A controller that manages multiple data parallel workers.
|
|
3
3
|
Each data parallel worker can manage multiple tensor parallel workers.
|
4
4
|
"""
|
5
5
|
|
6
|
-
import
|
6
|
+
import dataclasses
|
7
7
|
import logging
|
8
|
-
|
8
|
+
import multiprocessing
|
9
|
+
import os
|
9
10
|
from enum import Enum, auto
|
10
|
-
from typing import Dict
|
11
11
|
|
12
|
+
import numpy as np
|
12
13
|
import zmq
|
13
|
-
import zmq.asyncio
|
14
14
|
|
15
|
-
from sglang.
|
16
|
-
|
17
|
-
DataParallelWorkerThread,
|
18
|
-
start_data_parallel_worker,
|
15
|
+
from sglang.srt.managers.controller.manager_single import (
|
16
|
+
start_controller_process as start_controller_process_single,
|
19
17
|
)
|
20
18
|
from sglang.srt.managers.io_struct import (
|
21
19
|
AbortReq,
|
@@ -23,12 +21,15 @@ from sglang.srt.managers.io_struct import (
|
|
23
21
|
TokenizedGenerateReqInput,
|
24
22
|
)
|
25
23
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
24
|
+
from sglang.srt.utils import kill_parent_process
|
26
25
|
from sglang.utils import get_exception_traceback
|
27
26
|
|
28
27
|
logger = logging.getLogger("srt.controller")
|
29
28
|
|
30
29
|
|
31
30
|
class LoadBalanceMethod(Enum):
|
31
|
+
"""Load balance method."""
|
32
|
+
|
32
33
|
ROUND_ROBIN = auto()
|
33
34
|
SHORTEST_QUEUE = auto()
|
34
35
|
|
@@ -41,151 +42,161 @@ class LoadBalanceMethod(Enum):
|
|
41
42
|
raise ValueError(f"Invalid load balance method: {method}") from exc
|
42
43
|
|
43
44
|
|
44
|
-
|
45
|
+
@dataclasses.dataclass
|
46
|
+
class WorkerHandle:
|
47
|
+
"""Store the handle of a data parallel worker."""
|
48
|
+
|
49
|
+
proc: multiprocessing.Process
|
50
|
+
queue: multiprocessing.Queue
|
51
|
+
|
52
|
+
|
53
|
+
class ControllerMulti:
|
54
|
+
"""A controller that manages multiple data parallel workers."""
|
55
|
+
|
45
56
|
def __init__(
|
46
57
|
self,
|
47
|
-
load_balance_method: str,
|
48
58
|
server_args: ServerArgs,
|
49
59
|
port_args: PortArgs,
|
50
60
|
model_overide_args,
|
51
61
|
):
|
52
|
-
|
62
|
+
# Parse args
|
53
63
|
self.server_args = server_args
|
54
64
|
self.port_args = port_args
|
65
|
+
self.model_overide_args = model_overide_args
|
66
|
+
self.load_balance_method = LoadBalanceMethod.from_str(
|
67
|
+
server_args.load_balance_method
|
68
|
+
)
|
55
69
|
|
56
|
-
|
57
|
-
|
70
|
+
# Init communication
|
71
|
+
context = zmq.Context()
|
72
|
+
self.recv_from_tokenizer = context.socket(zmq.PULL)
|
73
|
+
self.recv_from_tokenizer.bind(f"tcp://127.0.0.1:{port_args.controller_port}")
|
58
74
|
|
59
|
-
|
75
|
+
# Dispatch method
|
76
|
+
self.round_robin_counter = 0
|
77
|
+
dispatch_lookup = {
|
60
78
|
LoadBalanceMethod.ROUND_ROBIN: self.round_robin_scheduler,
|
61
79
|
LoadBalanceMethod.SHORTEST_QUEUE: self.shortest_queue_scheduler,
|
62
80
|
}
|
63
|
-
self.dispatching =
|
64
|
-
|
65
|
-
# Init communication
|
66
|
-
context = zmq.asyncio.Context()
|
67
|
-
self.recv_from_tokenizer = context.socket(zmq.PULL)
|
68
|
-
self.recv_from_tokenizer.bind(f"tcp://127.0.0.1:{port_args.router_port}")
|
69
|
-
|
70
|
-
# Init status
|
71
|
-
self.recv_reqs = []
|
81
|
+
self.dispatching = dispatch_lookup[self.load_balance_method]
|
72
82
|
|
73
83
|
# Start data parallel workers
|
74
|
-
self.workers
|
75
|
-
tp_size = server_args.tp_size
|
76
|
-
|
77
|
-
def start_dp_worker(i):
|
78
|
-
try:
|
79
|
-
gpu_ids = list(range(i * tp_size, (i + 1) * tp_size))
|
80
|
-
worker_thread = start_data_parallel_worker(
|
81
|
-
server_args, port_args, model_overide_args, gpu_ids, i
|
82
|
-
)
|
83
|
-
self.workers[i] = worker_thread
|
84
|
-
except Exception:
|
85
|
-
logger.error(
|
86
|
-
f"Failed to start local worker {i}\n{get_exception_traceback()}"
|
87
|
-
)
|
88
|
-
|
84
|
+
self.workers = []
|
89
85
|
for i in range(server_args.dp_size):
|
90
|
-
start_dp_worker(i)
|
86
|
+
self.start_dp_worker(i)
|
91
87
|
|
92
|
-
|
93
|
-
|
94
|
-
|
88
|
+
def start_dp_worker(self, dp_worker_id: int):
|
89
|
+
tp_size = self.server_args.tp_size
|
90
|
+
|
91
|
+
pipe_controller_reader, pipe_controller_writer = multiprocessing.Pipe(
|
92
|
+
duplex=False
|
93
|
+
)
|
95
94
|
|
96
|
-
|
97
|
-
|
95
|
+
gpu_ids = list(range(dp_worker_id * tp_size, (dp_worker_id + 1) * tp_size))
|
96
|
+
queue = multiprocessing.Queue()
|
97
|
+
proc = multiprocessing.Process(
|
98
|
+
target=start_controller_process_single,
|
99
|
+
args=(
|
100
|
+
self.server_args,
|
101
|
+
self.port_args,
|
102
|
+
pipe_controller_writer,
|
103
|
+
self.model_overide_args,
|
104
|
+
True,
|
105
|
+
gpu_ids,
|
106
|
+
dp_worker_id,
|
107
|
+
queue,
|
108
|
+
),
|
109
|
+
)
|
110
|
+
proc.start()
|
98
111
|
|
99
|
-
|
100
|
-
|
112
|
+
controller_init_state = pipe_controller_reader.recv()
|
113
|
+
if controller_init_state != "init ok":
|
114
|
+
raise RuntimeError(
|
115
|
+
f"Initialization failed. controller_init_state: {controller_init_state}"
|
116
|
+
)
|
117
|
+
self.workers.append(
|
118
|
+
WorkerHandle(
|
119
|
+
proc=proc,
|
120
|
+
queue=queue,
|
121
|
+
)
|
122
|
+
)
|
101
123
|
|
102
|
-
|
103
|
-
available_workers = list(self.workers.keys())
|
124
|
+
def round_robin_scheduler(self, input_requests):
|
104
125
|
for r in input_requests:
|
105
|
-
self.
|
126
|
+
self.workers[self.round_robin_counter].queue.put(r)
|
106
127
|
self.round_robin_counter = (self.round_robin_counter + 1) % len(
|
107
|
-
|
128
|
+
self.workers
|
108
129
|
)
|
109
|
-
return
|
110
130
|
|
111
|
-
|
131
|
+
def shortest_queue_scheduler(self, input_requests):
|
112
132
|
for r in input_requests:
|
113
|
-
|
114
|
-
|
115
|
-
)
|
116
|
-
self.put_req_to_worker(worker, r)
|
117
|
-
return
|
118
|
-
|
119
|
-
async def remove_dead_workers(self):
|
120
|
-
for i in list(self.workers.keys()):
|
121
|
-
worker_thread = self.workers[i]
|
122
|
-
if not worker_thread.liveness:
|
123
|
-
worker_thread.join()
|
124
|
-
# move unsuccessful requests back to the queue
|
125
|
-
while not worker_thread.request_queue.empty():
|
126
|
-
self.recv_reqs.append(worker_thread.request_queue.get())
|
127
|
-
del self.workers[i]
|
128
|
-
logger.info(f"Stale worker {i} removed")
|
129
|
-
|
130
|
-
async def loop_for_forward(self):
|
131
|
-
while True:
|
132
|
-
await self.remove_dead_workers()
|
133
|
+
queue_sizes = [worker.queue.qsize() for worker in self.workers]
|
134
|
+
wid = np.argmin(queue_sizes)
|
135
|
+
self.workers[wid].queue.put(r)
|
133
136
|
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
await self.dispatching(next_step_input)
|
139
|
-
# else:
|
140
|
-
# logger.error("There is no live worker.")
|
137
|
+
def loop_for_forward(self):
|
138
|
+
while True:
|
139
|
+
recv_reqs = self.recv_requests()
|
140
|
+
self.dispatching(recv_reqs)
|
141
141
|
|
142
|
-
|
142
|
+
def recv_requests(self):
|
143
|
+
recv_reqs = []
|
143
144
|
|
144
|
-
async def loop_for_recv_requests(self):
|
145
145
|
while True:
|
146
|
-
|
146
|
+
try:
|
147
|
+
recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
|
148
|
+
except zmq.ZMQError:
|
149
|
+
break
|
150
|
+
|
147
151
|
if isinstance(recv_req, FlushCacheReq):
|
148
152
|
# TODO(lsyin): apply more specific flushCacheReq
|
149
|
-
for
|
150
|
-
|
151
|
-
elif isinstance(recv_req, TokenizedGenerateReqInput):
|
152
|
-
self.recv_reqs.append(recv_req)
|
153
|
+
for worker in self.workers:
|
154
|
+
worker.queue.put(recv_req)
|
153
155
|
elif isinstance(recv_req, AbortReq):
|
154
156
|
in_queue = False
|
155
|
-
for i, req in enumerate(
|
157
|
+
for i, req in enumerate(recv_reqs):
|
156
158
|
if req.rid == recv_req.rid:
|
157
|
-
|
159
|
+
recv_reqs[i] = recv_req
|
158
160
|
in_queue = True
|
159
161
|
break
|
160
162
|
if not in_queue:
|
161
163
|
# Send abort req to all TP groups
|
162
|
-
for worker in
|
163
|
-
|
164
|
+
for worker in self.workers:
|
165
|
+
worker.queue.put(recv_req)
|
166
|
+
elif isinstance(recv_req, TokenizedGenerateReqInput):
|
167
|
+
recv_reqs.append(recv_req)
|
164
168
|
else:
|
165
169
|
logger.error(f"Invalid object: {recv_req}")
|
166
170
|
|
171
|
+
return recv_reqs
|
172
|
+
|
167
173
|
|
168
174
|
def start_controller_process(
|
169
175
|
server_args: ServerArgs,
|
170
176
|
port_args: PortArgs,
|
171
177
|
pipe_writer,
|
172
|
-
model_overide_args
|
178
|
+
model_overide_args: dict,
|
173
179
|
):
|
180
|
+
"""Start a controller process."""
|
181
|
+
|
174
182
|
logging.basicConfig(
|
175
183
|
level=getattr(logging, server_args.log_level.upper()),
|
176
184
|
format="%(message)s",
|
177
185
|
)
|
178
186
|
|
179
187
|
try:
|
180
|
-
controller =
|
181
|
-
server_args.load_balance_method, server_args, port_args, model_overide_args
|
182
|
-
)
|
188
|
+
controller = ControllerMulti(server_args, port_args, model_overide_args)
|
183
189
|
except Exception:
|
184
190
|
pipe_writer.send(get_exception_traceback())
|
185
191
|
raise
|
186
192
|
|
187
193
|
pipe_writer.send("init ok")
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
194
|
+
|
195
|
+
try:
|
196
|
+
controller.loop_for_forward()
|
197
|
+
except Exception:
|
198
|
+
logger.error("Exception in ControllerMulti:\n" + get_exception_traceback())
|
199
|
+
finally:
|
200
|
+
for w in controller.workers:
|
201
|
+
os.kill(w.proc.pid, 9)
|
202
|
+
kill_parent_process()
|
@@ -1,102 +1,157 @@
|
|
1
1
|
"""A controller that manages a group of tensor parallel workers."""
|
2
2
|
|
3
|
-
import asyncio
|
4
3
|
import logging
|
5
|
-
|
4
|
+
import multiprocessing
|
5
|
+
import os
|
6
|
+
from typing import List
|
6
7
|
|
7
|
-
import uvloop
|
8
8
|
import zmq
|
9
|
-
import zmq.asyncio
|
10
9
|
|
11
|
-
from sglang.
|
12
|
-
|
10
|
+
from sglang.srt.managers.controller.tp_worker import (
|
11
|
+
ModelTpServer,
|
12
|
+
broadcast_recv_input,
|
13
|
+
launch_tp_servers,
|
14
|
+
)
|
13
15
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
14
16
|
from sglang.srt.utils import kill_parent_process
|
15
17
|
from sglang.utils import get_exception_traceback
|
16
18
|
|
17
|
-
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
18
|
-
|
19
19
|
logger = logging.getLogger("srt.controller")
|
20
20
|
|
21
21
|
|
22
22
|
class ControllerSingle:
|
23
|
-
|
23
|
+
"""A controller that manages a group of tensor parallel workers."""
|
24
|
+
|
25
|
+
def __init__(
|
26
|
+
self,
|
27
|
+
server_args: ServerArgs,
|
28
|
+
port_args: PortArgs,
|
29
|
+
model_overide_args: dict,
|
30
|
+
gpu_ids: List[int],
|
31
|
+
is_data_parallel_worker: bool,
|
32
|
+
dp_worker_id: int,
|
33
|
+
mp_queue: multiprocessing.Queue,
|
34
|
+
):
|
35
|
+
# Parse args
|
36
|
+
self.tp_size = server_args.tp_size
|
37
|
+
self.is_dp_worker = is_data_parallel_worker
|
38
|
+
self.dp_worker_id = dp_worker_id
|
39
|
+
self.mp_queue = mp_queue
|
40
|
+
|
24
41
|
# Init communication
|
25
|
-
context = zmq.
|
26
|
-
|
27
|
-
self.
|
42
|
+
context = zmq.Context(2)
|
43
|
+
|
44
|
+
if not self.is_dp_worker:
|
45
|
+
self.recv_from_tokenizer = context.socket(zmq.PULL)
|
46
|
+
self.recv_from_tokenizer.bind(
|
47
|
+
f"tcp://127.0.0.1:{port_args.controller_port}"
|
48
|
+
)
|
28
49
|
|
29
50
|
self.send_to_detokenizer = context.socket(zmq.PUSH)
|
30
51
|
self.send_to_detokenizer.connect(
|
31
52
|
f"tcp://127.0.0.1:{port_args.detokenizer_port}"
|
32
53
|
)
|
33
54
|
|
34
|
-
#
|
35
|
-
|
36
|
-
self.
|
37
|
-
|
38
|
-
|
39
|
-
|
55
|
+
# Launch other tp ranks
|
56
|
+
tp_size_local = server_args.tp_size // server_args.nnodes
|
57
|
+
self.tp_procs = []
|
58
|
+
if tp_size_local > 1:
|
59
|
+
tp_rank_range = range(1, tp_size_local)
|
60
|
+
self.tp_procs = launch_tp_servers(
|
61
|
+
gpu_ids,
|
62
|
+
tp_rank_range,
|
63
|
+
server_args,
|
64
|
+
port_args.nccl_ports[dp_worker_id],
|
65
|
+
model_overide_args,
|
66
|
+
)
|
67
|
+
|
68
|
+
# Launch tp rank 0
|
69
|
+
self.tp_server = ModelTpServer(
|
70
|
+
gpu_ids[0],
|
71
|
+
0,
|
72
|
+
server_args,
|
73
|
+
port_args.nccl_ports[dp_worker_id],
|
74
|
+
model_overide_args,
|
75
|
+
)
|
76
|
+
self.tp_cpu_group = self.tp_server.model_runner.tp_group.cpu_group
|
40
77
|
|
41
|
-
|
78
|
+
def loop_for_forward(self):
|
42
79
|
while True:
|
43
|
-
|
44
|
-
|
45
|
-
|
80
|
+
if not self.is_dp_worker:
|
81
|
+
recv_reqs = self.recv_requests_from_zmq()
|
82
|
+
else:
|
83
|
+
recv_reqs = self.recv_requests_from_mp_queue()
|
84
|
+
|
85
|
+
if self.tp_size > 1:
|
86
|
+
broadcast_recv_input(recv_reqs, 0, self.tp_cpu_group)
|
87
|
+
|
88
|
+
out_pyobjs = self.tp_server.exposed_step(recv_reqs)
|
46
89
|
|
47
90
|
for obj in out_pyobjs:
|
48
91
|
self.send_to_detokenizer.send_pyobj(obj)
|
49
92
|
|
50
|
-
|
51
|
-
|
52
|
-
if len(out_pyobjs) != 0:
|
53
|
-
has_finished = any(
|
54
|
-
[obj.finished_reason is not None for obj in out_pyobjs]
|
55
|
-
)
|
56
|
-
if has_finished:
|
57
|
-
if self.request_dependency_delay > 0:
|
58
|
-
slept = True
|
59
|
-
await asyncio.sleep(self.request_dependency_delay)
|
60
|
-
|
61
|
-
if not slept:
|
62
|
-
await asyncio.sleep(global_config.wait_for_new_request_delay)
|
63
|
-
|
64
|
-
async def loop_for_recv_requests(self):
|
93
|
+
def recv_requests_from_zmq(self):
|
94
|
+
recv_reqs = []
|
65
95
|
while True:
|
66
|
-
|
67
|
-
|
96
|
+
try:
|
97
|
+
recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
|
98
|
+
except zmq.ZMQError:
|
99
|
+
break
|
100
|
+
recv_reqs.append(recv_req)
|
101
|
+
|
102
|
+
return recv_reqs
|
103
|
+
|
104
|
+
def recv_requests_from_mp_queue(self):
|
105
|
+
recv_reqs = []
|
106
|
+
while not self.mp_queue.empty():
|
107
|
+
recv_reqs.append(self.mp_queue.get())
|
108
|
+
return recv_reqs
|
68
109
|
|
69
110
|
|
70
111
|
def start_controller_process(
|
71
|
-
server_args: ServerArgs,
|
112
|
+
server_args: ServerArgs,
|
113
|
+
port_args: PortArgs,
|
114
|
+
pipe_writer: multiprocessing.connection.Connection,
|
115
|
+
model_overide_args: dict,
|
116
|
+
is_data_parallel_worker: bool = False,
|
117
|
+
gpu_ids: List[int] = None,
|
118
|
+
dp_worker_id: int = None,
|
119
|
+
queue: multiprocessing.connection.Connection = None,
|
72
120
|
):
|
121
|
+
"""Start a controller process."""
|
122
|
+
|
73
123
|
logging.basicConfig(
|
74
124
|
level=getattr(logging, server_args.log_level.upper()),
|
75
125
|
format="%(message)s",
|
76
126
|
)
|
77
127
|
|
78
|
-
|
128
|
+
if not is_data_parallel_worker:
|
79
129
|
tp_size_local = server_args.tp_size // server_args.nnodes
|
80
|
-
|
81
|
-
|
130
|
+
gpu_ids = [i for _ in range(server_args.nnodes) for i in range(tp_size_local)]
|
131
|
+
dp_worker_id = 0
|
132
|
+
queue = None
|
133
|
+
|
134
|
+
try:
|
135
|
+
controller = ControllerSingle(
|
82
136
|
server_args,
|
83
|
-
port_args
|
137
|
+
port_args,
|
84
138
|
model_overide_args,
|
139
|
+
gpu_ids,
|
140
|
+
is_data_parallel_worker,
|
141
|
+
dp_worker_id,
|
142
|
+
queue,
|
85
143
|
)
|
86
|
-
controller = ControllerSingle(model_client, port_args)
|
87
144
|
except Exception:
|
88
145
|
pipe_writer.send(get_exception_traceback())
|
89
146
|
raise
|
90
147
|
|
91
148
|
pipe_writer.send("init ok")
|
92
149
|
|
93
|
-
loop = asyncio.new_event_loop()
|
94
|
-
loop.set_default_executor(ThreadPoolExecutor(max_workers=256))
|
95
|
-
asyncio.set_event_loop(loop)
|
96
|
-
loop.create_task(controller.loop_for_recv_requests())
|
97
150
|
try:
|
98
|
-
|
151
|
+
controller.loop_for_forward()
|
99
152
|
except Exception:
|
100
153
|
logger.error("Exception in ControllerSingle:\n" + get_exception_traceback())
|
101
154
|
finally:
|
155
|
+
for t in controller.tp_procs:
|
156
|
+
os.kill(t.pid, 9)
|
102
157
|
kill_parent_process()
|
@@ -9,14 +9,24 @@ from typing import Optional, Type
|
|
9
9
|
|
10
10
|
import torch
|
11
11
|
import torch.nn as nn
|
12
|
+
from flashinfer import (
|
13
|
+
BatchDecodeWithPagedKVCacheWrapper,
|
14
|
+
BatchPrefillWithPagedKVCacheWrapper,
|
15
|
+
BatchPrefillWithRaggedKVCacheWrapper,
|
16
|
+
)
|
17
|
+
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
|
12
18
|
from vllm.config import DeviceConfig, LoadConfig
|
13
19
|
from vllm.config import ModelConfig as VllmModelConfig
|
14
|
-
from vllm.distributed import
|
20
|
+
from vllm.distributed import (
|
21
|
+
get_tp_group,
|
22
|
+
init_distributed_environment,
|
23
|
+
initialize_model_parallel,
|
24
|
+
)
|
15
25
|
from vllm.model_executor.model_loader import get_model
|
16
26
|
from vllm.model_executor.models import ModelRegistry
|
17
27
|
|
18
28
|
from sglang.global_config import global_config
|
19
|
-
from sglang.srt.managers.controller.infer_batch import Batch, ForwardMode, InputMetadata
|
29
|
+
from sglang.srt.managers.controller.infer_batch import Batch, ForwardMode, InputMetadata
|
20
30
|
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
|
21
31
|
from sglang.srt.server_args import ServerArgs
|
22
32
|
from sglang.srt.utils import (
|
@@ -70,6 +80,7 @@ class ModelRunner:
|
|
70
80
|
distributed_init_method=nccl_init_method,
|
71
81
|
)
|
72
82
|
initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
|
83
|
+
self.tp_group = get_tp_group()
|
73
84
|
total_gpu_memory = get_available_gpu_memory(
|
74
85
|
self.gpu_id, distributed=self.tp_size > 1
|
75
86
|
)
|
@@ -81,10 +92,6 @@ class ModelRunner:
|
|
81
92
|
"The memory capacity is unbalanced. Some GPUs may be occupied by other processes."
|
82
93
|
)
|
83
94
|
|
84
|
-
# Set some global args
|
85
|
-
global_server_args_dict["disable_flashinfer"] = server_args.disable_flashinfer
|
86
|
-
global_server_args_dict["attention_reduce_in_fp32"] = server_args.attention_reduce_in_fp32
|
87
|
-
|
88
95
|
# Load the model and create memory pool
|
89
96
|
self.load_model()
|
90
97
|
self.init_memory_pool(total_gpu_memory)
|
@@ -116,6 +123,15 @@ class ModelRunner:
|
|
116
123
|
if self.model_config.model_overide_args is not None:
|
117
124
|
vllm_model_config.hf_config.update(self.model_config.model_overide_args)
|
118
125
|
|
126
|
+
if (
|
127
|
+
self.server_args.efficient_weight_load
|
128
|
+
and "llama" in self.server_args.model_path.lower()
|
129
|
+
and self.server_args.quantization == "fp8"
|
130
|
+
):
|
131
|
+
from sglang.srt.model_loader.model_loader import get_model
|
132
|
+
else:
|
133
|
+
from vllm.model_executor.model_loader import get_model
|
134
|
+
|
119
135
|
self.model = get_model(
|
120
136
|
model_config=vllm_model_config,
|
121
137
|
device_config=device_config,
|
@@ -161,7 +177,10 @@ class ModelRunner:
|
|
161
177
|
)
|
162
178
|
|
163
179
|
self.req_to_token_pool = ReqToTokenPool(
|
164
|
-
|
180
|
+
max(
|
181
|
+
int(self.max_total_num_tokens / self.model_config.context_len * 512),
|
182
|
+
2048,
|
183
|
+
),
|
165
184
|
self.model_config.context_len + 8,
|
166
185
|
)
|
167
186
|
self.token_to_kv_pool = TokenToKVPool(
|
@@ -192,13 +211,6 @@ class ModelRunner:
|
|
192
211
|
self.flashinfer_decode_wrapper = None
|
193
212
|
return
|
194
213
|
|
195
|
-
from flashinfer import (
|
196
|
-
BatchDecodeWithPagedKVCacheWrapper,
|
197
|
-
BatchPrefillWithPagedKVCacheWrapper,
|
198
|
-
BatchPrefillWithRaggedKVCacheWrapper,
|
199
|
-
)
|
200
|
-
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
|
201
|
-
|
202
214
|
if not _grouped_size_compiled_for_decode_kernels(
|
203
215
|
self.model_config.num_attention_heads // self.tp_size,
|
204
216
|
self.model_config.get_num_kv_heads(self.tp_size),
|
@@ -217,7 +229,9 @@ class ModelRunner:
|
|
217
229
|
self.flashinfer_workspace_buffers[1], "NHD"
|
218
230
|
)
|
219
231
|
self.flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
|
220
|
-
self.flashinfer_workspace_buffers[0],
|
232
|
+
self.flashinfer_workspace_buffers[0],
|
233
|
+
"NHD",
|
234
|
+
use_tensor_cores=use_tensor_cores,
|
221
235
|
)
|
222
236
|
|
223
237
|
def init_cuda_graphs(self):
|
@@ -228,9 +242,19 @@ class ModelRunner:
|
|
228
242
|
return
|
229
243
|
|
230
244
|
logger.info(f"[gpu_id={self.gpu_id}] Capture cuda graph begin.")
|
231
|
-
batch_size_list = [1, 2, 4] + [i * 8 for i in range(1,
|
232
|
-
self.cuda_graph_runner = CudaGraphRunner(
|
233
|
-
|
245
|
+
batch_size_list = [1, 2, 4] + [i * 8 for i in range(1, 17)]
|
246
|
+
self.cuda_graph_runner = CudaGraphRunner(
|
247
|
+
self, max_batch_size_to_capture=max(batch_size_list)
|
248
|
+
)
|
249
|
+
try:
|
250
|
+
self.cuda_graph_runner.capture(batch_size_list)
|
251
|
+
except RuntimeError as e:
|
252
|
+
raise Exception(
|
253
|
+
f"Capture cuda graph failed {e}. Possible solutions:\n"
|
254
|
+
f"1. disable cuda graph by --disable-cuda-graph\n"
|
255
|
+
f"2. set --mem-fraction-static to a smaller value\n"
|
256
|
+
f"Open an issue on GitHub with reproducible scripts if you need help.\n"
|
257
|
+
)
|
234
258
|
|
235
259
|
@torch.inference_mode()
|
236
260
|
def forward_decode(self, batch: Batch):
|
@@ -82,12 +82,12 @@ class RadixCache:
|
|
82
82
|
|
83
83
|
if self.disable:
|
84
84
|
if del_in_memory_pool:
|
85
|
-
self.token_to_kv_pool.
|
85
|
+
self.token_to_kv_pool.free(indices)
|
86
86
|
else:
|
87
87
|
return torch.tensor([], dtype=torch.int64), self.root_node
|
88
88
|
|
89
89
|
# Radix Cache takes one ref in memory pool
|
90
|
-
self.token_to_kv_pool.
|
90
|
+
self.token_to_kv_pool.free(indices[last_uncached_pos:new_prefix_len])
|
91
91
|
|
92
92
|
if del_in_memory_pool:
|
93
93
|
self.req_to_token_pool.free(req_pool_idx)
|
@@ -125,7 +125,8 @@ class RadixCache:
|
|
125
125
|
if x.lock_ref > 0:
|
126
126
|
continue
|
127
127
|
|
128
|
-
|
128
|
+
evict_callback(x.value)
|
129
|
+
num_evicted += len(x.value)
|
129
130
|
self._delete_leaf(x)
|
130
131
|
|
131
132
|
if len(x.parent.children) == 0:
|
@@ -13,6 +13,10 @@ class ScheduleHeuristic:
|
|
13
13
|
max_total_num_tokens,
|
14
14
|
tree_cache,
|
15
15
|
):
|
16
|
+
if tree_cache.disable and schedule_heuristic == "lpm":
|
17
|
+
# LMP is not meaningless when tree cache is disabled.
|
18
|
+
schedule_heuristic = "fcfs"
|
19
|
+
|
16
20
|
self.schedule_heuristic = schedule_heuristic
|
17
21
|
self.max_running_seqs = max_running_seqs
|
18
22
|
self.max_prefill_num_tokens = max_prefill_num_tokens
|