sglang 0.1.21__py3-none-any.whl → 0.1.24__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +8 -8
- sglang/api.py +1 -1
- sglang/backend/vertexai.py +5 -4
- sglang/bench.py +627 -0
- sglang/bench_latency.py +22 -19
- sglang/bench_serving.py +976 -0
- sglang/check_env.py +171 -0
- sglang/global_config.py +3 -2
- sglang/lang/backend/__init__.py +0 -0
- sglang/lang/backend/anthropic.py +77 -0
- sglang/lang/backend/base_backend.py +80 -0
- sglang/lang/backend/litellm.py +90 -0
- sglang/lang/backend/openai.py +438 -0
- sglang/lang/backend/runtime_endpoint.py +283 -0
- sglang/lang/backend/vertexai.py +149 -0
- sglang/lang/interpreter.py +1 -0
- sglang/lang/tracer.py +1 -1
- sglang/launch_server.py +1 -1
- sglang/launch_server_llavavid.py +1 -4
- sglang/srt/conversation.py +1 -1
- sglang/srt/hf_transformers_utils.py +13 -1
- sglang/srt/layers/context_flashattention_nopad.py +0 -29
- sglang/srt/layers/extend_attention.py +0 -39
- sglang/srt/layers/linear.py +869 -0
- sglang/srt/layers/logits_processor.py +4 -5
- sglang/srt/layers/quantization/__init__.py +49 -0
- sglang/srt/layers/quantization/fp8.py +662 -0
- sglang/srt/layers/radix_attention.py +39 -24
- sglang/srt/layers/token_attention.py +1 -51
- sglang/srt/managers/controller/cuda_graph_runner.py +72 -28
- sglang/srt/managers/controller/infer_batch.py +90 -63
- sglang/srt/managers/controller/manager_multi.py +107 -100
- sglang/srt/managers/controller/manager_single.py +76 -96
- sglang/srt/managers/controller/model_runner.py +41 -26
- sglang/srt/managers/controller/schedule_heuristic.py +8 -3
- sglang/srt/managers/controller/tp_worker.py +136 -149
- sglang/srt/managers/detokenizer_manager.py +49 -5
- sglang/srt/managers/io_struct.py +36 -17
- sglang/srt/managers/tokenizer_manager.py +228 -125
- sglang/srt/memory_pool.py +32 -11
- sglang/srt/model_loader/model_loader.py +277 -0
- sglang/srt/model_loader/utils.py +260 -0
- sglang/srt/models/chatglm.py +1 -0
- sglang/srt/models/dbrx.py +1 -0
- sglang/srt/models/deepseek.py +430 -0
- sglang/srt/models/gpt_bigcode.py +282 -0
- sglang/srt/models/grok.py +1 -0
- sglang/srt/models/internlm2.py +317 -0
- sglang/srt/models/llama2.py +81 -23
- sglang/srt/models/llama_classification.py +1 -0
- sglang/srt/models/llava.py +1 -0
- sglang/srt/models/llavavid.py +1 -0
- sglang/srt/models/minicpm.py +1 -0
- sglang/srt/models/mixtral.py +1 -0
- sglang/srt/models/mixtral_quant.py +1 -0
- sglang/srt/models/qwen.py +1 -0
- sglang/srt/models/qwen2.py +6 -0
- sglang/srt/models/qwen2_moe.py +7 -4
- sglang/srt/models/stablelm.py +1 -0
- sglang/srt/openai_api/adapter.py +432 -0
- sglang/srt/openai_api/api_adapter.py +432 -0
- sglang/srt/openai_api/openai_api_adapter.py +431 -0
- sglang/srt/openai_api/openai_protocol.py +207 -0
- sglang/srt/openai_api/protocol.py +208 -0
- sglang/srt/openai_protocol.py +17 -0
- sglang/srt/sampling_params.py +2 -0
- sglang/srt/server.py +132 -84
- sglang/srt/server_args.py +35 -21
- sglang/srt/utils.py +65 -117
- sglang/test/test_conversation.py +1 -1
- sglang/test/test_openai_protocol.py +1 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +2 -2
- {sglang-0.1.21.dist-info → sglang-0.1.24.dist-info}/METADATA +162 -168
- sglang-0.1.24.dist-info/RECORD +105 -0
- {sglang-0.1.21.dist-info → sglang-0.1.24.dist-info}/WHEEL +1 -1
- sglang-0.1.21.dist-info/RECORD +0 -82
- {sglang-0.1.21.dist-info → sglang-0.1.24.dist-info}/LICENSE +0 -0
- {sglang-0.1.21.dist-info → sglang-0.1.24.dist-info}/top_level.txt +0 -0
@@ -3,19 +3,17 @@ A controller that manages multiple data parallel workers.
|
|
3
3
|
Each data parallel worker can manage multiple tensor parallel workers.
|
4
4
|
"""
|
5
5
|
|
6
|
-
import
|
6
|
+
import dataclasses
|
7
7
|
import logging
|
8
|
-
|
8
|
+
import multiprocessing
|
9
|
+
import os
|
9
10
|
from enum import Enum, auto
|
10
|
-
from typing import Dict
|
11
11
|
|
12
|
+
import numpy as np
|
12
13
|
import zmq
|
13
|
-
import zmq.asyncio
|
14
14
|
|
15
|
-
from sglang.
|
16
|
-
|
17
|
-
DataParallelWorkerThread,
|
18
|
-
start_data_parallel_worker,
|
15
|
+
from sglang.srt.managers.controller.manager_single import (
|
16
|
+
start_controller_process as start_controller_process_single,
|
19
17
|
)
|
20
18
|
from sglang.srt.managers.io_struct import (
|
21
19
|
AbortReq,
|
@@ -23,12 +21,15 @@ from sglang.srt.managers.io_struct import (
|
|
23
21
|
TokenizedGenerateReqInput,
|
24
22
|
)
|
25
23
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
24
|
+
from sglang.srt.utils import kill_parent_process
|
26
25
|
from sglang.utils import get_exception_traceback
|
27
26
|
|
28
27
|
logger = logging.getLogger("srt.controller")
|
29
28
|
|
30
29
|
|
31
30
|
class LoadBalanceMethod(Enum):
|
31
|
+
"""Load balance method."""
|
32
|
+
|
32
33
|
ROUND_ROBIN = auto()
|
33
34
|
SHORTEST_QUEUE = auto()
|
34
35
|
|
@@ -41,155 +42,161 @@ class LoadBalanceMethod(Enum):
|
|
41
42
|
raise ValueError(f"Invalid load balance method: {method}") from exc
|
42
43
|
|
43
44
|
|
44
|
-
|
45
|
+
@dataclasses.dataclass
|
46
|
+
class WorkerHandle:
|
47
|
+
"""Store the handle of a data parallel worker."""
|
48
|
+
|
49
|
+
proc: multiprocessing.Process
|
50
|
+
queue: multiprocessing.Queue
|
51
|
+
|
52
|
+
|
53
|
+
class ControllerMulti:
|
45
54
|
"""A controller that manages multiple data parallel workers."""
|
46
55
|
|
47
56
|
def __init__(
|
48
57
|
self,
|
49
|
-
load_balance_method: str,
|
50
58
|
server_args: ServerArgs,
|
51
59
|
port_args: PortArgs,
|
52
60
|
model_overide_args,
|
53
61
|
):
|
54
|
-
|
62
|
+
# Parse args
|
55
63
|
self.server_args = server_args
|
56
64
|
self.port_args = port_args
|
65
|
+
self.model_overide_args = model_overide_args
|
66
|
+
self.load_balance_method = LoadBalanceMethod.from_str(
|
67
|
+
server_args.load_balance_method
|
68
|
+
)
|
57
69
|
|
58
|
-
|
59
|
-
|
70
|
+
# Init communication
|
71
|
+
context = zmq.Context()
|
72
|
+
self.recv_from_tokenizer = context.socket(zmq.PULL)
|
73
|
+
self.recv_from_tokenizer.bind(f"tcp://127.0.0.1:{port_args.controller_port}")
|
60
74
|
|
61
|
-
|
75
|
+
# Dispatch method
|
76
|
+
self.round_robin_counter = 0
|
77
|
+
dispatch_lookup = {
|
62
78
|
LoadBalanceMethod.ROUND_ROBIN: self.round_robin_scheduler,
|
63
79
|
LoadBalanceMethod.SHORTEST_QUEUE: self.shortest_queue_scheduler,
|
64
80
|
}
|
65
|
-
self.dispatching =
|
66
|
-
|
67
|
-
# Init communication
|
68
|
-
context = zmq.asyncio.Context()
|
69
|
-
self.recv_from_tokenizer = context.socket(zmq.PULL)
|
70
|
-
self.recv_from_tokenizer.bind(f"tcp://127.0.0.1:{port_args.router_port}")
|
71
|
-
|
72
|
-
# Init status
|
73
|
-
self.recv_reqs = []
|
81
|
+
self.dispatching = dispatch_lookup[self.load_balance_method]
|
74
82
|
|
75
83
|
# Start data parallel workers
|
76
|
-
self.workers
|
77
|
-
tp_size = server_args.tp_size
|
78
|
-
|
79
|
-
def start_dp_worker(i):
|
80
|
-
try:
|
81
|
-
gpu_ids = list(range(i * tp_size, (i + 1) * tp_size))
|
82
|
-
worker_thread = start_data_parallel_worker(
|
83
|
-
server_args, port_args, model_overide_args, gpu_ids, i
|
84
|
-
)
|
85
|
-
self.workers[i] = worker_thread
|
86
|
-
except Exception:
|
87
|
-
logger.error(
|
88
|
-
f"Failed to start local worker {i}\n{get_exception_traceback()}"
|
89
|
-
)
|
90
|
-
|
84
|
+
self.workers = []
|
91
85
|
for i in range(server_args.dp_size):
|
92
|
-
start_dp_worker(i)
|
86
|
+
self.start_dp_worker(i)
|
87
|
+
|
88
|
+
def start_dp_worker(self, dp_worker_id: int):
|
89
|
+
tp_size = self.server_args.tp_size
|
93
90
|
|
94
|
-
|
95
|
-
|
96
|
-
|
91
|
+
pipe_controller_reader, pipe_controller_writer = multiprocessing.Pipe(
|
92
|
+
duplex=False
|
93
|
+
)
|
97
94
|
|
98
|
-
|
99
|
-
|
95
|
+
gpu_ids = list(range(dp_worker_id * tp_size, (dp_worker_id + 1) * tp_size))
|
96
|
+
queue = multiprocessing.Queue()
|
97
|
+
proc = multiprocessing.Process(
|
98
|
+
target=start_controller_process_single,
|
99
|
+
args=(
|
100
|
+
self.server_args,
|
101
|
+
self.port_args,
|
102
|
+
pipe_controller_writer,
|
103
|
+
self.model_overide_args,
|
104
|
+
True,
|
105
|
+
gpu_ids,
|
106
|
+
dp_worker_id,
|
107
|
+
queue,
|
108
|
+
),
|
109
|
+
)
|
110
|
+
proc.start()
|
100
111
|
|
101
|
-
|
102
|
-
|
112
|
+
controller_init_state = pipe_controller_reader.recv()
|
113
|
+
if controller_init_state != "init ok":
|
114
|
+
raise RuntimeError(
|
115
|
+
f"Initialization failed. controller_init_state: {controller_init_state}"
|
116
|
+
)
|
117
|
+
self.workers.append(
|
118
|
+
WorkerHandle(
|
119
|
+
proc=proc,
|
120
|
+
queue=queue,
|
121
|
+
)
|
122
|
+
)
|
103
123
|
|
104
|
-
|
105
|
-
available_workers = list(self.workers.keys())
|
124
|
+
def round_robin_scheduler(self, input_requests):
|
106
125
|
for r in input_requests:
|
107
|
-
self.
|
126
|
+
self.workers[self.round_robin_counter].queue.put(r)
|
108
127
|
self.round_robin_counter = (self.round_robin_counter + 1) % len(
|
109
|
-
|
128
|
+
self.workers
|
110
129
|
)
|
111
|
-
return
|
112
130
|
|
113
|
-
|
131
|
+
def shortest_queue_scheduler(self, input_requests):
|
114
132
|
for r in input_requests:
|
115
|
-
|
116
|
-
|
117
|
-
)
|
118
|
-
self.put_req_to_worker(worker, r)
|
119
|
-
return
|
120
|
-
|
121
|
-
async def remove_dead_workers(self):
|
122
|
-
for i in list(self.workers.keys()):
|
123
|
-
worker_thread = self.workers[i]
|
124
|
-
if not worker_thread.liveness:
|
125
|
-
worker_thread.join()
|
126
|
-
# move unsuccessful requests back to the queue
|
127
|
-
while not worker_thread.request_queue.empty():
|
128
|
-
self.recv_reqs.append(worker_thread.request_queue.get())
|
129
|
-
del self.workers[i]
|
130
|
-
logger.info(f"Stale worker {i} removed")
|
131
|
-
|
132
|
-
async def loop_for_forward(self):
|
133
|
-
while True:
|
134
|
-
await self.remove_dead_workers()
|
133
|
+
queue_sizes = [worker.queue.qsize() for worker in self.workers]
|
134
|
+
wid = np.argmin(queue_sizes)
|
135
|
+
self.workers[wid].queue.put(r)
|
135
136
|
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
await self.dispatching(next_step_input)
|
141
|
-
# else:
|
142
|
-
# logger.error("There is no live worker.")
|
137
|
+
def loop_for_forward(self):
|
138
|
+
while True:
|
139
|
+
recv_reqs = self.recv_requests()
|
140
|
+
self.dispatching(recv_reqs)
|
143
141
|
|
144
|
-
|
142
|
+
def recv_requests(self):
|
143
|
+
recv_reqs = []
|
145
144
|
|
146
|
-
async def loop_for_recv_requests(self):
|
147
145
|
while True:
|
148
|
-
|
146
|
+
try:
|
147
|
+
recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
|
148
|
+
except zmq.ZMQError:
|
149
|
+
break
|
150
|
+
|
149
151
|
if isinstance(recv_req, FlushCacheReq):
|
150
152
|
# TODO(lsyin): apply more specific flushCacheReq
|
151
|
-
for
|
152
|
-
|
153
|
-
elif isinstance(recv_req, TokenizedGenerateReqInput):
|
154
|
-
self.recv_reqs.append(recv_req)
|
153
|
+
for worker in self.workers:
|
154
|
+
worker.queue.put(recv_req)
|
155
155
|
elif isinstance(recv_req, AbortReq):
|
156
156
|
in_queue = False
|
157
|
-
for i, req in enumerate(
|
157
|
+
for i, req in enumerate(recv_reqs):
|
158
158
|
if req.rid == recv_req.rid:
|
159
|
-
|
159
|
+
recv_reqs[i] = recv_req
|
160
160
|
in_queue = True
|
161
161
|
break
|
162
162
|
if not in_queue:
|
163
163
|
# Send abort req to all TP groups
|
164
|
-
for worker in
|
165
|
-
|
164
|
+
for worker in self.workers:
|
165
|
+
worker.queue.put(recv_req)
|
166
|
+
elif isinstance(recv_req, TokenizedGenerateReqInput):
|
167
|
+
recv_reqs.append(recv_req)
|
166
168
|
else:
|
167
169
|
logger.error(f"Invalid object: {recv_req}")
|
168
170
|
|
171
|
+
return recv_reqs
|
172
|
+
|
169
173
|
|
170
174
|
def start_controller_process(
|
171
175
|
server_args: ServerArgs,
|
172
176
|
port_args: PortArgs,
|
173
177
|
pipe_writer,
|
174
|
-
model_overide_args
|
178
|
+
model_overide_args: dict,
|
175
179
|
):
|
180
|
+
"""Start a controller process."""
|
181
|
+
|
176
182
|
logging.basicConfig(
|
177
183
|
level=getattr(logging, server_args.log_level.upper()),
|
178
184
|
format="%(message)s",
|
179
185
|
)
|
180
186
|
|
181
187
|
try:
|
182
|
-
controller =
|
183
|
-
server_args.load_balance_method, server_args, port_args, model_overide_args
|
184
|
-
)
|
188
|
+
controller = ControllerMulti(server_args, port_args, model_overide_args)
|
185
189
|
except Exception:
|
186
190
|
pipe_writer.send(get_exception_traceback())
|
187
191
|
raise
|
188
|
-
pipe_writer.send("init ok")
|
189
192
|
|
190
|
-
|
191
|
-
loop.set_default_executor(ThreadPoolExecutor(max_workers=256))
|
193
|
+
pipe_writer.send("init ok")
|
192
194
|
|
193
|
-
|
194
|
-
|
195
|
-
|
195
|
+
try:
|
196
|
+
controller.loop_for_forward()
|
197
|
+
except Exception:
|
198
|
+
logger.error("Exception in ControllerMulti:\n" + get_exception_traceback())
|
199
|
+
finally:
|
200
|
+
for w in controller.workers:
|
201
|
+
os.kill(w.proc.pid, 9)
|
202
|
+
kill_parent_process()
|
@@ -1,138 +1,88 @@
|
|
1
1
|
"""A controller that manages a group of tensor parallel workers."""
|
2
2
|
|
3
|
-
import multiprocessing
|
4
3
|
import logging
|
4
|
+
import multiprocessing
|
5
5
|
import os
|
6
|
-
import
|
6
|
+
from typing import List
|
7
7
|
|
8
|
-
import torch
|
9
|
-
import torch.distributed as dist
|
10
8
|
import zmq
|
11
|
-
import zmq.asyncio
|
12
9
|
|
13
|
-
from sglang.srt.managers.controller.tp_worker import
|
14
|
-
|
10
|
+
from sglang.srt.managers.controller.tp_worker import (
|
11
|
+
ModelTpServer,
|
12
|
+
broadcast_recv_input,
|
13
|
+
launch_tp_servers,
|
14
|
+
)
|
15
|
+
from sglang.srt.server_args import PortArgs, ServerArgs
|
15
16
|
from sglang.srt.utils import kill_parent_process
|
16
17
|
from sglang.utils import get_exception_traceback
|
17
18
|
|
18
|
-
|
19
19
|
logger = logging.getLogger("srt.controller")
|
20
20
|
|
21
21
|
|
22
|
-
def run_tp_server(
|
23
|
-
gpu_id: int,
|
24
|
-
tp_rank: int,
|
25
|
-
server_args: ServerArgs,
|
26
|
-
model_port_args: ModelPortArgs,
|
27
|
-
model_overide_args: dict,
|
28
|
-
):
|
29
|
-
"""Run a tp server."""
|
30
|
-
try:
|
31
|
-
model_server = ModelTpServer(
|
32
|
-
gpu_id,
|
33
|
-
tp_rank,
|
34
|
-
server_args,
|
35
|
-
model_port_args,
|
36
|
-
model_overide_args,
|
37
|
-
)
|
38
|
-
tp_cpu_group = model_server.model_runner.tp_group.cpu_group
|
39
|
-
|
40
|
-
while True:
|
41
|
-
recv_reqs = broadcast_recv_input(None, tp_rank, tp_cpu_group)
|
42
|
-
model_server.exposed_step(recv_reqs)
|
43
|
-
except Exception:
|
44
|
-
logger.error("Exception in run_tp_server:\n" + get_exception_traceback())
|
45
|
-
raise
|
46
|
-
|
47
|
-
|
48
|
-
def launch_tp_servers(gpu_ids, tp_rank_range, server_args,
|
49
|
-
model_port_args, model_overide_args):
|
50
|
-
"""Launch multiple tp servers."""
|
51
|
-
procs = []
|
52
|
-
for i in tp_rank_range:
|
53
|
-
proc = multiprocessing.Process(target=run_tp_server, args=(
|
54
|
-
gpu_ids[i], i, server_args, model_port_args, model_overide_args
|
55
|
-
))
|
56
|
-
proc.start()
|
57
|
-
procs.append(proc)
|
58
|
-
|
59
|
-
return procs
|
60
|
-
|
61
|
-
|
62
|
-
def broadcast_recv_input(data, rank, dist_group):
|
63
|
-
"""Broadcast inputs from rank=0 to all other ranks with torch.dist backend."""
|
64
|
-
|
65
|
-
if rank == 0:
|
66
|
-
if len(data) == 0:
|
67
|
-
tensor_size = torch.tensor([0], dtype=torch.long)
|
68
|
-
dist.broadcast(tensor_size, src=0, group=dist_group)
|
69
|
-
else:
|
70
|
-
serialized_data = pickle.dumps(data)
|
71
|
-
size = len(serialized_data)
|
72
|
-
tensor_data = torch.ByteTensor(list(serialized_data))
|
73
|
-
tensor_size = torch.tensor([size], dtype=torch.long)
|
74
|
-
|
75
|
-
dist.broadcast(tensor_size, src=0, group=dist_group)
|
76
|
-
dist.broadcast(tensor_data, src=0, group=dist_group)
|
77
|
-
else:
|
78
|
-
tensor_size = torch.tensor([0], dtype=torch.long)
|
79
|
-
dist.broadcast(tensor_size, src=0, group=dist_group)
|
80
|
-
size = tensor_size.item()
|
81
|
-
|
82
|
-
if size == 0:
|
83
|
-
return []
|
84
|
-
|
85
|
-
tensor_data = torch.empty(size, dtype=torch.uint8)
|
86
|
-
dist.broadcast(tensor_data, src=0, group=dist_group)
|
87
|
-
|
88
|
-
serialized_data = bytes(tensor_data.tolist())
|
89
|
-
data = pickle.loads(serialized_data)
|
90
|
-
return data
|
91
|
-
|
92
|
-
|
93
22
|
class ControllerSingle:
|
94
23
|
"""A controller that manages a group of tensor parallel workers."""
|
95
24
|
|
96
|
-
def __init__(
|
25
|
+
def __init__(
|
26
|
+
self,
|
27
|
+
server_args: ServerArgs,
|
28
|
+
port_args: PortArgs,
|
29
|
+
model_overide_args: dict,
|
30
|
+
gpu_ids: List[int],
|
31
|
+
is_data_parallel_worker: bool,
|
32
|
+
dp_worker_id: int,
|
33
|
+
mp_queue: multiprocessing.Queue,
|
34
|
+
):
|
97
35
|
# Parse args
|
98
|
-
self.
|
36
|
+
self.tp_size = server_args.tp_size
|
37
|
+
self.is_dp_worker = is_data_parallel_worker
|
38
|
+
self.dp_worker_id = dp_worker_id
|
39
|
+
self.mp_queue = mp_queue
|
99
40
|
|
100
41
|
# Init communication
|
101
42
|
context = zmq.Context(2)
|
102
|
-
|
103
|
-
self.
|
43
|
+
|
44
|
+
if not self.is_dp_worker:
|
45
|
+
self.recv_from_tokenizer = context.socket(zmq.PULL)
|
46
|
+
self.recv_from_tokenizer.bind(
|
47
|
+
f"tcp://127.0.0.1:{port_args.controller_port}"
|
48
|
+
)
|
104
49
|
|
105
50
|
self.send_to_detokenizer = context.socket(zmq.PUSH)
|
106
51
|
self.send_to_detokenizer.connect(
|
107
52
|
f"tcp://127.0.0.1:{port_args.detokenizer_port}"
|
108
53
|
)
|
109
54
|
|
110
|
-
# Init model server
|
111
|
-
tp_size_local = server_args.tp_size // server_args.nnodes
|
112
|
-
gpu_ids = [i for _ in range(server_args.nnodes) for i in range(tp_size_local)]
|
113
|
-
|
114
55
|
# Launch other tp ranks
|
56
|
+
tp_size_local = server_args.tp_size // server_args.nnodes
|
57
|
+
self.tp_procs = []
|
115
58
|
if tp_size_local > 1:
|
116
59
|
tp_rank_range = range(1, tp_size_local)
|
117
60
|
self.tp_procs = launch_tp_servers(
|
118
|
-
gpu_ids,
|
119
|
-
|
61
|
+
gpu_ids,
|
62
|
+
tp_rank_range,
|
63
|
+
server_args,
|
64
|
+
port_args.nccl_ports[dp_worker_id],
|
65
|
+
model_overide_args,
|
66
|
+
)
|
120
67
|
|
121
68
|
# Launch tp rank 0
|
122
69
|
self.tp_server = ModelTpServer(
|
123
70
|
gpu_ids[0],
|
124
71
|
0,
|
125
72
|
server_args,
|
126
|
-
port_args.
|
73
|
+
port_args.nccl_ports[dp_worker_id],
|
127
74
|
model_overide_args,
|
128
75
|
)
|
129
76
|
self.tp_cpu_group = self.tp_server.model_runner.tp_group.cpu_group
|
130
77
|
|
131
78
|
def loop_for_forward(self):
|
132
79
|
while True:
|
133
|
-
|
80
|
+
if not self.is_dp_worker:
|
81
|
+
recv_reqs = self.recv_requests_from_zmq()
|
82
|
+
else:
|
83
|
+
recv_reqs = self.recv_requests_from_mp_queue()
|
134
84
|
|
135
|
-
if self.
|
85
|
+
if self.tp_size > 1:
|
136
86
|
broadcast_recv_input(recv_reqs, 0, self.tp_cpu_group)
|
137
87
|
|
138
88
|
out_pyobjs = self.tp_server.exposed_step(recv_reqs)
|
@@ -140,27 +90,57 @@ class ControllerSingle:
|
|
140
90
|
for obj in out_pyobjs:
|
141
91
|
self.send_to_detokenizer.send_pyobj(obj)
|
142
92
|
|
143
|
-
def
|
93
|
+
def recv_requests_from_zmq(self):
|
144
94
|
recv_reqs = []
|
145
95
|
while True:
|
146
96
|
try:
|
147
97
|
recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
|
148
|
-
recv_reqs.append(recv_req)
|
149
98
|
except zmq.ZMQError:
|
150
99
|
break
|
100
|
+
recv_reqs.append(recv_req)
|
101
|
+
|
102
|
+
return recv_reqs
|
103
|
+
|
104
|
+
def recv_requests_from_mp_queue(self):
|
105
|
+
recv_reqs = []
|
106
|
+
while not self.mp_queue.empty():
|
107
|
+
recv_reqs.append(self.mp_queue.get())
|
151
108
|
return recv_reqs
|
152
109
|
|
153
110
|
|
154
111
|
def start_controller_process(
|
155
|
-
server_args: ServerArgs,
|
112
|
+
server_args: ServerArgs,
|
113
|
+
port_args: PortArgs,
|
114
|
+
pipe_writer: multiprocessing.connection.Connection,
|
115
|
+
model_overide_args: dict,
|
116
|
+
is_data_parallel_worker: bool = False,
|
117
|
+
gpu_ids: List[int] = None,
|
118
|
+
dp_worker_id: int = None,
|
119
|
+
queue: multiprocessing.connection.Connection = None,
|
156
120
|
):
|
121
|
+
"""Start a controller process."""
|
122
|
+
|
157
123
|
logging.basicConfig(
|
158
124
|
level=getattr(logging, server_args.log_level.upper()),
|
159
125
|
format="%(message)s",
|
160
126
|
)
|
161
127
|
|
128
|
+
if not is_data_parallel_worker:
|
129
|
+
tp_size_local = server_args.tp_size // server_args.nnodes
|
130
|
+
gpu_ids = [i for _ in range(server_args.nnodes) for i in range(tp_size_local)]
|
131
|
+
dp_worker_id = 0
|
132
|
+
queue = None
|
133
|
+
|
162
134
|
try:
|
163
|
-
controller = ControllerSingle(
|
135
|
+
controller = ControllerSingle(
|
136
|
+
server_args,
|
137
|
+
port_args,
|
138
|
+
model_overide_args,
|
139
|
+
gpu_ids,
|
140
|
+
is_data_parallel_worker,
|
141
|
+
dp_worker_id,
|
142
|
+
queue,
|
143
|
+
)
|
164
144
|
except Exception:
|
165
145
|
pipe_writer.send(get_exception_traceback())
|
166
146
|
raise
|
@@ -9,19 +9,23 @@ from typing import Optional, Type
|
|
9
9
|
|
10
10
|
import torch
|
11
11
|
import torch.nn as nn
|
12
|
+
from flashinfer import (
|
13
|
+
BatchDecodeWithPagedKVCacheWrapper,
|
14
|
+
BatchPrefillWithPagedKVCacheWrapper,
|
15
|
+
BatchPrefillWithRaggedKVCacheWrapper,
|
16
|
+
)
|
17
|
+
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
|
12
18
|
from vllm.config import DeviceConfig, LoadConfig
|
13
19
|
from vllm.config import ModelConfig as VllmModelConfig
|
14
|
-
from vllm.distributed import
|
15
|
-
|
20
|
+
from vllm.distributed import (
|
21
|
+
get_tp_group,
|
22
|
+
init_distributed_environment,
|
23
|
+
initialize_model_parallel,
|
24
|
+
)
|
16
25
|
from vllm.model_executor.models import ModelRegistry
|
17
26
|
|
18
27
|
from sglang.global_config import global_config
|
19
|
-
from sglang.srt.managers.controller.infer_batch import
|
20
|
-
Batch,
|
21
|
-
ForwardMode,
|
22
|
-
InputMetadata,
|
23
|
-
global_server_args_dict,
|
24
|
-
)
|
28
|
+
from sglang.srt.managers.controller.infer_batch import Batch, ForwardMode, InputMetadata
|
25
29
|
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
|
26
30
|
from sglang.srt.server_args import ServerArgs
|
27
31
|
from sglang.srt.utils import (
|
@@ -87,12 +91,6 @@ class ModelRunner:
|
|
87
91
|
"The memory capacity is unbalanced. Some GPUs may be occupied by other processes."
|
88
92
|
)
|
89
93
|
|
90
|
-
# Set some global args
|
91
|
-
global_server_args_dict["disable_flashinfer"] = server_args.disable_flashinfer
|
92
|
-
global_server_args_dict[
|
93
|
-
"attention_reduce_in_fp32"
|
94
|
-
] = server_args.attention_reduce_in_fp32
|
95
|
-
|
96
94
|
# Load the model and create memory pool
|
97
95
|
self.load_model()
|
98
96
|
self.init_memory_pool(total_gpu_memory)
|
@@ -124,6 +122,15 @@ class ModelRunner:
|
|
124
122
|
if self.model_config.model_overide_args is not None:
|
125
123
|
vllm_model_config.hf_config.update(self.model_config.model_overide_args)
|
126
124
|
|
125
|
+
if (
|
126
|
+
self.server_args.efficient_weight_load
|
127
|
+
and "llama" in self.server_args.model_path.lower()
|
128
|
+
and self.server_args.quantization == "fp8"
|
129
|
+
):
|
130
|
+
from sglang.srt.model_loader.model_loader import get_model
|
131
|
+
else:
|
132
|
+
from vllm.model_executor.model_loader import get_model
|
133
|
+
|
127
134
|
self.model = get_model(
|
128
135
|
model_config=vllm_model_config,
|
129
136
|
device_config=device_config,
|
@@ -169,7 +176,10 @@ class ModelRunner:
|
|
169
176
|
)
|
170
177
|
|
171
178
|
self.req_to_token_pool = ReqToTokenPool(
|
172
|
-
|
179
|
+
max(
|
180
|
+
int(self.max_total_num_tokens / self.model_config.context_len * 512),
|
181
|
+
2048,
|
182
|
+
),
|
173
183
|
self.model_config.context_len + 8,
|
174
184
|
)
|
175
185
|
self.token_to_kv_pool = TokenToKVPool(
|
@@ -200,13 +210,6 @@ class ModelRunner:
|
|
200
210
|
self.flashinfer_decode_wrapper = None
|
201
211
|
return
|
202
212
|
|
203
|
-
from flashinfer import (
|
204
|
-
BatchDecodeWithPagedKVCacheWrapper,
|
205
|
-
BatchPrefillWithPagedKVCacheWrapper,
|
206
|
-
BatchPrefillWithRaggedKVCacheWrapper,
|
207
|
-
)
|
208
|
-
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
|
209
|
-
|
210
213
|
if not _grouped_size_compiled_for_decode_kernels(
|
211
214
|
self.model_config.num_attention_heads // self.tp_size,
|
212
215
|
self.model_config.get_num_kv_heads(self.tp_size),
|
@@ -237,12 +240,24 @@ class ModelRunner:
|
|
237
240
|
self.cuda_graph_runner = None
|
238
241
|
return
|
239
242
|
|
240
|
-
logger.info(
|
241
|
-
|
243
|
+
logger.info(
|
244
|
+
f"[gpu_id={self.gpu_id}] Capture cuda graph begin. This can take up to several minutes."
|
245
|
+
)
|
246
|
+
batch_size_list = [1, 2, 4] + [i * 8 for i in range(1, 17)]
|
242
247
|
self.cuda_graph_runner = CudaGraphRunner(
|
243
|
-
self,
|
248
|
+
self,
|
249
|
+
max_batch_size_to_capture=max(batch_size_list),
|
250
|
+
use_torch_compile=self.server_args.enable_torch_compile,
|
244
251
|
)
|
245
|
-
|
252
|
+
try:
|
253
|
+
self.cuda_graph_runner.capture(batch_size_list)
|
254
|
+
except RuntimeError as e:
|
255
|
+
raise Exception(
|
256
|
+
f"Capture cuda graph failed: {e}. Possible solutions:\n"
|
257
|
+
f"1. disable cuda graph by --disable-cuda-graph\n"
|
258
|
+
f"2. set --mem-fraction-static to a smaller value\n"
|
259
|
+
f"Open an issue on GitHub with reproducible scripts if you need help.\n"
|
260
|
+
)
|
246
261
|
|
247
262
|
@torch.inference_mode()
|
248
263
|
def forward_decode(self, batch: Batch):
|
@@ -14,7 +14,7 @@ class ScheduleHeuristic:
|
|
14
14
|
tree_cache,
|
15
15
|
):
|
16
16
|
if tree_cache.disable and schedule_heuristic == "lpm":
|
17
|
-
# LMP is
|
17
|
+
# LMP is meaningless when the tree cache is disabled.
|
18
18
|
schedule_heuristic = "fcfs"
|
19
19
|
|
20
20
|
self.schedule_heuristic = schedule_heuristic
|
@@ -28,11 +28,16 @@ class ScheduleHeuristic:
|
|
28
28
|
# longest prefix match
|
29
29
|
forward_queue.sort(key=lambda x: -len(x.prefix_indices))
|
30
30
|
return forward_queue
|
31
|
+
elif self.schedule_heuristic == "fcfs":
|
32
|
+
# first come first serve
|
33
|
+
return forward_queue
|
34
|
+
elif self.schedule_heuristic == "lof":
|
35
|
+
# longest output first
|
36
|
+
forward_queue.sort(key=lambda x: -x.sampling_params.max_new_tokens)
|
37
|
+
return forward_queue
|
31
38
|
elif self.schedule_heuristic == "random":
|
32
39
|
random.shuffle(forward_queue)
|
33
40
|
return forward_queue
|
34
|
-
elif self.schedule_heuristic == "fcfs":
|
35
|
-
return forward_queue
|
36
41
|
elif self.schedule_heuristic == "dfs-weight":
|
37
42
|
last_node_to_reqs = defaultdict(list)
|
38
43
|
for req in forward_queue:
|