sglang 0.1.21__py3-none-any.whl → 0.1.24__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (79) hide show
  1. sglang/__init__.py +8 -8
  2. sglang/api.py +1 -1
  3. sglang/backend/vertexai.py +5 -4
  4. sglang/bench.py +627 -0
  5. sglang/bench_latency.py +22 -19
  6. sglang/bench_serving.py +976 -0
  7. sglang/check_env.py +171 -0
  8. sglang/global_config.py +3 -2
  9. sglang/lang/backend/__init__.py +0 -0
  10. sglang/lang/backend/anthropic.py +77 -0
  11. sglang/lang/backend/base_backend.py +80 -0
  12. sglang/lang/backend/litellm.py +90 -0
  13. sglang/lang/backend/openai.py +438 -0
  14. sglang/lang/backend/runtime_endpoint.py +283 -0
  15. sglang/lang/backend/vertexai.py +149 -0
  16. sglang/lang/interpreter.py +1 -0
  17. sglang/lang/tracer.py +1 -1
  18. sglang/launch_server.py +1 -1
  19. sglang/launch_server_llavavid.py +1 -4
  20. sglang/srt/conversation.py +1 -1
  21. sglang/srt/hf_transformers_utils.py +13 -1
  22. sglang/srt/layers/context_flashattention_nopad.py +0 -29
  23. sglang/srt/layers/extend_attention.py +0 -39
  24. sglang/srt/layers/linear.py +869 -0
  25. sglang/srt/layers/logits_processor.py +4 -5
  26. sglang/srt/layers/quantization/__init__.py +49 -0
  27. sglang/srt/layers/quantization/fp8.py +662 -0
  28. sglang/srt/layers/radix_attention.py +39 -24
  29. sglang/srt/layers/token_attention.py +1 -51
  30. sglang/srt/managers/controller/cuda_graph_runner.py +72 -28
  31. sglang/srt/managers/controller/infer_batch.py +90 -63
  32. sglang/srt/managers/controller/manager_multi.py +107 -100
  33. sglang/srt/managers/controller/manager_single.py +76 -96
  34. sglang/srt/managers/controller/model_runner.py +41 -26
  35. sglang/srt/managers/controller/schedule_heuristic.py +8 -3
  36. sglang/srt/managers/controller/tp_worker.py +136 -149
  37. sglang/srt/managers/detokenizer_manager.py +49 -5
  38. sglang/srt/managers/io_struct.py +36 -17
  39. sglang/srt/managers/tokenizer_manager.py +228 -125
  40. sglang/srt/memory_pool.py +32 -11
  41. sglang/srt/model_loader/model_loader.py +277 -0
  42. sglang/srt/model_loader/utils.py +260 -0
  43. sglang/srt/models/chatglm.py +1 -0
  44. sglang/srt/models/dbrx.py +1 -0
  45. sglang/srt/models/deepseek.py +430 -0
  46. sglang/srt/models/gpt_bigcode.py +282 -0
  47. sglang/srt/models/grok.py +1 -0
  48. sglang/srt/models/internlm2.py +317 -0
  49. sglang/srt/models/llama2.py +81 -23
  50. sglang/srt/models/llama_classification.py +1 -0
  51. sglang/srt/models/llava.py +1 -0
  52. sglang/srt/models/llavavid.py +1 -0
  53. sglang/srt/models/minicpm.py +1 -0
  54. sglang/srt/models/mixtral.py +1 -0
  55. sglang/srt/models/mixtral_quant.py +1 -0
  56. sglang/srt/models/qwen.py +1 -0
  57. sglang/srt/models/qwen2.py +6 -0
  58. sglang/srt/models/qwen2_moe.py +7 -4
  59. sglang/srt/models/stablelm.py +1 -0
  60. sglang/srt/openai_api/adapter.py +432 -0
  61. sglang/srt/openai_api/api_adapter.py +432 -0
  62. sglang/srt/openai_api/openai_api_adapter.py +431 -0
  63. sglang/srt/openai_api/openai_protocol.py +207 -0
  64. sglang/srt/openai_api/protocol.py +208 -0
  65. sglang/srt/openai_protocol.py +17 -0
  66. sglang/srt/sampling_params.py +2 -0
  67. sglang/srt/server.py +132 -84
  68. sglang/srt/server_args.py +35 -21
  69. sglang/srt/utils.py +65 -117
  70. sglang/test/test_conversation.py +1 -1
  71. sglang/test/test_openai_protocol.py +1 -1
  72. sglang/test/test_programs.py +1 -1
  73. sglang/test/test_utils.py +2 -2
  74. {sglang-0.1.21.dist-info → sglang-0.1.24.dist-info}/METADATA +162 -168
  75. sglang-0.1.24.dist-info/RECORD +105 -0
  76. {sglang-0.1.21.dist-info → sglang-0.1.24.dist-info}/WHEEL +1 -1
  77. sglang-0.1.21.dist-info/RECORD +0 -82
  78. {sglang-0.1.21.dist-info → sglang-0.1.24.dist-info}/LICENSE +0 -0
  79. {sglang-0.1.21.dist-info → sglang-0.1.24.dist-info}/top_level.txt +0 -0
@@ -3,19 +3,17 @@ A controller that manages multiple data parallel workers.
3
3
  Each data parallel worker can manage multiple tensor parallel workers.
4
4
  """
5
5
 
6
- import asyncio
6
+ import dataclasses
7
7
  import logging
8
- from concurrent.futures import ThreadPoolExecutor
8
+ import multiprocessing
9
+ import os
9
10
  from enum import Enum, auto
10
- from typing import Dict
11
11
 
12
+ import numpy as np
12
13
  import zmq
13
- import zmq.asyncio
14
14
 
15
- from sglang.global_config import global_config
16
- from sglang.srt.managers.controller.dp_worker import (
17
- DataParallelWorkerThread,
18
- start_data_parallel_worker,
15
+ from sglang.srt.managers.controller.manager_single import (
16
+ start_controller_process as start_controller_process_single,
19
17
  )
20
18
  from sglang.srt.managers.io_struct import (
21
19
  AbortReq,
@@ -23,12 +21,15 @@ from sglang.srt.managers.io_struct import (
23
21
  TokenizedGenerateReqInput,
24
22
  )
25
23
  from sglang.srt.server_args import PortArgs, ServerArgs
24
+ from sglang.srt.utils import kill_parent_process
26
25
  from sglang.utils import get_exception_traceback
27
26
 
28
27
  logger = logging.getLogger("srt.controller")
29
28
 
30
29
 
31
30
  class LoadBalanceMethod(Enum):
31
+ """Load balance method."""
32
+
32
33
  ROUND_ROBIN = auto()
33
34
  SHORTEST_QUEUE = auto()
34
35
 
@@ -41,155 +42,161 @@ class LoadBalanceMethod(Enum):
41
42
  raise ValueError(f"Invalid load balance method: {method}") from exc
42
43
 
43
44
 
44
- class Controller:
45
+ @dataclasses.dataclass
46
+ class WorkerHandle:
47
+ """Store the handle of a data parallel worker."""
48
+
49
+ proc: multiprocessing.Process
50
+ queue: multiprocessing.Queue
51
+
52
+
53
+ class ControllerMulti:
45
54
  """A controller that manages multiple data parallel workers."""
46
55
 
47
56
  def __init__(
48
57
  self,
49
- load_balance_method: str,
50
58
  server_args: ServerArgs,
51
59
  port_args: PortArgs,
52
60
  model_overide_args,
53
61
  ):
54
- self.load_balance_method = LoadBalanceMethod.from_str(load_balance_method)
62
+ # Parse args
55
63
  self.server_args = server_args
56
64
  self.port_args = port_args
65
+ self.model_overide_args = model_overide_args
66
+ self.load_balance_method = LoadBalanceMethod.from_str(
67
+ server_args.load_balance_method
68
+ )
57
69
 
58
- if self.load_balance_method == LoadBalanceMethod.ROUND_ROBIN:
59
- self.round_robin_counter = 0
70
+ # Init communication
71
+ context = zmq.Context()
72
+ self.recv_from_tokenizer = context.socket(zmq.PULL)
73
+ self.recv_from_tokenizer.bind(f"tcp://127.0.0.1:{port_args.controller_port}")
60
74
 
61
- self.dispatch_lookup = {
75
+ # Dispatch method
76
+ self.round_robin_counter = 0
77
+ dispatch_lookup = {
62
78
  LoadBalanceMethod.ROUND_ROBIN: self.round_robin_scheduler,
63
79
  LoadBalanceMethod.SHORTEST_QUEUE: self.shortest_queue_scheduler,
64
80
  }
65
- self.dispatching = self.dispatch_lookup[self.load_balance_method]
66
-
67
- # Init communication
68
- context = zmq.asyncio.Context()
69
- self.recv_from_tokenizer = context.socket(zmq.PULL)
70
- self.recv_from_tokenizer.bind(f"tcp://127.0.0.1:{port_args.router_port}")
71
-
72
- # Init status
73
- self.recv_reqs = []
81
+ self.dispatching = dispatch_lookup[self.load_balance_method]
74
82
 
75
83
  # Start data parallel workers
76
- self.workers: Dict[int, DataParallelWorkerThread] = {}
77
- tp_size = server_args.tp_size
78
-
79
- def start_dp_worker(i):
80
- try:
81
- gpu_ids = list(range(i * tp_size, (i + 1) * tp_size))
82
- worker_thread = start_data_parallel_worker(
83
- server_args, port_args, model_overide_args, gpu_ids, i
84
- )
85
- self.workers[i] = worker_thread
86
- except Exception:
87
- logger.error(
88
- f"Failed to start local worker {i}\n{get_exception_traceback()}"
89
- )
90
-
84
+ self.workers = []
91
85
  for i in range(server_args.dp_size):
92
- start_dp_worker(i)
86
+ self.start_dp_worker(i)
87
+
88
+ def start_dp_worker(self, dp_worker_id: int):
89
+ tp_size = self.server_args.tp_size
93
90
 
94
- # Parallel launch is slower, probably due to the disk bandwidth limitations.
95
- # with ThreadPoolExecutor(server_args.dp_size) as executor:
96
- # executor.map(start_dp_worker, range(server_args.dp_size))
91
+ pipe_controller_reader, pipe_controller_writer = multiprocessing.Pipe(
92
+ duplex=False
93
+ )
97
94
 
98
- def have_any_live_worker(self):
99
- return any(worker_thread.liveness for worker_thread in self.workers.values())
95
+ gpu_ids = list(range(dp_worker_id * tp_size, (dp_worker_id + 1) * tp_size))
96
+ queue = multiprocessing.Queue()
97
+ proc = multiprocessing.Process(
98
+ target=start_controller_process_single,
99
+ args=(
100
+ self.server_args,
101
+ self.port_args,
102
+ pipe_controller_writer,
103
+ self.model_overide_args,
104
+ True,
105
+ gpu_ids,
106
+ dp_worker_id,
107
+ queue,
108
+ ),
109
+ )
110
+ proc.start()
100
111
 
101
- def put_req_to_worker(self, worker_id, req):
102
- self.workers[worker_id].request_queue.put(req)
112
+ controller_init_state = pipe_controller_reader.recv()
113
+ if controller_init_state != "init ok":
114
+ raise RuntimeError(
115
+ f"Initialization failed. controller_init_state: {controller_init_state}"
116
+ )
117
+ self.workers.append(
118
+ WorkerHandle(
119
+ proc=proc,
120
+ queue=queue,
121
+ )
122
+ )
103
123
 
104
- async def round_robin_scheduler(self, input_requests):
105
- available_workers = list(self.workers.keys())
124
+ def round_robin_scheduler(self, input_requests):
106
125
  for r in input_requests:
107
- self.put_req_to_worker(available_workers[self.round_robin_counter], r)
126
+ self.workers[self.round_robin_counter].queue.put(r)
108
127
  self.round_robin_counter = (self.round_robin_counter + 1) % len(
109
- available_workers
128
+ self.workers
110
129
  )
111
- return
112
130
 
113
- async def shortest_queue_scheduler(self, input_requests):
131
+ def shortest_queue_scheduler(self, input_requests):
114
132
  for r in input_requests:
115
- worker = min(
116
- self.workers, key=lambda w: self.workers[w].request_queue.qsize()
117
- )
118
- self.put_req_to_worker(worker, r)
119
- return
120
-
121
- async def remove_dead_workers(self):
122
- for i in list(self.workers.keys()):
123
- worker_thread = self.workers[i]
124
- if not worker_thread.liveness:
125
- worker_thread.join()
126
- # move unsuccessful requests back to the queue
127
- while not worker_thread.request_queue.empty():
128
- self.recv_reqs.append(worker_thread.request_queue.get())
129
- del self.workers[i]
130
- logger.info(f"Stale worker {i} removed")
131
-
132
- async def loop_for_forward(self):
133
- while True:
134
- await self.remove_dead_workers()
133
+ queue_sizes = [worker.queue.qsize() for worker in self.workers]
134
+ wid = np.argmin(queue_sizes)
135
+ self.workers[wid].queue.put(r)
135
136
 
136
- if self.have_any_live_worker():
137
- next_step_input = list(self.recv_reqs)
138
- self.recv_reqs = []
139
- if next_step_input:
140
- await self.dispatching(next_step_input)
141
- # else:
142
- # logger.error("There is no live worker.")
137
+ def loop_for_forward(self):
138
+ while True:
139
+ recv_reqs = self.recv_requests()
140
+ self.dispatching(recv_reqs)
143
141
 
144
- await asyncio.sleep(global_config.wait_for_new_request_delay)
142
+ def recv_requests(self):
143
+ recv_reqs = []
145
144
 
146
- async def loop_for_recv_requests(self):
147
145
  while True:
148
- recv_req = await self.recv_from_tokenizer.recv_pyobj()
146
+ try:
147
+ recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
148
+ except zmq.ZMQError:
149
+ break
150
+
149
151
  if isinstance(recv_req, FlushCacheReq):
150
152
  # TODO(lsyin): apply more specific flushCacheReq
151
- for worker_thread in self.workers.values():
152
- worker_thread.request_queue.put(recv_req)
153
- elif isinstance(recv_req, TokenizedGenerateReqInput):
154
- self.recv_reqs.append(recv_req)
153
+ for worker in self.workers:
154
+ worker.queue.put(recv_req)
155
155
  elif isinstance(recv_req, AbortReq):
156
156
  in_queue = False
157
- for i, req in enumerate(self.recv_reqs):
157
+ for i, req in enumerate(recv_reqs):
158
158
  if req.rid == recv_req.rid:
159
- self.recv_reqs[i] = recv_req
159
+ recv_reqs[i] = recv_req
160
160
  in_queue = True
161
161
  break
162
162
  if not in_queue:
163
163
  # Send abort req to all TP groups
164
- for worker in list(self.workers.keys()):
165
- self.put_req_to_worker(worker, recv_req)
164
+ for worker in self.workers:
165
+ worker.queue.put(recv_req)
166
+ elif isinstance(recv_req, TokenizedGenerateReqInput):
167
+ recv_reqs.append(recv_req)
166
168
  else:
167
169
  logger.error(f"Invalid object: {recv_req}")
168
170
 
171
+ return recv_reqs
172
+
169
173
 
170
174
  def start_controller_process(
171
175
  server_args: ServerArgs,
172
176
  port_args: PortArgs,
173
177
  pipe_writer,
174
- model_overide_args=None,
178
+ model_overide_args: dict,
175
179
  ):
180
+ """Start a controller process."""
181
+
176
182
  logging.basicConfig(
177
183
  level=getattr(logging, server_args.log_level.upper()),
178
184
  format="%(message)s",
179
185
  )
180
186
 
181
187
  try:
182
- controller = Controller(
183
- server_args.load_balance_method, server_args, port_args, model_overide_args
184
- )
188
+ controller = ControllerMulti(server_args, port_args, model_overide_args)
185
189
  except Exception:
186
190
  pipe_writer.send(get_exception_traceback())
187
191
  raise
188
- pipe_writer.send("init ok")
189
192
 
190
- loop = asyncio.new_event_loop()
191
- loop.set_default_executor(ThreadPoolExecutor(max_workers=256))
193
+ pipe_writer.send("init ok")
192
194
 
193
- asyncio.set_event_loop(loop)
194
- loop.create_task(controller.loop_for_recv_requests())
195
- loop.run_until_complete(controller.loop_for_forward())
195
+ try:
196
+ controller.loop_for_forward()
197
+ except Exception:
198
+ logger.error("Exception in ControllerMulti:\n" + get_exception_traceback())
199
+ finally:
200
+ for w in controller.workers:
201
+ os.kill(w.proc.pid, 9)
202
+ kill_parent_process()
@@ -1,138 +1,88 @@
1
1
  """A controller that manages a group of tensor parallel workers."""
2
2
 
3
- import multiprocessing
4
3
  import logging
4
+ import multiprocessing
5
5
  import os
6
- import pickle
6
+ from typing import List
7
7
 
8
- import torch
9
- import torch.distributed as dist
10
8
  import zmq
11
- import zmq.asyncio
12
9
 
13
- from sglang.srt.managers.controller.tp_worker import ModelTpServer
14
- from sglang.srt.server_args import PortArgs, ServerArgs, ModelPortArgs
10
+ from sglang.srt.managers.controller.tp_worker import (
11
+ ModelTpServer,
12
+ broadcast_recv_input,
13
+ launch_tp_servers,
14
+ )
15
+ from sglang.srt.server_args import PortArgs, ServerArgs
15
16
  from sglang.srt.utils import kill_parent_process
16
17
  from sglang.utils import get_exception_traceback
17
18
 
18
-
19
19
  logger = logging.getLogger("srt.controller")
20
20
 
21
21
 
22
- def run_tp_server(
23
- gpu_id: int,
24
- tp_rank: int,
25
- server_args: ServerArgs,
26
- model_port_args: ModelPortArgs,
27
- model_overide_args: dict,
28
- ):
29
- """Run a tp server."""
30
- try:
31
- model_server = ModelTpServer(
32
- gpu_id,
33
- tp_rank,
34
- server_args,
35
- model_port_args,
36
- model_overide_args,
37
- )
38
- tp_cpu_group = model_server.model_runner.tp_group.cpu_group
39
-
40
- while True:
41
- recv_reqs = broadcast_recv_input(None, tp_rank, tp_cpu_group)
42
- model_server.exposed_step(recv_reqs)
43
- except Exception:
44
- logger.error("Exception in run_tp_server:\n" + get_exception_traceback())
45
- raise
46
-
47
-
48
- def launch_tp_servers(gpu_ids, tp_rank_range, server_args,
49
- model_port_args, model_overide_args):
50
- """Launch multiple tp servers."""
51
- procs = []
52
- for i in tp_rank_range:
53
- proc = multiprocessing.Process(target=run_tp_server, args=(
54
- gpu_ids[i], i, server_args, model_port_args, model_overide_args
55
- ))
56
- proc.start()
57
- procs.append(proc)
58
-
59
- return procs
60
-
61
-
62
- def broadcast_recv_input(data, rank, dist_group):
63
- """Broadcast inputs from rank=0 to all other ranks with torch.dist backend."""
64
-
65
- if rank == 0:
66
- if len(data) == 0:
67
- tensor_size = torch.tensor([0], dtype=torch.long)
68
- dist.broadcast(tensor_size, src=0, group=dist_group)
69
- else:
70
- serialized_data = pickle.dumps(data)
71
- size = len(serialized_data)
72
- tensor_data = torch.ByteTensor(list(serialized_data))
73
- tensor_size = torch.tensor([size], dtype=torch.long)
74
-
75
- dist.broadcast(tensor_size, src=0, group=dist_group)
76
- dist.broadcast(tensor_data, src=0, group=dist_group)
77
- else:
78
- tensor_size = torch.tensor([0], dtype=torch.long)
79
- dist.broadcast(tensor_size, src=0, group=dist_group)
80
- size = tensor_size.item()
81
-
82
- if size == 0:
83
- return []
84
-
85
- tensor_data = torch.empty(size, dtype=torch.uint8)
86
- dist.broadcast(tensor_data, src=0, group=dist_group)
87
-
88
- serialized_data = bytes(tensor_data.tolist())
89
- data = pickle.loads(serialized_data)
90
- return data
91
-
92
-
93
22
  class ControllerSingle:
94
23
  """A controller that manages a group of tensor parallel workers."""
95
24
 
96
- def __init__(self, server_args: ServerArgs, port_args: PortArgs, model_overide_args: dict):
25
+ def __init__(
26
+ self,
27
+ server_args: ServerArgs,
28
+ port_args: PortArgs,
29
+ model_overide_args: dict,
30
+ gpu_ids: List[int],
31
+ is_data_parallel_worker: bool,
32
+ dp_worker_id: int,
33
+ mp_queue: multiprocessing.Queue,
34
+ ):
97
35
  # Parse args
98
- self.server_args = server_args
36
+ self.tp_size = server_args.tp_size
37
+ self.is_dp_worker = is_data_parallel_worker
38
+ self.dp_worker_id = dp_worker_id
39
+ self.mp_queue = mp_queue
99
40
 
100
41
  # Init communication
101
42
  context = zmq.Context(2)
102
- self.recv_from_tokenizer = context.socket(zmq.PULL)
103
- self.recv_from_tokenizer.bind(f"tcp://127.0.0.1:{port_args.router_port}")
43
+
44
+ if not self.is_dp_worker:
45
+ self.recv_from_tokenizer = context.socket(zmq.PULL)
46
+ self.recv_from_tokenizer.bind(
47
+ f"tcp://127.0.0.1:{port_args.controller_port}"
48
+ )
104
49
 
105
50
  self.send_to_detokenizer = context.socket(zmq.PUSH)
106
51
  self.send_to_detokenizer.connect(
107
52
  f"tcp://127.0.0.1:{port_args.detokenizer_port}"
108
53
  )
109
54
 
110
- # Init model server
111
- tp_size_local = server_args.tp_size // server_args.nnodes
112
- gpu_ids = [i for _ in range(server_args.nnodes) for i in range(tp_size_local)]
113
-
114
55
  # Launch other tp ranks
56
+ tp_size_local = server_args.tp_size // server_args.nnodes
57
+ self.tp_procs = []
115
58
  if tp_size_local > 1:
116
59
  tp_rank_range = range(1, tp_size_local)
117
60
  self.tp_procs = launch_tp_servers(
118
- gpu_ids, tp_rank_range, server_args,
119
- port_args.model_port_args[0], model_overide_args)
61
+ gpu_ids,
62
+ tp_rank_range,
63
+ server_args,
64
+ port_args.nccl_ports[dp_worker_id],
65
+ model_overide_args,
66
+ )
120
67
 
121
68
  # Launch tp rank 0
122
69
  self.tp_server = ModelTpServer(
123
70
  gpu_ids[0],
124
71
  0,
125
72
  server_args,
126
- port_args.model_port_args[0],
73
+ port_args.nccl_ports[dp_worker_id],
127
74
  model_overide_args,
128
75
  )
129
76
  self.tp_cpu_group = self.tp_server.model_runner.tp_group.cpu_group
130
77
 
131
78
  def loop_for_forward(self):
132
79
  while True:
133
- recv_reqs = self.recv_requests()
80
+ if not self.is_dp_worker:
81
+ recv_reqs = self.recv_requests_from_zmq()
82
+ else:
83
+ recv_reqs = self.recv_requests_from_mp_queue()
134
84
 
135
- if self.server_args.tp_size > 1:
85
+ if self.tp_size > 1:
136
86
  broadcast_recv_input(recv_reqs, 0, self.tp_cpu_group)
137
87
 
138
88
  out_pyobjs = self.tp_server.exposed_step(recv_reqs)
@@ -140,27 +90,57 @@ class ControllerSingle:
140
90
  for obj in out_pyobjs:
141
91
  self.send_to_detokenizer.send_pyobj(obj)
142
92
 
143
- def recv_requests(self):
93
+ def recv_requests_from_zmq(self):
144
94
  recv_reqs = []
145
95
  while True:
146
96
  try:
147
97
  recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
148
- recv_reqs.append(recv_req)
149
98
  except zmq.ZMQError:
150
99
  break
100
+ recv_reqs.append(recv_req)
101
+
102
+ return recv_reqs
103
+
104
+ def recv_requests_from_mp_queue(self):
105
+ recv_reqs = []
106
+ while not self.mp_queue.empty():
107
+ recv_reqs.append(self.mp_queue.get())
151
108
  return recv_reqs
152
109
 
153
110
 
154
111
  def start_controller_process(
155
- server_args: ServerArgs, port_args: PortArgs, pipe_writer, model_overide_args: dict
112
+ server_args: ServerArgs,
113
+ port_args: PortArgs,
114
+ pipe_writer: multiprocessing.connection.Connection,
115
+ model_overide_args: dict,
116
+ is_data_parallel_worker: bool = False,
117
+ gpu_ids: List[int] = None,
118
+ dp_worker_id: int = None,
119
+ queue: multiprocessing.connection.Connection = None,
156
120
  ):
121
+ """Start a controller process."""
122
+
157
123
  logging.basicConfig(
158
124
  level=getattr(logging, server_args.log_level.upper()),
159
125
  format="%(message)s",
160
126
  )
161
127
 
128
+ if not is_data_parallel_worker:
129
+ tp_size_local = server_args.tp_size // server_args.nnodes
130
+ gpu_ids = [i for _ in range(server_args.nnodes) for i in range(tp_size_local)]
131
+ dp_worker_id = 0
132
+ queue = None
133
+
162
134
  try:
163
- controller = ControllerSingle(server_args, port_args, model_overide_args)
135
+ controller = ControllerSingle(
136
+ server_args,
137
+ port_args,
138
+ model_overide_args,
139
+ gpu_ids,
140
+ is_data_parallel_worker,
141
+ dp_worker_id,
142
+ queue,
143
+ )
164
144
  except Exception:
165
145
  pipe_writer.send(get_exception_traceback())
166
146
  raise
@@ -9,19 +9,23 @@ from typing import Optional, Type
9
9
 
10
10
  import torch
11
11
  import torch.nn as nn
12
+ from flashinfer import (
13
+ BatchDecodeWithPagedKVCacheWrapper,
14
+ BatchPrefillWithPagedKVCacheWrapper,
15
+ BatchPrefillWithRaggedKVCacheWrapper,
16
+ )
17
+ from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
12
18
  from vllm.config import DeviceConfig, LoadConfig
13
19
  from vllm.config import ModelConfig as VllmModelConfig
14
- from vllm.distributed import init_distributed_environment, initialize_model_parallel, get_tp_group
15
- from vllm.model_executor.model_loader import get_model
20
+ from vllm.distributed import (
21
+ get_tp_group,
22
+ init_distributed_environment,
23
+ initialize_model_parallel,
24
+ )
16
25
  from vllm.model_executor.models import ModelRegistry
17
26
 
18
27
  from sglang.global_config import global_config
19
- from sglang.srt.managers.controller.infer_batch import (
20
- Batch,
21
- ForwardMode,
22
- InputMetadata,
23
- global_server_args_dict,
24
- )
28
+ from sglang.srt.managers.controller.infer_batch import Batch, ForwardMode, InputMetadata
25
29
  from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
26
30
  from sglang.srt.server_args import ServerArgs
27
31
  from sglang.srt.utils import (
@@ -87,12 +91,6 @@ class ModelRunner:
87
91
  "The memory capacity is unbalanced. Some GPUs may be occupied by other processes."
88
92
  )
89
93
 
90
- # Set some global args
91
- global_server_args_dict["disable_flashinfer"] = server_args.disable_flashinfer
92
- global_server_args_dict[
93
- "attention_reduce_in_fp32"
94
- ] = server_args.attention_reduce_in_fp32
95
-
96
94
  # Load the model and create memory pool
97
95
  self.load_model()
98
96
  self.init_memory_pool(total_gpu_memory)
@@ -124,6 +122,15 @@ class ModelRunner:
124
122
  if self.model_config.model_overide_args is not None:
125
123
  vllm_model_config.hf_config.update(self.model_config.model_overide_args)
126
124
 
125
+ if (
126
+ self.server_args.efficient_weight_load
127
+ and "llama" in self.server_args.model_path.lower()
128
+ and self.server_args.quantization == "fp8"
129
+ ):
130
+ from sglang.srt.model_loader.model_loader import get_model
131
+ else:
132
+ from vllm.model_executor.model_loader import get_model
133
+
127
134
  self.model = get_model(
128
135
  model_config=vllm_model_config,
129
136
  device_config=device_config,
@@ -169,7 +176,10 @@ class ModelRunner:
169
176
  )
170
177
 
171
178
  self.req_to_token_pool = ReqToTokenPool(
172
- int(self.max_total_num_tokens / self.model_config.context_len * 256),
179
+ max(
180
+ int(self.max_total_num_tokens / self.model_config.context_len * 512),
181
+ 2048,
182
+ ),
173
183
  self.model_config.context_len + 8,
174
184
  )
175
185
  self.token_to_kv_pool = TokenToKVPool(
@@ -200,13 +210,6 @@ class ModelRunner:
200
210
  self.flashinfer_decode_wrapper = None
201
211
  return
202
212
 
203
- from flashinfer import (
204
- BatchDecodeWithPagedKVCacheWrapper,
205
- BatchPrefillWithPagedKVCacheWrapper,
206
- BatchPrefillWithRaggedKVCacheWrapper,
207
- )
208
- from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
209
-
210
213
  if not _grouped_size_compiled_for_decode_kernels(
211
214
  self.model_config.num_attention_heads // self.tp_size,
212
215
  self.model_config.get_num_kv_heads(self.tp_size),
@@ -237,12 +240,24 @@ class ModelRunner:
237
240
  self.cuda_graph_runner = None
238
241
  return
239
242
 
240
- logger.info(f"[gpu_id={self.gpu_id}] Capture cuda graph begin.")
241
- batch_size_list = [1, 2, 4] + [i * 8 for i in range(1, 16)]
243
+ logger.info(
244
+ f"[gpu_id={self.gpu_id}] Capture cuda graph begin. This can take up to several minutes."
245
+ )
246
+ batch_size_list = [1, 2, 4] + [i * 8 for i in range(1, 17)]
242
247
  self.cuda_graph_runner = CudaGraphRunner(
243
- self, max_batch_size_to_capture=max(batch_size_list)
248
+ self,
249
+ max_batch_size_to_capture=max(batch_size_list),
250
+ use_torch_compile=self.server_args.enable_torch_compile,
244
251
  )
245
- self.cuda_graph_runner.capture(batch_size_list)
252
+ try:
253
+ self.cuda_graph_runner.capture(batch_size_list)
254
+ except RuntimeError as e:
255
+ raise Exception(
256
+ f"Capture cuda graph failed: {e}. Possible solutions:\n"
257
+ f"1. disable cuda graph by --disable-cuda-graph\n"
258
+ f"2. set --mem-fraction-static to a smaller value\n"
259
+ f"Open an issue on GitHub with reproducible scripts if you need help.\n"
260
+ )
246
261
 
247
262
  @torch.inference_mode()
248
263
  def forward_decode(self, batch: Batch):
@@ -14,7 +14,7 @@ class ScheduleHeuristic:
14
14
  tree_cache,
15
15
  ):
16
16
  if tree_cache.disable and schedule_heuristic == "lpm":
17
- # LMP is not meaningless when tree cache is disabled.
17
+ # LMP is meaningless when the tree cache is disabled.
18
18
  schedule_heuristic = "fcfs"
19
19
 
20
20
  self.schedule_heuristic = schedule_heuristic
@@ -28,11 +28,16 @@ class ScheduleHeuristic:
28
28
  # longest prefix match
29
29
  forward_queue.sort(key=lambda x: -len(x.prefix_indices))
30
30
  return forward_queue
31
+ elif self.schedule_heuristic == "fcfs":
32
+ # first come first serve
33
+ return forward_queue
34
+ elif self.schedule_heuristic == "lof":
35
+ # longest output first
36
+ forward_queue.sort(key=lambda x: -x.sampling_params.max_new_tokens)
37
+ return forward_queue
31
38
  elif self.schedule_heuristic == "random":
32
39
  random.shuffle(forward_queue)
33
40
  return forward_queue
34
- elif self.schedule_heuristic == "fcfs":
35
- return forward_queue
36
41
  elif self.schedule_heuristic == "dfs-weight":
37
42
  last_node_to_reqs = defaultdict(list)
38
43
  for req in forward_queue: