sglang 0.1.14__py3-none-any.whl → 0.1.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (81) hide show
  1. sglang/__init__.py +59 -2
  2. sglang/api.py +40 -11
  3. sglang/backend/anthropic.py +17 -3
  4. sglang/backend/litellm.py +90 -0
  5. sglang/backend/openai.py +160 -12
  6. sglang/backend/runtime_endpoint.py +62 -27
  7. sglang/backend/vertexai.py +1 -0
  8. sglang/bench_latency.py +320 -0
  9. sglang/global_config.py +24 -3
  10. sglang/lang/chat_template.py +122 -6
  11. sglang/lang/compiler.py +2 -2
  12. sglang/lang/interpreter.py +206 -98
  13. sglang/lang/ir.py +98 -34
  14. sglang/lang/tracer.py +6 -4
  15. sglang/launch_server.py +4 -1
  16. sglang/launch_server_llavavid.py +32 -0
  17. sglang/srt/constrained/__init__.py +14 -6
  18. sglang/srt/constrained/fsm_cache.py +9 -2
  19. sglang/srt/constrained/jump_forward.py +113 -24
  20. sglang/srt/conversation.py +4 -2
  21. sglang/srt/flush_cache.py +18 -0
  22. sglang/srt/hf_transformers_utils.py +144 -3
  23. sglang/srt/layers/context_flashattention_nopad.py +1 -0
  24. sglang/srt/layers/extend_attention.py +20 -1
  25. sglang/srt/layers/fused_moe.py +596 -0
  26. sglang/srt/layers/logits_processor.py +190 -61
  27. sglang/srt/layers/radix_attention.py +62 -53
  28. sglang/srt/layers/token_attention.py +21 -9
  29. sglang/srt/managers/controller/cuda_graph_runner.py +196 -0
  30. sglang/srt/managers/controller/dp_worker.py +113 -0
  31. sglang/srt/managers/controller/infer_batch.py +908 -0
  32. sglang/srt/managers/controller/manager_multi.py +195 -0
  33. sglang/srt/managers/controller/manager_single.py +177 -0
  34. sglang/srt/managers/controller/model_runner.py +359 -0
  35. sglang/srt/managers/{router → controller}/radix_cache.py +102 -53
  36. sglang/srt/managers/controller/schedule_heuristic.py +65 -0
  37. sglang/srt/managers/controller/tp_worker.py +813 -0
  38. sglang/srt/managers/detokenizer_manager.py +42 -40
  39. sglang/srt/managers/io_struct.py +44 -10
  40. sglang/srt/managers/tokenizer_manager.py +224 -82
  41. sglang/srt/memory_pool.py +52 -59
  42. sglang/srt/model_config.py +97 -2
  43. sglang/srt/models/chatglm.py +399 -0
  44. sglang/srt/models/commandr.py +369 -0
  45. sglang/srt/models/dbrx.py +406 -0
  46. sglang/srt/models/gemma.py +34 -38
  47. sglang/srt/models/gemma2.py +436 -0
  48. sglang/srt/models/grok.py +738 -0
  49. sglang/srt/models/llama2.py +47 -37
  50. sglang/srt/models/llama_classification.py +107 -0
  51. sglang/srt/models/llava.py +92 -27
  52. sglang/srt/models/llavavid.py +298 -0
  53. sglang/srt/models/minicpm.py +366 -0
  54. sglang/srt/models/mixtral.py +302 -127
  55. sglang/srt/models/mixtral_quant.py +372 -0
  56. sglang/srt/models/qwen.py +40 -35
  57. sglang/srt/models/qwen2.py +33 -36
  58. sglang/srt/models/qwen2_moe.py +473 -0
  59. sglang/srt/models/stablelm.py +33 -39
  60. sglang/srt/models/yivl.py +19 -26
  61. sglang/srt/openai_api_adapter.py +411 -0
  62. sglang/srt/{managers/openai_protocol.py → openai_protocol.py} +44 -19
  63. sglang/srt/sampling_params.py +2 -0
  64. sglang/srt/server.py +197 -481
  65. sglang/srt/server_args.py +190 -74
  66. sglang/srt/utils.py +460 -95
  67. sglang/test/test_programs.py +73 -10
  68. sglang/test/test_utils.py +226 -7
  69. sglang/utils.py +97 -27
  70. {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/METADATA +74 -45
  71. sglang-0.1.21.dist-info/RECORD +82 -0
  72. {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/WHEEL +1 -1
  73. sglang/srt/backend_config.py +0 -13
  74. sglang/srt/managers/router/infer_batch.py +0 -503
  75. sglang/srt/managers/router/manager.py +0 -79
  76. sglang/srt/managers/router/model_rpc.py +0 -686
  77. sglang/srt/managers/router/model_runner.py +0 -514
  78. sglang/srt/managers/router/scheduler.py +0 -70
  79. sglang-0.1.14.dist-info/RECORD +0 -64
  80. {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/LICENSE +0 -0
  81. {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,195 @@
1
+ """
2
+ A controller that manages multiple data parallel workers.
3
+ Each data parallel worker can manage multiple tensor parallel workers.
4
+ """
5
+
6
+ import asyncio
7
+ import logging
8
+ from concurrent.futures import ThreadPoolExecutor
9
+ from enum import Enum, auto
10
+ from typing import Dict
11
+
12
+ import zmq
13
+ import zmq.asyncio
14
+
15
+ from sglang.global_config import global_config
16
+ from sglang.srt.managers.controller.dp_worker import (
17
+ DataParallelWorkerThread,
18
+ start_data_parallel_worker,
19
+ )
20
+ from sglang.srt.managers.io_struct import (
21
+ AbortReq,
22
+ FlushCacheReq,
23
+ TokenizedGenerateReqInput,
24
+ )
25
+ from sglang.srt.server_args import PortArgs, ServerArgs
26
+ from sglang.utils import get_exception_traceback
27
+
28
+ logger = logging.getLogger("srt.controller")
29
+
30
+
31
+ class LoadBalanceMethod(Enum):
32
+ ROUND_ROBIN = auto()
33
+ SHORTEST_QUEUE = auto()
34
+
35
+ @classmethod
36
+ def from_str(cls, method: str):
37
+ method = method.upper()
38
+ try:
39
+ return cls[method]
40
+ except KeyError as exc:
41
+ raise ValueError(f"Invalid load balance method: {method}") from exc
42
+
43
+
44
+ class Controller:
45
+ """A controller that manages multiple data parallel workers."""
46
+
47
+ def __init__(
48
+ self,
49
+ load_balance_method: str,
50
+ server_args: ServerArgs,
51
+ port_args: PortArgs,
52
+ model_overide_args,
53
+ ):
54
+ self.load_balance_method = LoadBalanceMethod.from_str(load_balance_method)
55
+ self.server_args = server_args
56
+ self.port_args = port_args
57
+
58
+ if self.load_balance_method == LoadBalanceMethod.ROUND_ROBIN:
59
+ self.round_robin_counter = 0
60
+
61
+ self.dispatch_lookup = {
62
+ LoadBalanceMethod.ROUND_ROBIN: self.round_robin_scheduler,
63
+ LoadBalanceMethod.SHORTEST_QUEUE: self.shortest_queue_scheduler,
64
+ }
65
+ self.dispatching = self.dispatch_lookup[self.load_balance_method]
66
+
67
+ # Init communication
68
+ context = zmq.asyncio.Context()
69
+ self.recv_from_tokenizer = context.socket(zmq.PULL)
70
+ self.recv_from_tokenizer.bind(f"tcp://127.0.0.1:{port_args.router_port}")
71
+
72
+ # Init status
73
+ self.recv_reqs = []
74
+
75
+ # Start data parallel workers
76
+ self.workers: Dict[int, DataParallelWorkerThread] = {}
77
+ tp_size = server_args.tp_size
78
+
79
+ def start_dp_worker(i):
80
+ try:
81
+ gpu_ids = list(range(i * tp_size, (i + 1) * tp_size))
82
+ worker_thread = start_data_parallel_worker(
83
+ server_args, port_args, model_overide_args, gpu_ids, i
84
+ )
85
+ self.workers[i] = worker_thread
86
+ except Exception:
87
+ logger.error(
88
+ f"Failed to start local worker {i}\n{get_exception_traceback()}"
89
+ )
90
+
91
+ for i in range(server_args.dp_size):
92
+ start_dp_worker(i)
93
+
94
+ # Parallel launch is slower, probably due to the disk bandwidth limitations.
95
+ # with ThreadPoolExecutor(server_args.dp_size) as executor:
96
+ # executor.map(start_dp_worker, range(server_args.dp_size))
97
+
98
+ def have_any_live_worker(self):
99
+ return any(worker_thread.liveness for worker_thread in self.workers.values())
100
+
101
+ def put_req_to_worker(self, worker_id, req):
102
+ self.workers[worker_id].request_queue.put(req)
103
+
104
+ async def round_robin_scheduler(self, input_requests):
105
+ available_workers = list(self.workers.keys())
106
+ for r in input_requests:
107
+ self.put_req_to_worker(available_workers[self.round_robin_counter], r)
108
+ self.round_robin_counter = (self.round_robin_counter + 1) % len(
109
+ available_workers
110
+ )
111
+ return
112
+
113
+ async def shortest_queue_scheduler(self, input_requests):
114
+ for r in input_requests:
115
+ worker = min(
116
+ self.workers, key=lambda w: self.workers[w].request_queue.qsize()
117
+ )
118
+ self.put_req_to_worker(worker, r)
119
+ return
120
+
121
+ async def remove_dead_workers(self):
122
+ for i in list(self.workers.keys()):
123
+ worker_thread = self.workers[i]
124
+ if not worker_thread.liveness:
125
+ worker_thread.join()
126
+ # move unsuccessful requests back to the queue
127
+ while not worker_thread.request_queue.empty():
128
+ self.recv_reqs.append(worker_thread.request_queue.get())
129
+ del self.workers[i]
130
+ logger.info(f"Stale worker {i} removed")
131
+
132
+ async def loop_for_forward(self):
133
+ while True:
134
+ await self.remove_dead_workers()
135
+
136
+ if self.have_any_live_worker():
137
+ next_step_input = list(self.recv_reqs)
138
+ self.recv_reqs = []
139
+ if next_step_input:
140
+ await self.dispatching(next_step_input)
141
+ # else:
142
+ # logger.error("There is no live worker.")
143
+
144
+ await asyncio.sleep(global_config.wait_for_new_request_delay)
145
+
146
+ async def loop_for_recv_requests(self):
147
+ while True:
148
+ recv_req = await self.recv_from_tokenizer.recv_pyobj()
149
+ if isinstance(recv_req, FlushCacheReq):
150
+ # TODO(lsyin): apply more specific flushCacheReq
151
+ for worker_thread in self.workers.values():
152
+ worker_thread.request_queue.put(recv_req)
153
+ elif isinstance(recv_req, TokenizedGenerateReqInput):
154
+ self.recv_reqs.append(recv_req)
155
+ elif isinstance(recv_req, AbortReq):
156
+ in_queue = False
157
+ for i, req in enumerate(self.recv_reqs):
158
+ if req.rid == recv_req.rid:
159
+ self.recv_reqs[i] = recv_req
160
+ in_queue = True
161
+ break
162
+ if not in_queue:
163
+ # Send abort req to all TP groups
164
+ for worker in list(self.workers.keys()):
165
+ self.put_req_to_worker(worker, recv_req)
166
+ else:
167
+ logger.error(f"Invalid object: {recv_req}")
168
+
169
+
170
+ def start_controller_process(
171
+ server_args: ServerArgs,
172
+ port_args: PortArgs,
173
+ pipe_writer,
174
+ model_overide_args=None,
175
+ ):
176
+ logging.basicConfig(
177
+ level=getattr(logging, server_args.log_level.upper()),
178
+ format="%(message)s",
179
+ )
180
+
181
+ try:
182
+ controller = Controller(
183
+ server_args.load_balance_method, server_args, port_args, model_overide_args
184
+ )
185
+ except Exception:
186
+ pipe_writer.send(get_exception_traceback())
187
+ raise
188
+ pipe_writer.send("init ok")
189
+
190
+ loop = asyncio.new_event_loop()
191
+ loop.set_default_executor(ThreadPoolExecutor(max_workers=256))
192
+
193
+ asyncio.set_event_loop(loop)
194
+ loop.create_task(controller.loop_for_recv_requests())
195
+ loop.run_until_complete(controller.loop_for_forward())
@@ -0,0 +1,177 @@
1
+ """A controller that manages a group of tensor parallel workers."""
2
+
3
+ import multiprocessing
4
+ import logging
5
+ import os
6
+ import pickle
7
+
8
+ import torch
9
+ import torch.distributed as dist
10
+ import zmq
11
+ import zmq.asyncio
12
+
13
+ from sglang.srt.managers.controller.tp_worker import ModelTpServer
14
+ from sglang.srt.server_args import PortArgs, ServerArgs, ModelPortArgs
15
+ from sglang.srt.utils import kill_parent_process
16
+ from sglang.utils import get_exception_traceback
17
+
18
+
19
+ logger = logging.getLogger("srt.controller")
20
+
21
+
22
+ def run_tp_server(
23
+ gpu_id: int,
24
+ tp_rank: int,
25
+ server_args: ServerArgs,
26
+ model_port_args: ModelPortArgs,
27
+ model_overide_args: dict,
28
+ ):
29
+ """Run a tp server."""
30
+ try:
31
+ model_server = ModelTpServer(
32
+ gpu_id,
33
+ tp_rank,
34
+ server_args,
35
+ model_port_args,
36
+ model_overide_args,
37
+ )
38
+ tp_cpu_group = model_server.model_runner.tp_group.cpu_group
39
+
40
+ while True:
41
+ recv_reqs = broadcast_recv_input(None, tp_rank, tp_cpu_group)
42
+ model_server.exposed_step(recv_reqs)
43
+ except Exception:
44
+ logger.error("Exception in run_tp_server:\n" + get_exception_traceback())
45
+ raise
46
+
47
+
48
+ def launch_tp_servers(gpu_ids, tp_rank_range, server_args,
49
+ model_port_args, model_overide_args):
50
+ """Launch multiple tp servers."""
51
+ procs = []
52
+ for i in tp_rank_range:
53
+ proc = multiprocessing.Process(target=run_tp_server, args=(
54
+ gpu_ids[i], i, server_args, model_port_args, model_overide_args
55
+ ))
56
+ proc.start()
57
+ procs.append(proc)
58
+
59
+ return procs
60
+
61
+
62
+ def broadcast_recv_input(data, rank, dist_group):
63
+ """Broadcast inputs from rank=0 to all other ranks with torch.dist backend."""
64
+
65
+ if rank == 0:
66
+ if len(data) == 0:
67
+ tensor_size = torch.tensor([0], dtype=torch.long)
68
+ dist.broadcast(tensor_size, src=0, group=dist_group)
69
+ else:
70
+ serialized_data = pickle.dumps(data)
71
+ size = len(serialized_data)
72
+ tensor_data = torch.ByteTensor(list(serialized_data))
73
+ tensor_size = torch.tensor([size], dtype=torch.long)
74
+
75
+ dist.broadcast(tensor_size, src=0, group=dist_group)
76
+ dist.broadcast(tensor_data, src=0, group=dist_group)
77
+ else:
78
+ tensor_size = torch.tensor([0], dtype=torch.long)
79
+ dist.broadcast(tensor_size, src=0, group=dist_group)
80
+ size = tensor_size.item()
81
+
82
+ if size == 0:
83
+ return []
84
+
85
+ tensor_data = torch.empty(size, dtype=torch.uint8)
86
+ dist.broadcast(tensor_data, src=0, group=dist_group)
87
+
88
+ serialized_data = bytes(tensor_data.tolist())
89
+ data = pickle.loads(serialized_data)
90
+ return data
91
+
92
+
93
+ class ControllerSingle:
94
+ """A controller that manages a group of tensor parallel workers."""
95
+
96
+ def __init__(self, server_args: ServerArgs, port_args: PortArgs, model_overide_args: dict):
97
+ # Parse args
98
+ self.server_args = server_args
99
+
100
+ # Init communication
101
+ context = zmq.Context(2)
102
+ self.recv_from_tokenizer = context.socket(zmq.PULL)
103
+ self.recv_from_tokenizer.bind(f"tcp://127.0.0.1:{port_args.router_port}")
104
+
105
+ self.send_to_detokenizer = context.socket(zmq.PUSH)
106
+ self.send_to_detokenizer.connect(
107
+ f"tcp://127.0.0.1:{port_args.detokenizer_port}"
108
+ )
109
+
110
+ # Init model server
111
+ tp_size_local = server_args.tp_size // server_args.nnodes
112
+ gpu_ids = [i for _ in range(server_args.nnodes) for i in range(tp_size_local)]
113
+
114
+ # Launch other tp ranks
115
+ if tp_size_local > 1:
116
+ tp_rank_range = range(1, tp_size_local)
117
+ self.tp_procs = launch_tp_servers(
118
+ gpu_ids, tp_rank_range, server_args,
119
+ port_args.model_port_args[0], model_overide_args)
120
+
121
+ # Launch tp rank 0
122
+ self.tp_server = ModelTpServer(
123
+ gpu_ids[0],
124
+ 0,
125
+ server_args,
126
+ port_args.model_port_args[0],
127
+ model_overide_args,
128
+ )
129
+ self.tp_cpu_group = self.tp_server.model_runner.tp_group.cpu_group
130
+
131
+ def loop_for_forward(self):
132
+ while True:
133
+ recv_reqs = self.recv_requests()
134
+
135
+ if self.server_args.tp_size > 1:
136
+ broadcast_recv_input(recv_reqs, 0, self.tp_cpu_group)
137
+
138
+ out_pyobjs = self.tp_server.exposed_step(recv_reqs)
139
+
140
+ for obj in out_pyobjs:
141
+ self.send_to_detokenizer.send_pyobj(obj)
142
+
143
+ def recv_requests(self):
144
+ recv_reqs = []
145
+ while True:
146
+ try:
147
+ recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
148
+ recv_reqs.append(recv_req)
149
+ except zmq.ZMQError:
150
+ break
151
+ return recv_reqs
152
+
153
+
154
+ def start_controller_process(
155
+ server_args: ServerArgs, port_args: PortArgs, pipe_writer, model_overide_args: dict
156
+ ):
157
+ logging.basicConfig(
158
+ level=getattr(logging, server_args.log_level.upper()),
159
+ format="%(message)s",
160
+ )
161
+
162
+ try:
163
+ controller = ControllerSingle(server_args, port_args, model_overide_args)
164
+ except Exception:
165
+ pipe_writer.send(get_exception_traceback())
166
+ raise
167
+
168
+ pipe_writer.send("init ok")
169
+
170
+ try:
171
+ controller.loop_for_forward()
172
+ except Exception:
173
+ logger.error("Exception in ControllerSingle:\n" + get_exception_traceback())
174
+ finally:
175
+ for t in controller.tp_procs:
176
+ os.kill(t.pid, 9)
177
+ kill_parent_process()