sglang 0.1.20__py3-none-any.whl → 0.1.22__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. sglang/__init__.py +8 -8
  2. sglang/api.py +1 -1
  3. sglang/backend/runtime_endpoint.py +14 -4
  4. sglang/backend/vertexai.py +5 -4
  5. sglang/bench.py +627 -0
  6. sglang/bench_latency.py +22 -20
  7. sglang/bench_serving.py +758 -0
  8. sglang/check_env.py +171 -0
  9. sglang/global_config.py +3 -1
  10. sglang/lang/backend/__init__.py +0 -0
  11. sglang/lang/backend/anthropic.py +77 -0
  12. sglang/lang/backend/base_backend.py +80 -0
  13. sglang/lang/backend/litellm.py +90 -0
  14. sglang/lang/backend/openai.py +438 -0
  15. sglang/lang/backend/runtime_endpoint.py +283 -0
  16. sglang/lang/backend/vertexai.py +149 -0
  17. sglang/lang/chat_template.py +2 -2
  18. sglang/lang/ir.py +3 -3
  19. sglang/lang/tracer.py +1 -1
  20. sglang/launch_server.py +1 -1
  21. sglang/launch_server_llavavid.py +1 -4
  22. sglang/srt/conversation.py +1 -1
  23. sglang/srt/layers/context_flashattention_nopad.py +0 -29
  24. sglang/srt/layers/extend_attention.py +0 -39
  25. sglang/srt/layers/linear.py +869 -0
  26. sglang/srt/layers/quantization/__init__.py +49 -0
  27. sglang/srt/layers/quantization/fp8.py +662 -0
  28. sglang/srt/layers/radix_attention.py +31 -5
  29. sglang/srt/layers/token_attention.py +1 -51
  30. sglang/srt/managers/controller/cuda_graph_runner.py +44 -18
  31. sglang/srt/managers/controller/infer_batch.py +76 -72
  32. sglang/srt/managers/controller/manager_multi.py +109 -98
  33. sglang/srt/managers/controller/manager_single.py +105 -50
  34. sglang/srt/managers/controller/model_runner.py +42 -18
  35. sglang/srt/managers/controller/radix_cache.py +4 -3
  36. sglang/srt/managers/controller/schedule_heuristic.py +4 -0
  37. sglang/srt/managers/controller/tp_worker.py +143 -156
  38. sglang/srt/managers/detokenizer_manager.py +49 -5
  39. sglang/srt/managers/io_struct.py +36 -17
  40. sglang/srt/managers/tokenizer_manager.py +228 -125
  41. sglang/srt/memory_pool.py +46 -58
  42. sglang/srt/model_loader/model_loader.py +277 -0
  43. sglang/srt/model_loader/utils.py +260 -0
  44. sglang/srt/models/chatglm.py +1 -0
  45. sglang/srt/models/dbrx.py +1 -0
  46. sglang/srt/models/grok.py +1 -0
  47. sglang/srt/models/internlm2.py +317 -0
  48. sglang/srt/models/llama2.py +65 -16
  49. sglang/srt/models/llama_classification.py +1 -0
  50. sglang/srt/models/llava.py +1 -0
  51. sglang/srt/models/llavavid.py +1 -0
  52. sglang/srt/models/minicpm.py +2 -8
  53. sglang/srt/models/mixtral.py +1 -0
  54. sglang/srt/models/mixtral_quant.py +1 -0
  55. sglang/srt/models/qwen.py +1 -0
  56. sglang/srt/models/qwen2.py +6 -0
  57. sglang/srt/models/qwen2_moe.py +130 -108
  58. sglang/srt/models/stablelm.py +1 -0
  59. sglang/srt/openai_api/adapter.py +432 -0
  60. sglang/srt/openai_api/api_adapter.py +432 -0
  61. sglang/srt/openai_api/openai_api_adapter.py +431 -0
  62. sglang/srt/openai_api/openai_protocol.py +207 -0
  63. sglang/srt/openai_api/protocol.py +208 -0
  64. sglang/srt/openai_protocol.py +17 -0
  65. sglang/srt/sampling_params.py +2 -0
  66. sglang/srt/server.py +114 -90
  67. sglang/srt/server_args.py +27 -17
  68. sglang/srt/utils.py +17 -118
  69. sglang/test/test_conversation.py +1 -1
  70. sglang/test/test_openai_protocol.py +1 -1
  71. sglang/test/test_programs.py +1 -1
  72. sglang/test/test_utils.py +2 -2
  73. {sglang-0.1.20.dist-info → sglang-0.1.22.dist-info}/METADATA +157 -159
  74. sglang-0.1.22.dist-info/RECORD +103 -0
  75. {sglang-0.1.20.dist-info → sglang-0.1.22.dist-info}/WHEEL +1 -1
  76. sglang-0.1.20.dist-info/RECORD +0 -82
  77. {sglang-0.1.20.dist-info → sglang-0.1.22.dist-info}/LICENSE +0 -0
  78. {sglang-0.1.20.dist-info → sglang-0.1.22.dist-info}/top_level.txt +0 -0
@@ -3,19 +3,17 @@ A controller that manages multiple data parallel workers.
3
3
  Each data parallel worker can manage multiple tensor parallel workers.
4
4
  """
5
5
 
6
- import asyncio
6
+ import dataclasses
7
7
  import logging
8
- from concurrent.futures import ThreadPoolExecutor
8
+ import multiprocessing
9
+ import os
9
10
  from enum import Enum, auto
10
- from typing import Dict
11
11
 
12
+ import numpy as np
12
13
  import zmq
13
- import zmq.asyncio
14
14
 
15
- from sglang.global_config import global_config
16
- from sglang.srt.managers.controller.dp_worker import (
17
- DataParallelWorkerThread,
18
- start_data_parallel_worker,
15
+ from sglang.srt.managers.controller.manager_single import (
16
+ start_controller_process as start_controller_process_single,
19
17
  )
20
18
  from sglang.srt.managers.io_struct import (
21
19
  AbortReq,
@@ -23,12 +21,15 @@ from sglang.srt.managers.io_struct import (
23
21
  TokenizedGenerateReqInput,
24
22
  )
25
23
  from sglang.srt.server_args import PortArgs, ServerArgs
24
+ from sglang.srt.utils import kill_parent_process
26
25
  from sglang.utils import get_exception_traceback
27
26
 
28
27
  logger = logging.getLogger("srt.controller")
29
28
 
30
29
 
31
30
  class LoadBalanceMethod(Enum):
31
+ """Load balance method."""
32
+
32
33
  ROUND_ROBIN = auto()
33
34
  SHORTEST_QUEUE = auto()
34
35
 
@@ -41,151 +42,161 @@ class LoadBalanceMethod(Enum):
41
42
  raise ValueError(f"Invalid load balance method: {method}") from exc
42
43
 
43
44
 
44
- class Controller:
45
+ @dataclasses.dataclass
46
+ class WorkerHandle:
47
+ """Store the handle of a data parallel worker."""
48
+
49
+ proc: multiprocessing.Process
50
+ queue: multiprocessing.Queue
51
+
52
+
53
+ class ControllerMulti:
54
+ """A controller that manages multiple data parallel workers."""
55
+
45
56
  def __init__(
46
57
  self,
47
- load_balance_method: str,
48
58
  server_args: ServerArgs,
49
59
  port_args: PortArgs,
50
60
  model_overide_args,
51
61
  ):
52
- self.load_balance_method = LoadBalanceMethod.from_str(load_balance_method)
62
+ # Parse args
53
63
  self.server_args = server_args
54
64
  self.port_args = port_args
65
+ self.model_overide_args = model_overide_args
66
+ self.load_balance_method = LoadBalanceMethod.from_str(
67
+ server_args.load_balance_method
68
+ )
55
69
 
56
- if self.load_balance_method == LoadBalanceMethod.ROUND_ROBIN:
57
- self.round_robin_counter = 0
70
+ # Init communication
71
+ context = zmq.Context()
72
+ self.recv_from_tokenizer = context.socket(zmq.PULL)
73
+ self.recv_from_tokenizer.bind(f"tcp://127.0.0.1:{port_args.controller_port}")
58
74
 
59
- self.dispatch_lookup = {
75
+ # Dispatch method
76
+ self.round_robin_counter = 0
77
+ dispatch_lookup = {
60
78
  LoadBalanceMethod.ROUND_ROBIN: self.round_robin_scheduler,
61
79
  LoadBalanceMethod.SHORTEST_QUEUE: self.shortest_queue_scheduler,
62
80
  }
63
- self.dispatching = self.dispatch_lookup[self.load_balance_method]
64
-
65
- # Init communication
66
- context = zmq.asyncio.Context()
67
- self.recv_from_tokenizer = context.socket(zmq.PULL)
68
- self.recv_from_tokenizer.bind(f"tcp://127.0.0.1:{port_args.router_port}")
69
-
70
- # Init status
71
- self.recv_reqs = []
81
+ self.dispatching = dispatch_lookup[self.load_balance_method]
72
82
 
73
83
  # Start data parallel workers
74
- self.workers: Dict[int, DataParallelWorkerThread] = {}
75
- tp_size = server_args.tp_size
76
-
77
- def start_dp_worker(i):
78
- try:
79
- gpu_ids = list(range(i * tp_size, (i + 1) * tp_size))
80
- worker_thread = start_data_parallel_worker(
81
- server_args, port_args, model_overide_args, gpu_ids, i
82
- )
83
- self.workers[i] = worker_thread
84
- except Exception:
85
- logger.error(
86
- f"Failed to start local worker {i}\n{get_exception_traceback()}"
87
- )
88
-
84
+ self.workers = []
89
85
  for i in range(server_args.dp_size):
90
- start_dp_worker(i)
86
+ self.start_dp_worker(i)
91
87
 
92
- # Parallel launch is slower, probably due to the disk bandwidth limitations.
93
- # with ThreadPoolExecutor(server_args.dp_size) as executor:
94
- # executor.map(start_dp_worker, range(server_args.dp_size))
88
+ def start_dp_worker(self, dp_worker_id: int):
89
+ tp_size = self.server_args.tp_size
90
+
91
+ pipe_controller_reader, pipe_controller_writer = multiprocessing.Pipe(
92
+ duplex=False
93
+ )
95
94
 
96
- def have_any_live_worker(self):
97
- return any(worker_thread.liveness for worker_thread in self.workers.values())
95
+ gpu_ids = list(range(dp_worker_id * tp_size, (dp_worker_id + 1) * tp_size))
96
+ queue = multiprocessing.Queue()
97
+ proc = multiprocessing.Process(
98
+ target=start_controller_process_single,
99
+ args=(
100
+ self.server_args,
101
+ self.port_args,
102
+ pipe_controller_writer,
103
+ self.model_overide_args,
104
+ True,
105
+ gpu_ids,
106
+ dp_worker_id,
107
+ queue,
108
+ ),
109
+ )
110
+ proc.start()
98
111
 
99
- def put_req_to_worker(self, worker_id, req):
100
- self.workers[worker_id].request_queue.put(req)
112
+ controller_init_state = pipe_controller_reader.recv()
113
+ if controller_init_state != "init ok":
114
+ raise RuntimeError(
115
+ f"Initialization failed. controller_init_state: {controller_init_state}"
116
+ )
117
+ self.workers.append(
118
+ WorkerHandle(
119
+ proc=proc,
120
+ queue=queue,
121
+ )
122
+ )
101
123
 
102
- async def round_robin_scheduler(self, input_requests):
103
- available_workers = list(self.workers.keys())
124
+ def round_robin_scheduler(self, input_requests):
104
125
  for r in input_requests:
105
- self.put_req_to_worker(available_workers[self.round_robin_counter], r)
126
+ self.workers[self.round_robin_counter].queue.put(r)
106
127
  self.round_robin_counter = (self.round_robin_counter + 1) % len(
107
- available_workers
128
+ self.workers
108
129
  )
109
- return
110
130
 
111
- async def shortest_queue_scheduler(self, input_requests):
131
+ def shortest_queue_scheduler(self, input_requests):
112
132
  for r in input_requests:
113
- worker = min(
114
- self.workers, key=lambda w: self.workers[w].request_queue.qsize()
115
- )
116
- self.put_req_to_worker(worker, r)
117
- return
118
-
119
- async def remove_dead_workers(self):
120
- for i in list(self.workers.keys()):
121
- worker_thread = self.workers[i]
122
- if not worker_thread.liveness:
123
- worker_thread.join()
124
- # move unsuccessful requests back to the queue
125
- while not worker_thread.request_queue.empty():
126
- self.recv_reqs.append(worker_thread.request_queue.get())
127
- del self.workers[i]
128
- logger.info(f"Stale worker {i} removed")
129
-
130
- async def loop_for_forward(self):
131
- while True:
132
- await self.remove_dead_workers()
133
+ queue_sizes = [worker.queue.qsize() for worker in self.workers]
134
+ wid = np.argmin(queue_sizes)
135
+ self.workers[wid].queue.put(r)
133
136
 
134
- if self.have_any_live_worker():
135
- next_step_input = list(self.recv_reqs)
136
- self.recv_reqs = []
137
- if next_step_input:
138
- await self.dispatching(next_step_input)
139
- # else:
140
- # logger.error("There is no live worker.")
137
+ def loop_for_forward(self):
138
+ while True:
139
+ recv_reqs = self.recv_requests()
140
+ self.dispatching(recv_reqs)
141
141
 
142
- await asyncio.sleep(global_config.wait_for_new_request_delay)
142
+ def recv_requests(self):
143
+ recv_reqs = []
143
144
 
144
- async def loop_for_recv_requests(self):
145
145
  while True:
146
- recv_req = await self.recv_from_tokenizer.recv_pyobj()
146
+ try:
147
+ recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
148
+ except zmq.ZMQError:
149
+ break
150
+
147
151
  if isinstance(recv_req, FlushCacheReq):
148
152
  # TODO(lsyin): apply more specific flushCacheReq
149
- for worker_thread in self.workers.values():
150
- worker_thread.request_queue.put(recv_req)
151
- elif isinstance(recv_req, TokenizedGenerateReqInput):
152
- self.recv_reqs.append(recv_req)
153
+ for worker in self.workers:
154
+ worker.queue.put(recv_req)
153
155
  elif isinstance(recv_req, AbortReq):
154
156
  in_queue = False
155
- for i, req in enumerate(self.recv_reqs):
157
+ for i, req in enumerate(recv_reqs):
156
158
  if req.rid == recv_req.rid:
157
- self.recv_reqs[i] = recv_req
159
+ recv_reqs[i] = recv_req
158
160
  in_queue = True
159
161
  break
160
162
  if not in_queue:
161
163
  # Send abort req to all TP groups
162
- for worker in list(self.workers.keys()):
163
- self.put_req_to_worker(worker, recv_req)
164
+ for worker in self.workers:
165
+ worker.queue.put(recv_req)
166
+ elif isinstance(recv_req, TokenizedGenerateReqInput):
167
+ recv_reqs.append(recv_req)
164
168
  else:
165
169
  logger.error(f"Invalid object: {recv_req}")
166
170
 
171
+ return recv_reqs
172
+
167
173
 
168
174
  def start_controller_process(
169
175
  server_args: ServerArgs,
170
176
  port_args: PortArgs,
171
177
  pipe_writer,
172
- model_overide_args=None,
178
+ model_overide_args: dict,
173
179
  ):
180
+ """Start a controller process."""
181
+
174
182
  logging.basicConfig(
175
183
  level=getattr(logging, server_args.log_level.upper()),
176
184
  format="%(message)s",
177
185
  )
178
186
 
179
187
  try:
180
- controller = Controller(
181
- server_args.load_balance_method, server_args, port_args, model_overide_args
182
- )
188
+ controller = ControllerMulti(server_args, port_args, model_overide_args)
183
189
  except Exception:
184
190
  pipe_writer.send(get_exception_traceback())
185
191
  raise
186
192
 
187
193
  pipe_writer.send("init ok")
188
- loop = asyncio.get_event_loop()
189
- asyncio.set_event_loop(loop)
190
- loop.create_task(controller.loop_for_recv_requests())
191
- loop.run_until_complete(controller.loop_for_forward())
194
+
195
+ try:
196
+ controller.loop_for_forward()
197
+ except Exception:
198
+ logger.error("Exception in ControllerMulti:\n" + get_exception_traceback())
199
+ finally:
200
+ for w in controller.workers:
201
+ os.kill(w.proc.pid, 9)
202
+ kill_parent_process()
@@ -1,102 +1,157 @@
1
1
  """A controller that manages a group of tensor parallel workers."""
2
2
 
3
- import asyncio
4
3
  import logging
5
- from concurrent.futures import ThreadPoolExecutor
4
+ import multiprocessing
5
+ import os
6
+ from typing import List
6
7
 
7
- import uvloop
8
8
  import zmq
9
- import zmq.asyncio
10
9
 
11
- from sglang.global_config import global_config
12
- from sglang.srt.managers.controller.tp_worker import ModelTpClient
10
+ from sglang.srt.managers.controller.tp_worker import (
11
+ ModelTpServer,
12
+ broadcast_recv_input,
13
+ launch_tp_servers,
14
+ )
13
15
  from sglang.srt.server_args import PortArgs, ServerArgs
14
16
  from sglang.srt.utils import kill_parent_process
15
17
  from sglang.utils import get_exception_traceback
16
18
 
17
- asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
18
-
19
19
  logger = logging.getLogger("srt.controller")
20
20
 
21
21
 
22
22
  class ControllerSingle:
23
- def __init__(self, model_client: ModelTpClient, port_args: PortArgs):
23
+ """A controller that manages a group of tensor parallel workers."""
24
+
25
+ def __init__(
26
+ self,
27
+ server_args: ServerArgs,
28
+ port_args: PortArgs,
29
+ model_overide_args: dict,
30
+ gpu_ids: List[int],
31
+ is_data_parallel_worker: bool,
32
+ dp_worker_id: int,
33
+ mp_queue: multiprocessing.Queue,
34
+ ):
35
+ # Parse args
36
+ self.tp_size = server_args.tp_size
37
+ self.is_dp_worker = is_data_parallel_worker
38
+ self.dp_worker_id = dp_worker_id
39
+ self.mp_queue = mp_queue
40
+
24
41
  # Init communication
25
- context = zmq.asyncio.Context(2)
26
- self.recv_from_tokenizer = context.socket(zmq.PULL)
27
- self.recv_from_tokenizer.bind(f"tcp://127.0.0.1:{port_args.router_port}")
42
+ context = zmq.Context(2)
43
+
44
+ if not self.is_dp_worker:
45
+ self.recv_from_tokenizer = context.socket(zmq.PULL)
46
+ self.recv_from_tokenizer.bind(
47
+ f"tcp://127.0.0.1:{port_args.controller_port}"
48
+ )
28
49
 
29
50
  self.send_to_detokenizer = context.socket(zmq.PUSH)
30
51
  self.send_to_detokenizer.connect(
31
52
  f"tcp://127.0.0.1:{port_args.detokenizer_port}"
32
53
  )
33
54
 
34
- # Init status
35
- self.model_client = model_client
36
- self.recv_reqs = []
37
-
38
- # Init some configs
39
- self.request_dependency_delay = global_config.request_dependency_delay
55
+ # Launch other tp ranks
56
+ tp_size_local = server_args.tp_size // server_args.nnodes
57
+ self.tp_procs = []
58
+ if tp_size_local > 1:
59
+ tp_rank_range = range(1, tp_size_local)
60
+ self.tp_procs = launch_tp_servers(
61
+ gpu_ids,
62
+ tp_rank_range,
63
+ server_args,
64
+ port_args.nccl_ports[dp_worker_id],
65
+ model_overide_args,
66
+ )
67
+
68
+ # Launch tp rank 0
69
+ self.tp_server = ModelTpServer(
70
+ gpu_ids[0],
71
+ 0,
72
+ server_args,
73
+ port_args.nccl_ports[dp_worker_id],
74
+ model_overide_args,
75
+ )
76
+ self.tp_cpu_group = self.tp_server.model_runner.tp_group.cpu_group
40
77
 
41
- async def loop_for_forward(self):
78
+ def loop_for_forward(self):
42
79
  while True:
43
- next_step_input = list(self.recv_reqs)
44
- self.recv_reqs = []
45
- out_pyobjs = await self.model_client.step(next_step_input)
80
+ if not self.is_dp_worker:
81
+ recv_reqs = self.recv_requests_from_zmq()
82
+ else:
83
+ recv_reqs = self.recv_requests_from_mp_queue()
84
+
85
+ if self.tp_size > 1:
86
+ broadcast_recv_input(recv_reqs, 0, self.tp_cpu_group)
87
+
88
+ out_pyobjs = self.tp_server.exposed_step(recv_reqs)
46
89
 
47
90
  for obj in out_pyobjs:
48
91
  self.send_to_detokenizer.send_pyobj(obj)
49
92
 
50
- # async sleep for receiving the subsequent request and avoiding cache miss
51
- slept = False
52
- if len(out_pyobjs) != 0:
53
- has_finished = any(
54
- [obj.finished_reason is not None for obj in out_pyobjs]
55
- )
56
- if has_finished:
57
- if self.request_dependency_delay > 0:
58
- slept = True
59
- await asyncio.sleep(self.request_dependency_delay)
60
-
61
- if not slept:
62
- await asyncio.sleep(global_config.wait_for_new_request_delay)
63
-
64
- async def loop_for_recv_requests(self):
93
+ def recv_requests_from_zmq(self):
94
+ recv_reqs = []
65
95
  while True:
66
- recv_req = await self.recv_from_tokenizer.recv_pyobj()
67
- self.recv_reqs.append(recv_req)
96
+ try:
97
+ recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
98
+ except zmq.ZMQError:
99
+ break
100
+ recv_reqs.append(recv_req)
101
+
102
+ return recv_reqs
103
+
104
+ def recv_requests_from_mp_queue(self):
105
+ recv_reqs = []
106
+ while not self.mp_queue.empty():
107
+ recv_reqs.append(self.mp_queue.get())
108
+ return recv_reqs
68
109
 
69
110
 
70
111
  def start_controller_process(
71
- server_args: ServerArgs, port_args: PortArgs, pipe_writer, model_overide_args
112
+ server_args: ServerArgs,
113
+ port_args: PortArgs,
114
+ pipe_writer: multiprocessing.connection.Connection,
115
+ model_overide_args: dict,
116
+ is_data_parallel_worker: bool = False,
117
+ gpu_ids: List[int] = None,
118
+ dp_worker_id: int = None,
119
+ queue: multiprocessing.connection.Connection = None,
72
120
  ):
121
+ """Start a controller process."""
122
+
73
123
  logging.basicConfig(
74
124
  level=getattr(logging, server_args.log_level.upper()),
75
125
  format="%(message)s",
76
126
  )
77
127
 
78
- try:
128
+ if not is_data_parallel_worker:
79
129
  tp_size_local = server_args.tp_size // server_args.nnodes
80
- model_client = ModelTpClient(
81
- [i for _ in range(server_args.nnodes) for i in range(tp_size_local)],
130
+ gpu_ids = [i for _ in range(server_args.nnodes) for i in range(tp_size_local)]
131
+ dp_worker_id = 0
132
+ queue = None
133
+
134
+ try:
135
+ controller = ControllerSingle(
82
136
  server_args,
83
- port_args.model_port_args[0],
137
+ port_args,
84
138
  model_overide_args,
139
+ gpu_ids,
140
+ is_data_parallel_worker,
141
+ dp_worker_id,
142
+ queue,
85
143
  )
86
- controller = ControllerSingle(model_client, port_args)
87
144
  except Exception:
88
145
  pipe_writer.send(get_exception_traceback())
89
146
  raise
90
147
 
91
148
  pipe_writer.send("init ok")
92
149
 
93
- loop = asyncio.new_event_loop()
94
- loop.set_default_executor(ThreadPoolExecutor(max_workers=256))
95
- asyncio.set_event_loop(loop)
96
- loop.create_task(controller.loop_for_recv_requests())
97
150
  try:
98
- loop.run_until_complete(controller.loop_for_forward())
151
+ controller.loop_for_forward()
99
152
  except Exception:
100
153
  logger.error("Exception in ControllerSingle:\n" + get_exception_traceback())
101
154
  finally:
155
+ for t in controller.tp_procs:
156
+ os.kill(t.pid, 9)
102
157
  kill_parent_process()
@@ -9,14 +9,24 @@ from typing import Optional, Type
9
9
 
10
10
  import torch
11
11
  import torch.nn as nn
12
+ from flashinfer import (
13
+ BatchDecodeWithPagedKVCacheWrapper,
14
+ BatchPrefillWithPagedKVCacheWrapper,
15
+ BatchPrefillWithRaggedKVCacheWrapper,
16
+ )
17
+ from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
12
18
  from vllm.config import DeviceConfig, LoadConfig
13
19
  from vllm.config import ModelConfig as VllmModelConfig
14
- from vllm.distributed import init_distributed_environment, initialize_model_parallel
20
+ from vllm.distributed import (
21
+ get_tp_group,
22
+ init_distributed_environment,
23
+ initialize_model_parallel,
24
+ )
15
25
  from vllm.model_executor.model_loader import get_model
16
26
  from vllm.model_executor.models import ModelRegistry
17
27
 
18
28
  from sglang.global_config import global_config
19
- from sglang.srt.managers.controller.infer_batch import Batch, ForwardMode, InputMetadata, global_server_args_dict
29
+ from sglang.srt.managers.controller.infer_batch import Batch, ForwardMode, InputMetadata
20
30
  from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
21
31
  from sglang.srt.server_args import ServerArgs
22
32
  from sglang.srt.utils import (
@@ -70,6 +80,7 @@ class ModelRunner:
70
80
  distributed_init_method=nccl_init_method,
71
81
  )
72
82
  initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
83
+ self.tp_group = get_tp_group()
73
84
  total_gpu_memory = get_available_gpu_memory(
74
85
  self.gpu_id, distributed=self.tp_size > 1
75
86
  )
@@ -81,10 +92,6 @@ class ModelRunner:
81
92
  "The memory capacity is unbalanced. Some GPUs may be occupied by other processes."
82
93
  )
83
94
 
84
- # Set some global args
85
- global_server_args_dict["disable_flashinfer"] = server_args.disable_flashinfer
86
- global_server_args_dict["attention_reduce_in_fp32"] = server_args.attention_reduce_in_fp32
87
-
88
95
  # Load the model and create memory pool
89
96
  self.load_model()
90
97
  self.init_memory_pool(total_gpu_memory)
@@ -116,6 +123,15 @@ class ModelRunner:
116
123
  if self.model_config.model_overide_args is not None:
117
124
  vllm_model_config.hf_config.update(self.model_config.model_overide_args)
118
125
 
126
+ if (
127
+ self.server_args.efficient_weight_load
128
+ and "llama" in self.server_args.model_path.lower()
129
+ and self.server_args.quantization == "fp8"
130
+ ):
131
+ from sglang.srt.model_loader.model_loader import get_model
132
+ else:
133
+ from vllm.model_executor.model_loader import get_model
134
+
119
135
  self.model = get_model(
120
136
  model_config=vllm_model_config,
121
137
  device_config=device_config,
@@ -161,7 +177,10 @@ class ModelRunner:
161
177
  )
162
178
 
163
179
  self.req_to_token_pool = ReqToTokenPool(
164
- int(self.max_total_num_tokens / self.model_config.context_len * 256),
180
+ max(
181
+ int(self.max_total_num_tokens / self.model_config.context_len * 512),
182
+ 2048,
183
+ ),
165
184
  self.model_config.context_len + 8,
166
185
  )
167
186
  self.token_to_kv_pool = TokenToKVPool(
@@ -192,13 +211,6 @@ class ModelRunner:
192
211
  self.flashinfer_decode_wrapper = None
193
212
  return
194
213
 
195
- from flashinfer import (
196
- BatchDecodeWithPagedKVCacheWrapper,
197
- BatchPrefillWithPagedKVCacheWrapper,
198
- BatchPrefillWithRaggedKVCacheWrapper,
199
- )
200
- from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
201
-
202
214
  if not _grouped_size_compiled_for_decode_kernels(
203
215
  self.model_config.num_attention_heads // self.tp_size,
204
216
  self.model_config.get_num_kv_heads(self.tp_size),
@@ -217,7 +229,9 @@ class ModelRunner:
217
229
  self.flashinfer_workspace_buffers[1], "NHD"
218
230
  )
219
231
  self.flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
220
- self.flashinfer_workspace_buffers[0], "NHD", use_tensor_cores=use_tensor_cores
232
+ self.flashinfer_workspace_buffers[0],
233
+ "NHD",
234
+ use_tensor_cores=use_tensor_cores,
221
235
  )
222
236
 
223
237
  def init_cuda_graphs(self):
@@ -228,9 +242,19 @@ class ModelRunner:
228
242
  return
229
243
 
230
244
  logger.info(f"[gpu_id={self.gpu_id}] Capture cuda graph begin.")
231
- batch_size_list = [1, 2, 4] + [i * 8 for i in range(1, 16)]
232
- self.cuda_graph_runner = CudaGraphRunner(self, max_batch_size_to_capture=max(batch_size_list))
233
- self.cuda_graph_runner.capture(batch_size_list)
245
+ batch_size_list = [1, 2, 4] + [i * 8 for i in range(1, 17)]
246
+ self.cuda_graph_runner = CudaGraphRunner(
247
+ self, max_batch_size_to_capture=max(batch_size_list)
248
+ )
249
+ try:
250
+ self.cuda_graph_runner.capture(batch_size_list)
251
+ except RuntimeError as e:
252
+ raise Exception(
253
+ f"Capture cuda graph failed {e}. Possible solutions:\n"
254
+ f"1. disable cuda graph by --disable-cuda-graph\n"
255
+ f"2. set --mem-fraction-static to a smaller value\n"
256
+ f"Open an issue on GitHub with reproducible scripts if you need help.\n"
257
+ )
234
258
 
235
259
  @torch.inference_mode()
236
260
  def forward_decode(self, batch: Batch):
@@ -82,12 +82,12 @@ class RadixCache:
82
82
 
83
83
  if self.disable:
84
84
  if del_in_memory_pool:
85
- self.token_to_kv_pool.dec_refs(indices)
85
+ self.token_to_kv_pool.free(indices)
86
86
  else:
87
87
  return torch.tensor([], dtype=torch.int64), self.root_node
88
88
 
89
89
  # Radix Cache takes one ref in memory pool
90
- self.token_to_kv_pool.dec_refs(indices[last_uncached_pos:new_prefix_len])
90
+ self.token_to_kv_pool.free(indices[last_uncached_pos:new_prefix_len])
91
91
 
92
92
  if del_in_memory_pool:
93
93
  self.req_to_token_pool.free(req_pool_idx)
@@ -125,7 +125,8 @@ class RadixCache:
125
125
  if x.lock_ref > 0:
126
126
  continue
127
127
 
128
- num_evicted += evict_callback(x.value)
128
+ evict_callback(x.value)
129
+ num_evicted += len(x.value)
129
130
  self._delete_leaf(x)
130
131
 
131
132
  if len(x.parent.children) == 0:
@@ -13,6 +13,10 @@ class ScheduleHeuristic:
13
13
  max_total_num_tokens,
14
14
  tree_cache,
15
15
  ):
16
+ if tree_cache.disable and schedule_heuristic == "lpm":
17
+ # LMP is not meaningless when tree cache is disabled.
18
+ schedule_heuristic = "fcfs"
19
+
16
20
  self.schedule_heuristic = schedule_heuristic
17
21
  self.max_running_seqs = max_running_seqs
18
22
  self.max_prefill_num_tokens = max_prefill_num_tokens