sglang 0.1.14__py3-none-any.whl → 0.1.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +59 -2
- sglang/api.py +40 -11
- sglang/backend/anthropic.py +17 -3
- sglang/backend/litellm.py +90 -0
- sglang/backend/openai.py +160 -12
- sglang/backend/runtime_endpoint.py +62 -27
- sglang/backend/vertexai.py +1 -0
- sglang/bench_latency.py +320 -0
- sglang/global_config.py +24 -3
- sglang/lang/chat_template.py +122 -6
- sglang/lang/compiler.py +2 -2
- sglang/lang/interpreter.py +206 -98
- sglang/lang/ir.py +98 -34
- sglang/lang/tracer.py +6 -4
- sglang/launch_server.py +4 -1
- sglang/launch_server_llavavid.py +32 -0
- sglang/srt/constrained/__init__.py +14 -6
- sglang/srt/constrained/fsm_cache.py +9 -2
- sglang/srt/constrained/jump_forward.py +113 -24
- sglang/srt/conversation.py +4 -2
- sglang/srt/flush_cache.py +18 -0
- sglang/srt/hf_transformers_utils.py +144 -3
- sglang/srt/layers/context_flashattention_nopad.py +1 -0
- sglang/srt/layers/extend_attention.py +20 -1
- sglang/srt/layers/fused_moe.py +596 -0
- sglang/srt/layers/logits_processor.py +190 -61
- sglang/srt/layers/radix_attention.py +62 -53
- sglang/srt/layers/token_attention.py +21 -9
- sglang/srt/managers/controller/cuda_graph_runner.py +196 -0
- sglang/srt/managers/controller/dp_worker.py +113 -0
- sglang/srt/managers/controller/infer_batch.py +908 -0
- sglang/srt/managers/controller/manager_multi.py +195 -0
- sglang/srt/managers/controller/manager_single.py +177 -0
- sglang/srt/managers/controller/model_runner.py +359 -0
- sglang/srt/managers/{router → controller}/radix_cache.py +102 -53
- sglang/srt/managers/controller/schedule_heuristic.py +65 -0
- sglang/srt/managers/controller/tp_worker.py +813 -0
- sglang/srt/managers/detokenizer_manager.py +42 -40
- sglang/srt/managers/io_struct.py +44 -10
- sglang/srt/managers/tokenizer_manager.py +224 -82
- sglang/srt/memory_pool.py +52 -59
- sglang/srt/model_config.py +97 -2
- sglang/srt/models/chatglm.py +399 -0
- sglang/srt/models/commandr.py +369 -0
- sglang/srt/models/dbrx.py +406 -0
- sglang/srt/models/gemma.py +34 -38
- sglang/srt/models/gemma2.py +436 -0
- sglang/srt/models/grok.py +738 -0
- sglang/srt/models/llama2.py +47 -37
- sglang/srt/models/llama_classification.py +107 -0
- sglang/srt/models/llava.py +92 -27
- sglang/srt/models/llavavid.py +298 -0
- sglang/srt/models/minicpm.py +366 -0
- sglang/srt/models/mixtral.py +302 -127
- sglang/srt/models/mixtral_quant.py +372 -0
- sglang/srt/models/qwen.py +40 -35
- sglang/srt/models/qwen2.py +33 -36
- sglang/srt/models/qwen2_moe.py +473 -0
- sglang/srt/models/stablelm.py +33 -39
- sglang/srt/models/yivl.py +19 -26
- sglang/srt/openai_api_adapter.py +411 -0
- sglang/srt/{managers/openai_protocol.py → openai_protocol.py} +44 -19
- sglang/srt/sampling_params.py +2 -0
- sglang/srt/server.py +197 -481
- sglang/srt/server_args.py +190 -74
- sglang/srt/utils.py +460 -95
- sglang/test/test_programs.py +73 -10
- sglang/test/test_utils.py +226 -7
- sglang/utils.py +97 -27
- {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/METADATA +74 -45
- sglang-0.1.21.dist-info/RECORD +82 -0
- {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/WHEEL +1 -1
- sglang/srt/backend_config.py +0 -13
- sglang/srt/managers/router/infer_batch.py +0 -503
- sglang/srt/managers/router/manager.py +0 -79
- sglang/srt/managers/router/model_rpc.py +0 -686
- sglang/srt/managers/router/model_runner.py +0 -514
- sglang/srt/managers/router/scheduler.py +0 -70
- sglang-0.1.14.dist-info/RECORD +0 -64
- {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/LICENSE +0 -0
- {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,196 @@
|
|
1
|
+
"""Run the model with cuda graph."""
|
2
|
+
|
3
|
+
import bisect
|
4
|
+
|
5
|
+
import torch
|
6
|
+
from vllm.distributed.parallel_state import graph_capture
|
7
|
+
|
8
|
+
from sglang.global_config import global_config
|
9
|
+
from sglang.srt.layers.logits_processor import LogitProcessorOutput
|
10
|
+
from sglang.srt.managers.controller.infer_batch import (
|
11
|
+
Batch,
|
12
|
+
ForwardMode,
|
13
|
+
InputMetadata,
|
14
|
+
init_flashinfer_args,
|
15
|
+
)
|
16
|
+
|
17
|
+
|
18
|
+
class CudaGraphRunner:
|
19
|
+
def __init__(self, model_runner, max_batch_size_to_capture):
|
20
|
+
self.model_runner = model_runner
|
21
|
+
self.graphs = {}
|
22
|
+
self.input_buffers = {}
|
23
|
+
self.output_buffers = {}
|
24
|
+
self.flashinfer_handlers = {}
|
25
|
+
self.graph_memory_pool = None
|
26
|
+
|
27
|
+
# Common inputs
|
28
|
+
self.max_bs = max_batch_size_to_capture
|
29
|
+
self.input_ids = torch.zeros((self.max_bs,), dtype=torch.int32, device="cuda")
|
30
|
+
self.req_pool_indices = torch.zeros(
|
31
|
+
(self.max_bs,), dtype=torch.int32, device="cuda"
|
32
|
+
)
|
33
|
+
self.seq_lens = torch.ones((self.max_bs,), dtype=torch.int32, device="cuda")
|
34
|
+
self.position_ids_offsets = torch.zeros(
|
35
|
+
(self.max_bs,), dtype=torch.int32, device="cuda"
|
36
|
+
)
|
37
|
+
self.out_cache_loc = torch.zeros(
|
38
|
+
(self.max_bs,), dtype=torch.int32, device="cuda"
|
39
|
+
)
|
40
|
+
|
41
|
+
# FlashInfer inputs
|
42
|
+
self.flashinfer_workspace_buffer = (
|
43
|
+
self.model_runner.flashinfer_workspace_buffers[0]
|
44
|
+
)
|
45
|
+
self.flashinfer_kv_indptr = torch.zeros(
|
46
|
+
(self.max_bs + 1,), dtype=torch.int32, device="cuda"
|
47
|
+
)
|
48
|
+
self.flashinfer_kv_indices = torch.zeros(
|
49
|
+
(self.max_bs * model_runner.model_config.context_len,),
|
50
|
+
dtype=torch.int32,
|
51
|
+
device="cuda",
|
52
|
+
)
|
53
|
+
self.flashinfer_kv_last_page_len = torch.ones(
|
54
|
+
(self.max_bs,), dtype=torch.int32, device="cuda"
|
55
|
+
)
|
56
|
+
|
57
|
+
def can_run(self, batch_size):
|
58
|
+
return batch_size < self.max_bs
|
59
|
+
|
60
|
+
def capture(self, batch_size_list):
|
61
|
+
self.batch_size_list = batch_size_list
|
62
|
+
with graph_capture() as graph_capture_context:
|
63
|
+
self.stream = graph_capture_context.stream
|
64
|
+
for bs in batch_size_list:
|
65
|
+
(
|
66
|
+
graph,
|
67
|
+
input_buffers,
|
68
|
+
output_buffers,
|
69
|
+
flashinfer_handler,
|
70
|
+
) = self.capture_one_batch_size(bs)
|
71
|
+
self.graphs[bs] = graph
|
72
|
+
self.input_buffers[bs] = input_buffers
|
73
|
+
self.output_buffers[bs] = output_buffers
|
74
|
+
self.flashinfer_handlers[bs] = flashinfer_handler
|
75
|
+
|
76
|
+
def capture_one_batch_size(self, bs):
|
77
|
+
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
|
78
|
+
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
|
79
|
+
|
80
|
+
graph = torch.cuda.CUDAGraph()
|
81
|
+
stream = self.stream
|
82
|
+
|
83
|
+
# Common inputs
|
84
|
+
input_ids = self.input_ids[:bs]
|
85
|
+
req_pool_indices = self.req_pool_indices[:bs]
|
86
|
+
seq_lens = self.seq_lens[:bs]
|
87
|
+
position_ids_offsets = self.position_ids_offsets[:bs]
|
88
|
+
out_cache_loc = self.out_cache_loc[:bs]
|
89
|
+
|
90
|
+
# FlashInfer inputs
|
91
|
+
if not _grouped_size_compiled_for_decode_kernels(
|
92
|
+
self.model_runner.model_config.num_attention_heads
|
93
|
+
// self.model_runner.tp_size,
|
94
|
+
self.model_runner.model_config.get_num_kv_heads(self.model_runner.tp_size),
|
95
|
+
):
|
96
|
+
use_tensor_cores = True
|
97
|
+
else:
|
98
|
+
use_tensor_cores = False
|
99
|
+
flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
|
100
|
+
self.flashinfer_workspace_buffer,
|
101
|
+
"NHD",
|
102
|
+
use_cuda_graph=True,
|
103
|
+
use_tensor_cores=use_tensor_cores,
|
104
|
+
paged_kv_indptr_buffer=self.flashinfer_kv_indptr[: bs + 1],
|
105
|
+
paged_kv_indices_buffer=self.flashinfer_kv_indices,
|
106
|
+
paged_kv_last_page_len_buffer=self.flashinfer_kv_last_page_len[:bs],
|
107
|
+
)
|
108
|
+
init_flashinfer_args(
|
109
|
+
ForwardMode.DECODE,
|
110
|
+
self.model_runner,
|
111
|
+
req_pool_indices,
|
112
|
+
seq_lens,
|
113
|
+
None,
|
114
|
+
flashinfer_decode_wrapper,
|
115
|
+
)
|
116
|
+
|
117
|
+
# Run and capture
|
118
|
+
def run_once():
|
119
|
+
input_metadata = InputMetadata.create(
|
120
|
+
self.model_runner,
|
121
|
+
forward_mode=ForwardMode.DECODE,
|
122
|
+
req_pool_indices=req_pool_indices,
|
123
|
+
seq_lens=seq_lens,
|
124
|
+
prefix_lens=None,
|
125
|
+
position_ids_offsets=position_ids_offsets,
|
126
|
+
out_cache_loc=out_cache_loc,
|
127
|
+
return_logprob=False,
|
128
|
+
top_logprobs_nums=0,
|
129
|
+
skip_flashinfer_init=True,
|
130
|
+
)
|
131
|
+
input_metadata.flashinfer_decode_wrapper = flashinfer_decode_wrapper
|
132
|
+
return self.model_runner.model.forward(
|
133
|
+
input_ids, input_metadata.positions, input_metadata
|
134
|
+
)
|
135
|
+
|
136
|
+
for _ in range(2):
|
137
|
+
run_once()
|
138
|
+
|
139
|
+
torch.cuda.synchronize()
|
140
|
+
with torch.cuda.graph(graph, pool=self.graph_memory_pool, stream=stream):
|
141
|
+
out = run_once()
|
142
|
+
torch.cuda.synchronize()
|
143
|
+
self.graph_memory_pool = graph.pool()
|
144
|
+
return graph, None, out, flashinfer_decode_wrapper
|
145
|
+
|
146
|
+
def replay(self, batch: Batch):
|
147
|
+
assert batch.out_cache_loc is not None
|
148
|
+
assert not batch.return_logprob
|
149
|
+
raw_bs = len(batch.reqs)
|
150
|
+
|
151
|
+
# Pad
|
152
|
+
index = bisect.bisect_left(self.batch_size_list, raw_bs)
|
153
|
+
bs = self.batch_size_list[index]
|
154
|
+
if bs != raw_bs:
|
155
|
+
self.seq_lens.zero_()
|
156
|
+
self.position_ids_offsets.fill_(1)
|
157
|
+
self.out_cache_loc.zero_()
|
158
|
+
|
159
|
+
# Common inputs
|
160
|
+
self.input_ids[:raw_bs] = batch.input_ids
|
161
|
+
self.req_pool_indices[:raw_bs] = batch.req_pool_indices
|
162
|
+
self.seq_lens[:raw_bs] = batch.seq_lens
|
163
|
+
self.position_ids_offsets[:raw_bs] = batch.position_ids_offsets
|
164
|
+
self.out_cache_loc[:raw_bs] = batch.out_cache_loc
|
165
|
+
|
166
|
+
# FlashInfer inputs
|
167
|
+
init_flashinfer_args(
|
168
|
+
ForwardMode.DECODE,
|
169
|
+
self.model_runner,
|
170
|
+
self.req_pool_indices[:bs],
|
171
|
+
self.seq_lens[:bs],
|
172
|
+
None,
|
173
|
+
self.flashinfer_handlers[bs],
|
174
|
+
)
|
175
|
+
|
176
|
+
# Replay
|
177
|
+
self.graphs[bs].replay()
|
178
|
+
output = self.output_buffers[bs]
|
179
|
+
|
180
|
+
# Unpad
|
181
|
+
if bs == raw_bs:
|
182
|
+
return output
|
183
|
+
else:
|
184
|
+
output = LogitProcessorOutput(
|
185
|
+
next_token_logits=output.next_token_logits[:raw_bs],
|
186
|
+
next_token_logprobs=output.next_token_logprobs[:raw_bs]
|
187
|
+
if output.next_token_logprobs is not None
|
188
|
+
else None,
|
189
|
+
normalized_prompt_logprobs=None,
|
190
|
+
prefill_token_logprobs=None,
|
191
|
+
prefill_top_logprobs=None,
|
192
|
+
decode_top_logprobs=output.decode_top_logprobs[:raw_bs]
|
193
|
+
if output.decode_top_logprobs is not None
|
194
|
+
else None,
|
195
|
+
)
|
196
|
+
return output
|
@@ -0,0 +1,113 @@
|
|
1
|
+
"""A data parallel worker thread."""
|
2
|
+
|
3
|
+
import asyncio
|
4
|
+
import logging
|
5
|
+
import queue
|
6
|
+
import threading
|
7
|
+
from typing import Callable, List
|
8
|
+
|
9
|
+
import uvloop
|
10
|
+
import zmq
|
11
|
+
|
12
|
+
from sglang.global_config import global_config
|
13
|
+
from sglang.srt.managers.controller.tp_worker import ModelTpClient
|
14
|
+
from sglang.srt.managers.io_struct import BatchTokenIDOut
|
15
|
+
from sglang.srt.server_args import PortArgs, ServerArgs
|
16
|
+
from sglang.srt.utils import kill_parent_process
|
17
|
+
from sglang.utils import get_exception_traceback
|
18
|
+
|
19
|
+
logger = logging.getLogger("srt.controller")
|
20
|
+
CHECKING_INTERVAL = 5
|
21
|
+
|
22
|
+
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
23
|
+
|
24
|
+
|
25
|
+
class DataParallelWorkerThread(threading.Thread):
|
26
|
+
def __init__(
|
27
|
+
self,
|
28
|
+
worker_id: int,
|
29
|
+
request_queue: queue.Queue,
|
30
|
+
detokenizer_port: int,
|
31
|
+
step_func: Callable,
|
32
|
+
):
|
33
|
+
super(DataParallelWorkerThread, self).__init__()
|
34
|
+
self.worker_id = worker_id
|
35
|
+
self.request_queue = request_queue
|
36
|
+
self.liveness = True
|
37
|
+
self.request_dependency_delay = global_config.request_dependency_delay
|
38
|
+
|
39
|
+
context = zmq.asyncio.Context()
|
40
|
+
self.send_to_detokenizer = context.socket(zmq.PUSH)
|
41
|
+
self.send_to_detokenizer.connect(f"tcp://127.0.0.1:{detokenizer_port}")
|
42
|
+
|
43
|
+
self.step = step_func
|
44
|
+
|
45
|
+
async def loop_for_forward(self):
|
46
|
+
while self.liveness:
|
47
|
+
requests = []
|
48
|
+
while not self.request_queue.empty():
|
49
|
+
requests.append(self.request_queue.get())
|
50
|
+
|
51
|
+
out_pyobjs: List[BatchTokenIDOut] = []
|
52
|
+
try:
|
53
|
+
out_pyobjs = await self.step(requests)
|
54
|
+
except Exception:
|
55
|
+
for r in requests:
|
56
|
+
self.request_queue.put(r)
|
57
|
+
logger.error(
|
58
|
+
f"Worker thread {self.worker_id}: "
|
59
|
+
f"failed to get back from Model Server\n"
|
60
|
+
f"{get_exception_traceback()}"
|
61
|
+
)
|
62
|
+
self.liveness = False
|
63
|
+
# Crash the whole server when there are any errors.
|
64
|
+
# TODO(lianmin): make this an option.
|
65
|
+
kill_parent_process()
|
66
|
+
return
|
67
|
+
|
68
|
+
for obj in out_pyobjs:
|
69
|
+
self.send_to_detokenizer.send_pyobj(obj)
|
70
|
+
|
71
|
+
# async sleep for receiving the subsequent request and avoiding cache miss
|
72
|
+
if len(out_pyobjs) != 0:
|
73
|
+
has_finished = any(
|
74
|
+
[obj.finished_reason is not None for obj in out_pyobjs]
|
75
|
+
)
|
76
|
+
if has_finished:
|
77
|
+
await asyncio.sleep(self.request_dependency_delay)
|
78
|
+
await asyncio.sleep(global_config.wait_for_new_request_delay)
|
79
|
+
|
80
|
+
async def monitoring(self):
|
81
|
+
while True:
|
82
|
+
await asyncio.sleep(CHECKING_INTERVAL)
|
83
|
+
# can plug in monitoring logic here
|
84
|
+
|
85
|
+
def run(self):
|
86
|
+
logger.info(f"DataParallelWorkerThread {self.worker_id} start")
|
87
|
+
loop = asyncio.new_event_loop()
|
88
|
+
asyncio.set_event_loop(loop)
|
89
|
+
loop.create_task(self.monitoring())
|
90
|
+
loop.run_until_complete(self.loop_for_forward())
|
91
|
+
|
92
|
+
|
93
|
+
def start_data_parallel_worker(
|
94
|
+
server_args: ServerArgs,
|
95
|
+
port_args: PortArgs,
|
96
|
+
model_overide_args,
|
97
|
+
gpu_ids: List[int],
|
98
|
+
worker_id: int,
|
99
|
+
):
|
100
|
+
model_tp_client = ModelTpClient(
|
101
|
+
gpu_ids,
|
102
|
+
server_args,
|
103
|
+
port_args.model_port_args[worker_id],
|
104
|
+
model_overide_args,
|
105
|
+
)
|
106
|
+
worker_thread = DataParallelWorkerThread(
|
107
|
+
worker_id=worker_id,
|
108
|
+
request_queue=queue.Queue(),
|
109
|
+
detokenizer_port=port_args.detokenizer_port,
|
110
|
+
step_func=model_tp_client.step,
|
111
|
+
)
|
112
|
+
worker_thread.start()
|
113
|
+
return worker_thread
|