sglang 0.1.15__py3-none-any.whl → 0.1.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. sglang/__init__.py +5 -1
  2. sglang/api.py +8 -3
  3. sglang/backend/anthropic.py +1 -1
  4. sglang/backend/litellm.py +90 -0
  5. sglang/backend/openai.py +148 -12
  6. sglang/backend/runtime_endpoint.py +18 -10
  7. sglang/global_config.py +11 -1
  8. sglang/lang/chat_template.py +9 -2
  9. sglang/lang/interpreter.py +161 -81
  10. sglang/lang/ir.py +29 -11
  11. sglang/lang/tracer.py +1 -1
  12. sglang/launch_server.py +1 -2
  13. sglang/launch_server_llavavid.py +31 -0
  14. sglang/srt/constrained/fsm_cache.py +3 -0
  15. sglang/srt/flush_cache.py +16 -0
  16. sglang/srt/hf_transformers_utils.py +83 -2
  17. sglang/srt/layers/extend_attention.py +17 -0
  18. sglang/srt/layers/fused_moe.py +485 -0
  19. sglang/srt/layers/logits_processor.py +12 -7
  20. sglang/srt/layers/radix_attention.py +10 -3
  21. sglang/srt/layers/token_attention.py +16 -1
  22. sglang/srt/managers/controller/dp_worker.py +110 -0
  23. sglang/srt/managers/controller/infer_batch.py +619 -0
  24. sglang/srt/managers/controller/manager_multi.py +191 -0
  25. sglang/srt/managers/controller/manager_single.py +97 -0
  26. sglang/srt/managers/controller/model_runner.py +462 -0
  27. sglang/srt/managers/controller/radix_cache.py +267 -0
  28. sglang/srt/managers/controller/schedule_heuristic.py +59 -0
  29. sglang/srt/managers/controller/tp_worker.py +791 -0
  30. sglang/srt/managers/detokenizer_manager.py +45 -45
  31. sglang/srt/managers/io_struct.py +26 -10
  32. sglang/srt/managers/router/infer_batch.py +130 -74
  33. sglang/srt/managers/router/manager.py +7 -9
  34. sglang/srt/managers/router/model_rpc.py +224 -135
  35. sglang/srt/managers/router/model_runner.py +94 -107
  36. sglang/srt/managers/router/radix_cache.py +54 -18
  37. sglang/srt/managers/router/scheduler.py +23 -34
  38. sglang/srt/managers/tokenizer_manager.py +183 -88
  39. sglang/srt/model_config.py +5 -2
  40. sglang/srt/models/commandr.py +15 -22
  41. sglang/srt/models/dbrx.py +22 -29
  42. sglang/srt/models/gemma.py +14 -24
  43. sglang/srt/models/grok.py +671 -0
  44. sglang/srt/models/llama2.py +24 -23
  45. sglang/srt/models/llava.py +85 -25
  46. sglang/srt/models/llavavid.py +298 -0
  47. sglang/srt/models/mixtral.py +254 -130
  48. sglang/srt/models/mixtral_quant.py +373 -0
  49. sglang/srt/models/qwen.py +28 -25
  50. sglang/srt/models/qwen2.py +17 -22
  51. sglang/srt/models/stablelm.py +21 -26
  52. sglang/srt/models/yivl.py +17 -25
  53. sglang/srt/openai_api_adapter.py +140 -95
  54. sglang/srt/openai_protocol.py +10 -1
  55. sglang/srt/server.py +101 -52
  56. sglang/srt/server_args.py +59 -11
  57. sglang/srt/utils.py +242 -75
  58. sglang/test/test_programs.py +44 -0
  59. sglang/test/test_utils.py +32 -1
  60. sglang/utils.py +95 -26
  61. {sglang-0.1.15.dist-info → sglang-0.1.17.dist-info}/METADATA +23 -13
  62. sglang-0.1.17.dist-info/RECORD +81 -0
  63. sglang/srt/backend_config.py +0 -13
  64. sglang/srt/models/dbrx_config.py +0 -281
  65. sglang/srt/weight_utils.py +0 -402
  66. sglang-0.1.15.dist-info/RECORD +0 -69
  67. {sglang-0.1.15.dist-info → sglang-0.1.17.dist-info}/LICENSE +0 -0
  68. {sglang-0.1.15.dist-info → sglang-0.1.17.dist-info}/WHEEL +0 -0
  69. {sglang-0.1.15.dist-info → sglang-0.1.17.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,191 @@
1
+ """
2
+ A controller that manages multiple data parallel workers.
3
+ Each data parallel worker can manage multiple tensor parallel workers.
4
+ """
5
+
6
+ import asyncio
7
+ import logging
8
+ from concurrent.futures import ThreadPoolExecutor
9
+ from enum import Enum, auto
10
+ from typing import Dict
11
+
12
+ import zmq
13
+ import zmq.asyncio
14
+
15
+ from sglang.global_config import global_config
16
+ from sglang.srt.managers.io_struct import (
17
+ AbortReq,
18
+ FlushCacheReq,
19
+ TokenizedGenerateReqInput,
20
+ )
21
+ from sglang.srt.managers.controller.dp_worker import (
22
+ DataParallelWorkerThread,
23
+ start_data_parallel_worker,
24
+ )
25
+ from sglang.srt.server_args import PortArgs, ServerArgs
26
+ from sglang.utils import get_exception_traceback
27
+
28
+ logger = logging.getLogger("srt.controller")
29
+
30
+
31
+ class LoadBalanceMethod(Enum):
32
+ ROUND_ROBIN = auto()
33
+ SHORTEST_QUEUE = auto()
34
+
35
+ @classmethod
36
+ def from_str(cls, method: str):
37
+ method = method.upper()
38
+ try:
39
+ return cls[method]
40
+ except KeyError as exc:
41
+ raise ValueError(f"Invalid load balance method: {method}") from exc
42
+
43
+
44
+ class Controller:
45
+ def __init__(
46
+ self,
47
+ load_balance_method: str,
48
+ server_args: ServerArgs,
49
+ port_args: PortArgs,
50
+ model_overide_args,
51
+ ):
52
+ self.load_balance_method = LoadBalanceMethod.from_str(load_balance_method)
53
+ self.server_args = server_args
54
+ self.port_args = port_args
55
+
56
+ if self.load_balance_method == LoadBalanceMethod.ROUND_ROBIN:
57
+ self.round_robin_counter = 0
58
+
59
+ self.dispatch_lookup = {
60
+ LoadBalanceMethod.ROUND_ROBIN: self.round_robin_scheduler,
61
+ LoadBalanceMethod.SHORTEST_QUEUE: self.shortest_queue_scheduler,
62
+ }
63
+ self.dispatching = self.dispatch_lookup[self.load_balance_method]
64
+
65
+ # Init communication
66
+ context = zmq.asyncio.Context()
67
+ self.recv_from_tokenizer = context.socket(zmq.PULL)
68
+ self.recv_from_tokenizer.bind(f"tcp://127.0.0.1:{port_args.router_port}")
69
+
70
+ # Init status
71
+ self.recv_reqs = []
72
+
73
+ # Start data parallel workers
74
+ self.workers: Dict[int, DataParallelWorkerThread] = {}
75
+ tp_size = server_args.tp_size
76
+
77
+ def start_dp_worker(i):
78
+ try:
79
+ gpu_ids = list(range(i * tp_size, (i + 1) * tp_size))
80
+ worker_thread = start_data_parallel_worker(
81
+ server_args, port_args, model_overide_args, gpu_ids, i
82
+ )
83
+ self.workers[i] = worker_thread
84
+ except Exception:
85
+ logger.error(
86
+ f"Failed to start local worker {i}\n{get_exception_traceback()}"
87
+ )
88
+
89
+ for i in range(server_args.dp_size):
90
+ start_dp_worker(i)
91
+
92
+ # Parallel launch is slower, probably due to the disk bandwidth limitations.
93
+ # with ThreadPoolExecutor(server_args.dp_size) as executor:
94
+ # executor.map(start_dp_worker, range(server_args.dp_size))
95
+
96
+ def have_any_live_worker(self):
97
+ return any(worker_thread.liveness for worker_thread in self.workers.values())
98
+
99
+ def put_req_to_worker(self, worker_id, req):
100
+ self.workers[worker_id].request_queue.put(req)
101
+
102
+ async def round_robin_scheduler(self, input_requests):
103
+ available_workers = list(self.workers.keys())
104
+ for r in input_requests:
105
+ self.put_req_to_worker(available_workers[self.round_robin_counter], r)
106
+ self.round_robin_counter = (self.round_robin_counter + 1) % len(
107
+ available_workers
108
+ )
109
+ return
110
+
111
+ async def shortest_queue_scheduler(self, input_requests):
112
+ for r in input_requests:
113
+ worker = min(
114
+ self.workers, key=lambda w: self.workers[w].request_queue.qsize()
115
+ )
116
+ self.put_req_to_worker(worker, r)
117
+ return
118
+
119
+ async def remove_dead_workers(self):
120
+ for i in list(self.workers.keys()):
121
+ worker_thread = self.workers[i]
122
+ if not worker_thread.liveness:
123
+ worker_thread.join()
124
+ # move unsuccessful requests back to the queue
125
+ while not worker_thread.request_queue.empty():
126
+ self.recv_reqs.append(worker_thread.request_queue.get())
127
+ del self.workers[i]
128
+ logger.info(f"Stale worker {i} removed")
129
+
130
+ async def loop_for_forward(self):
131
+ while True:
132
+ await self.remove_dead_workers()
133
+
134
+ if self.have_any_live_worker():
135
+ next_step_input = list(self.recv_reqs)
136
+ self.recv_reqs = []
137
+ if next_step_input:
138
+ await self.dispatching(next_step_input)
139
+ #else:
140
+ # logger.error("There is no live worker.")
141
+
142
+ await asyncio.sleep(global_config.wait_for_new_request_delay)
143
+
144
+ async def loop_for_recv_requests(self):
145
+ while True:
146
+ recv_req = await self.recv_from_tokenizer.recv_pyobj()
147
+ if isinstance(recv_req, FlushCacheReq):
148
+ # TODO(lsyin): apply more specific flushCacheReq
149
+ for worker_thread in self.workers.values():
150
+ worker_thread.request_queue.put(recv_req)
151
+ elif isinstance(recv_req, TokenizedGenerateReqInput):
152
+ self.recv_reqs.append(recv_req)
153
+ elif isinstance(recv_req, AbortReq):
154
+ in_queue = False
155
+ for i, req in enumerate(self.recv_reqs):
156
+ if req.rid == recv_req.rid:
157
+ self.recv_reqs[i] = recv_req
158
+ in_queue = True
159
+ break
160
+ if not in_queue:
161
+ # Send abort req to all TP groups
162
+ for worker in list(self.workers.keys()):
163
+ self.put_req_to_worker(worker, recv_req)
164
+ else:
165
+ logger.error(f"Invalid object: {recv_req}")
166
+
167
+
168
+ def start_controller_process(
169
+ server_args: ServerArgs,
170
+ port_args: PortArgs,
171
+ pipe_writer,
172
+ model_overide_args=None,
173
+ ):
174
+ logging.basicConfig(
175
+ level=getattr(logging, server_args.log_level.upper()),
176
+ format="%(message)s",
177
+ )
178
+
179
+ try:
180
+ controller = Controller(
181
+ server_args.load_balance_method, server_args, port_args, model_overide_args
182
+ )
183
+ except Exception:
184
+ pipe_writer.send(get_exception_traceback())
185
+ raise
186
+
187
+ pipe_writer.send("init ok")
188
+ loop = asyncio.get_event_loop()
189
+ asyncio.set_event_loop(loop)
190
+ loop.create_task(controller.loop_for_recv_requests())
191
+ loop.run_until_complete(controller.loop_for_forward())
@@ -0,0 +1,97 @@
1
+ """A controller that manages a group of tensor parallel workers."""
2
+ import asyncio
3
+ import logging
4
+ import time
5
+
6
+ import uvloop
7
+ import zmq
8
+ import zmq.asyncio
9
+
10
+ from sglang.global_config import global_config
11
+ from sglang.srt.managers.controller.tp_worker import ModelTpClient
12
+ from sglang.srt.server_args import PortArgs, ServerArgs
13
+ from sglang.srt.utils import kill_parent_process
14
+ from sglang.utils import get_exception_traceback
15
+
16
+ asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
17
+
18
+ logger = logging.getLogger("srt.controller")
19
+
20
+
21
+ class ControllerSingle:
22
+ def __init__(self, model_client: ModelTpClient, port_args: PortArgs):
23
+ # Init communication
24
+ context = zmq.asyncio.Context(2)
25
+ self.recv_from_tokenizer = context.socket(zmq.PULL)
26
+ self.recv_from_tokenizer.bind(f"tcp://127.0.0.1:{port_args.router_port}")
27
+
28
+ self.send_to_detokenizer = context.socket(zmq.PUSH)
29
+ self.send_to_detokenizer.connect(
30
+ f"tcp://127.0.0.1:{port_args.detokenizer_port}"
31
+ )
32
+
33
+ # Init status
34
+ self.model_client = model_client
35
+ self.recv_reqs = []
36
+
37
+ # Init some configs
38
+ self.request_dependency_delay = global_config.request_dependency_delay
39
+
40
+ async def loop_for_forward(self):
41
+ while True:
42
+ next_step_input = list(self.recv_reqs)
43
+ self.recv_reqs = []
44
+ out_pyobjs = await self.model_client.step(next_step_input)
45
+
46
+ for obj in out_pyobjs:
47
+ self.send_to_detokenizer.send_pyobj(obj)
48
+
49
+ # async sleep for receiving the subsequent request and avoiding cache miss
50
+ slept = False
51
+ if len(out_pyobjs) != 0:
52
+ has_finished = any([obj.finished_reason is not None for obj in out_pyobjs])
53
+ if has_finished:
54
+ if self.request_dependency_delay > 0:
55
+ slept = True
56
+ await asyncio.sleep(self.request_dependency_delay)
57
+
58
+ if not slept:
59
+ await asyncio.sleep(global_config.wait_for_new_request_delay)
60
+
61
+ async def loop_for_recv_requests(self):
62
+ while True:
63
+ recv_req = await self.recv_from_tokenizer.recv_pyobj()
64
+ self.recv_reqs.append(recv_req)
65
+
66
+
67
+ def start_controller_process(
68
+ server_args: ServerArgs, port_args: PortArgs, pipe_writer, model_overide_args
69
+ ):
70
+ logging.basicConfig(
71
+ level=getattr(logging, server_args.log_level.upper()),
72
+ format="%(message)s",
73
+ )
74
+
75
+ try:
76
+ model_client = ModelTpClient(
77
+ list(range(server_args.tp_size)),
78
+ server_args,
79
+ port_args.model_port_args[0],
80
+ model_overide_args,
81
+ )
82
+ controller = ControllerSingle(model_client, port_args)
83
+ except Exception:
84
+ pipe_writer.send(get_exception_traceback())
85
+ raise
86
+
87
+ pipe_writer.send("init ok")
88
+
89
+ loop = asyncio.new_event_loop()
90
+ asyncio.set_event_loop(loop)
91
+ loop.create_task(controller.loop_for_recv_requests())
92
+ try:
93
+ loop.run_until_complete(controller.loop_for_forward())
94
+ except Exception:
95
+ logger.error("Exception in ControllerSingle:\n" + get_exception_traceback())
96
+ finally:
97
+ kill_parent_process()