sglang 0.1.3__py3-none-any.whl → 0.1.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +1 -1
- sglang/api.py +1 -0
- sglang/backend/vertexai.py +147 -0
- sglang/lang/interpreter.py +8 -9
- sglang/lang/ir.py +21 -0
- sglang/srt/layers/context_flashattention_nopad.py +7 -1
- sglang/srt/layers/extend_attention.py +46 -1
- sglang/srt/managers/router/manager.py +2 -2
- sglang/srt/managers/router/model_rpc.py +7 -3
- sglang/srt/managers/router/model_runner.py +1 -1
- sglang/srt/models/mixtral.py +1 -1
- sglang/srt/server_args.py +22 -4
- sglang/srt/utils.py +1 -1
- sglang/test/test_programs.py +4 -1
- {sglang-0.1.3.dist-info → sglang-0.1.5.dist-info}/METADATA +44 -12
- {sglang-0.1.3.dist-info → sglang-0.1.5.dist-info}/RECORD +19 -20
- sglang/backend/huggingface.py +0 -349
- sglang/backend/tgi.py +0 -190
- {sglang-0.1.3.dist-info → sglang-0.1.5.dist-info}/LICENSE +0 -0
- {sglang-0.1.3.dist-info → sglang-0.1.5.dist-info}/WHEEL +0 -0
- {sglang-0.1.3.dist-info → sglang-0.1.5.dist-info}/top_level.txt +0 -0
sglang/__init__.py
CHANGED
sglang/api.py
CHANGED
@@ -6,6 +6,7 @@ from sglang.backend.anthropic import Anthropic
|
|
6
6
|
from sglang.backend.base_backend import BaseBackend
|
7
7
|
from sglang.backend.openai import OpenAI
|
8
8
|
from sglang.backend.runtime_endpoint import RuntimeEndpoint
|
9
|
+
from sglang.backend.vertexai import VertexAI
|
9
10
|
from sglang.global_config import global_config
|
10
11
|
from sglang.lang.ir import (
|
11
12
|
SglExpr,
|
@@ -0,0 +1,147 @@
|
|
1
|
+
import os
|
2
|
+
import warnings
|
3
|
+
from typing import List, Optional, Union
|
4
|
+
|
5
|
+
import numpy as np
|
6
|
+
from sglang.backend.base_backend import BaseBackend
|
7
|
+
from sglang.lang.chat_template import get_chat_template
|
8
|
+
from sglang.lang.interpreter import StreamExecutor
|
9
|
+
from sglang.lang.ir import SglSamplingParams
|
10
|
+
|
11
|
+
try:
|
12
|
+
import vertexai
|
13
|
+
from vertexai.preview.generative_models import (
|
14
|
+
GenerationConfig,
|
15
|
+
GenerativeModel,
|
16
|
+
Image,
|
17
|
+
)
|
18
|
+
except ImportError as e:
|
19
|
+
GenerativeModel = e
|
20
|
+
|
21
|
+
|
22
|
+
class VertexAI(BaseBackend):
|
23
|
+
def __init__(self, model_name):
|
24
|
+
super().__init__()
|
25
|
+
|
26
|
+
if isinstance(GenerativeModel, Exception):
|
27
|
+
raise GenerativeModel
|
28
|
+
|
29
|
+
project_id = os.environ["GCP_PROJECT_ID"]
|
30
|
+
location = os.environ.get("GCP_LOCATION")
|
31
|
+
vertexai.init(project=project_id, location=location)
|
32
|
+
|
33
|
+
self.model_name = model_name
|
34
|
+
self.chat_template = get_chat_template("default")
|
35
|
+
|
36
|
+
def get_chat_template(self):
|
37
|
+
return self.chat_template
|
38
|
+
|
39
|
+
def generate(
|
40
|
+
self,
|
41
|
+
s: StreamExecutor,
|
42
|
+
sampling_params: SglSamplingParams,
|
43
|
+
):
|
44
|
+
if s.messages_:
|
45
|
+
prompt = self.messages_to_vertexai_input(s.messages_)
|
46
|
+
else:
|
47
|
+
# single-turn
|
48
|
+
prompt = (
|
49
|
+
self.text_to_vertexai_input(s.text_, s.cur_images)
|
50
|
+
if s.cur_images
|
51
|
+
else s.text_
|
52
|
+
)
|
53
|
+
ret = GenerativeModel(self.model_name).generate_content(
|
54
|
+
prompt,
|
55
|
+
generation_config=GenerationConfig(**sampling_params.to_vertexai_kwargs()),
|
56
|
+
)
|
57
|
+
|
58
|
+
comp = ret.text
|
59
|
+
|
60
|
+
return comp, {}
|
61
|
+
|
62
|
+
def generate_stream(
|
63
|
+
self,
|
64
|
+
s: StreamExecutor,
|
65
|
+
sampling_params: SglSamplingParams,
|
66
|
+
):
|
67
|
+
if s.messages_:
|
68
|
+
prompt = self.messages_to_vertexai_input(s.messages_)
|
69
|
+
else:
|
70
|
+
# single-turn
|
71
|
+
prompt = (
|
72
|
+
self.text_to_vertexai_input(s.text_, s.cur_images)
|
73
|
+
if s.cur_images
|
74
|
+
else s.text_
|
75
|
+
)
|
76
|
+
generator = GenerativeModel(self.model_name).generate_content(
|
77
|
+
prompt,
|
78
|
+
stream=True,
|
79
|
+
generation_config=GenerationConfig(**sampling_params.to_vertexai_kwargs()),
|
80
|
+
)
|
81
|
+
for ret in generator:
|
82
|
+
yield ret.text, {}
|
83
|
+
|
84
|
+
def text_to_vertexai_input(self, text, images):
|
85
|
+
input = []
|
86
|
+
# split with image token
|
87
|
+
text_segs = text.split(self.chat_template.image_token)
|
88
|
+
for image_path, image_base64_data in images:
|
89
|
+
text_seg = text_segs.pop(0)
|
90
|
+
if text_seg != "":
|
91
|
+
input.append(text_seg)
|
92
|
+
input.append(Image.from_bytes(image_base64_data))
|
93
|
+
text_seg = text_segs.pop(0)
|
94
|
+
if text_seg != "":
|
95
|
+
input.append(text_seg)
|
96
|
+
return input
|
97
|
+
|
98
|
+
def messages_to_vertexai_input(self, messages):
|
99
|
+
vertexai_message = []
|
100
|
+
# from openai message format to vertexai message format
|
101
|
+
for msg in messages:
|
102
|
+
if isinstance(msg["content"], str):
|
103
|
+
text = msg["content"]
|
104
|
+
else:
|
105
|
+
text = msg["content"][0]["text"]
|
106
|
+
|
107
|
+
if msg["role"] == "system":
|
108
|
+
warnings.warn("Warning: system prompt is not supported in VertexAI.")
|
109
|
+
vertexai_message.append(
|
110
|
+
{
|
111
|
+
"role": "user",
|
112
|
+
"parts": [{"text": "System prompt: " + text}],
|
113
|
+
}
|
114
|
+
)
|
115
|
+
vertexai_message.append(
|
116
|
+
{
|
117
|
+
"role": "model",
|
118
|
+
"parts": [{"text": "Understood."}],
|
119
|
+
}
|
120
|
+
)
|
121
|
+
continue
|
122
|
+
if msg["role"] == "user":
|
123
|
+
vertexai_msg = {
|
124
|
+
"role": "user",
|
125
|
+
"parts": [{"text": text}],
|
126
|
+
}
|
127
|
+
elif msg["role"] == "assistant":
|
128
|
+
vertexai_msg = {
|
129
|
+
"role": "model",
|
130
|
+
"parts": [{"text": text}],
|
131
|
+
}
|
132
|
+
|
133
|
+
# images
|
134
|
+
if isinstance(msg["content"], list) and len(msg["content"]) > 1:
|
135
|
+
for image in msg["content"][1:]:
|
136
|
+
assert image["type"] == "image_url"
|
137
|
+
vertexai_msg["parts"].append(
|
138
|
+
{
|
139
|
+
"inline_data": {
|
140
|
+
"data": image["image_url"]["url"].split(",")[1],
|
141
|
+
"mime_type": "image/jpeg",
|
142
|
+
}
|
143
|
+
}
|
144
|
+
)
|
145
|
+
|
146
|
+
vertexai_message.append(vertexai_msg)
|
147
|
+
return vertexai_message
|
sglang/lang/interpreter.py
CHANGED
@@ -365,11 +365,10 @@ class StreamExecutor:
|
|
365
365
|
for comp, meta_info in generator:
|
366
366
|
self.text_ += comp
|
367
367
|
self.variables[name] += comp
|
368
|
+
self.meta_info[name] = meta_info
|
368
369
|
self.stream_var_event[name].set()
|
369
370
|
self.stream_text_event.set()
|
370
371
|
|
371
|
-
self.meta_info[name] = meta_info
|
372
|
-
|
373
372
|
self.variable_event[name].set()
|
374
373
|
self.stream_var_event[name].set()
|
375
374
|
|
@@ -428,6 +427,7 @@ class StreamExecutor:
|
|
428
427
|
self.messages_.append(last_msg)
|
429
428
|
self.cur_images = []
|
430
429
|
else:
|
430
|
+
# OpenAI chat API format
|
431
431
|
self.messages_.append({"role": expr.role, "content": new_text})
|
432
432
|
|
433
433
|
self.cur_role = None
|
@@ -582,7 +582,7 @@ class ProgramState:
|
|
582
582
|
else:
|
583
583
|
yield self.get_var(name)
|
584
584
|
|
585
|
-
async def text_async_iter(self, var_name=None):
|
585
|
+
async def text_async_iter(self, var_name=None, return_meta_data=False):
|
586
586
|
loop = asyncio.get_running_loop()
|
587
587
|
|
588
588
|
if self.stream_executor.stream:
|
@@ -606,7 +606,10 @@ class ProgramState:
|
|
606
606
|
out = str(self.stream_executor.variables[var_name][prev:])
|
607
607
|
prev += len(out)
|
608
608
|
if out:
|
609
|
-
|
609
|
+
if return_meta_data:
|
610
|
+
yield out, self.stream_executor.meta_info[var_name]
|
611
|
+
else:
|
612
|
+
yield out
|
610
613
|
if self.stream_executor.variable_event[var_name].is_set():
|
611
614
|
break
|
612
615
|
else:
|
@@ -632,11 +635,7 @@ class ProgramState:
|
|
632
635
|
self.stream_executor.end()
|
633
636
|
|
634
637
|
def __repr__(self) -> str:
|
635
|
-
|
636
|
-
ret = ""
|
637
|
-
for msg in msgs:
|
638
|
-
ret += msg["role"] + ":\n" + msg["content"] + "\n"
|
639
|
-
return ret
|
638
|
+
return f"ProgramState({self.text()})"
|
640
639
|
|
641
640
|
|
642
641
|
class ProgramStateGroup:
|
sglang/lang/ir.py
CHANGED
@@ -2,6 +2,7 @@
|
|
2
2
|
|
3
3
|
import dataclasses
|
4
4
|
import inspect
|
5
|
+
import warnings
|
5
6
|
from typing import List, Optional, Union
|
6
7
|
|
7
8
|
from sglang.global_config import global_config
|
@@ -40,6 +41,8 @@ class SglSamplingParams:
|
|
40
41
|
|
41
42
|
def to_openai_kwargs(self):
|
42
43
|
# OpenAI does not support top_k, so we drop it here
|
44
|
+
if self.regex is not None:
|
45
|
+
warnings.warn("Regular expression is not supported in the OpenAI backend.")
|
43
46
|
return {
|
44
47
|
"max_tokens": self.max_new_tokens,
|
45
48
|
"stop": self.stop or None,
|
@@ -49,8 +52,26 @@ class SglSamplingParams:
|
|
49
52
|
"presence_penalty": self.presence_penalty,
|
50
53
|
}
|
51
54
|
|
55
|
+
def to_vertexai_kwargs(self):
|
56
|
+
if self.regex is not None:
|
57
|
+
warnings.warn(
|
58
|
+
"Regular expression is not supported in the VertexAI backend."
|
59
|
+
)
|
60
|
+
return {
|
61
|
+
"candidate_count": 1,
|
62
|
+
"max_output_tokens": self.max_new_tokens,
|
63
|
+
"stop_sequences": self.stop,
|
64
|
+
"temperature": self.temperature,
|
65
|
+
"top_p": self.top_p,
|
66
|
+
"top_k": self.top_k if self.top_k > 0 else None,
|
67
|
+
}
|
68
|
+
|
52
69
|
def to_anthropic_kwargs(self):
|
53
70
|
# Anthropic does not support frequency_penalty or presence_penalty, so we drop it here
|
71
|
+
if self.regex is not None:
|
72
|
+
warnings.warn(
|
73
|
+
"Regular expression is not supported in the Anthropic backend."
|
74
|
+
)
|
54
75
|
return {
|
55
76
|
"max_tokens_to_sample": self.max_new_tokens,
|
56
77
|
"stop_sequences": self.stop,
|
@@ -5,6 +5,8 @@ import triton
|
|
5
5
|
import triton.language as tl
|
6
6
|
from sglang.srt.utils import wrap_kernel_launcher
|
7
7
|
|
8
|
+
CUDA_CAPABILITY = torch.cuda.get_device_capability()
|
9
|
+
|
8
10
|
|
9
11
|
@triton.jit
|
10
12
|
def _fwd_kernel(
|
@@ -120,7 +122,11 @@ cached_kernel = None
|
|
120
122
|
|
121
123
|
|
122
124
|
def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len):
|
123
|
-
|
125
|
+
if CUDA_CAPABILITY[0] >= 8:
|
126
|
+
BLOCK = 128
|
127
|
+
else:
|
128
|
+
BLOCK = 64
|
129
|
+
|
124
130
|
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
|
125
131
|
assert Lq == Lk and Lk == Lv
|
126
132
|
assert Lk in {16, 32, 64, 128}
|
@@ -2,6 +2,9 @@ import torch
|
|
2
2
|
import triton
|
3
3
|
import triton.language as tl
|
4
4
|
from sglang.srt.layers.context_flashattention_nopad import context_attention_fwd
|
5
|
+
from sglang.srt.utils import wrap_kernel_launcher
|
6
|
+
|
7
|
+
CUDA_CAPABILITY = torch.cuda.get_device_capability()
|
5
8
|
|
6
9
|
|
7
10
|
@triton.jit
|
@@ -153,6 +156,9 @@ def _fwd_kernel(
|
|
153
156
|
tl.store(O_Extend + offs_o, acc / deno[:, None], mask=mask_m[:, None])
|
154
157
|
|
155
158
|
|
159
|
+
cached_kernel = None
|
160
|
+
|
161
|
+
|
156
162
|
def extend_attention_fwd(
|
157
163
|
q_extend,
|
158
164
|
k_extend,
|
@@ -175,7 +181,11 @@ def extend_attention_fwd(
|
|
175
181
|
|
176
182
|
k_buffer, v_buffer: (prefix + extend) tensors in mem_manager
|
177
183
|
"""
|
178
|
-
|
184
|
+
if CUDA_CAPABILITY[0] >= 8:
|
185
|
+
BLOCK_M, BLOCK_N = 128, 128
|
186
|
+
else:
|
187
|
+
BLOCK_M, BLOCK_N = 64, 64
|
188
|
+
|
179
189
|
Lq, Lk, Lv, Lo = (
|
180
190
|
q_extend.shape[-1],
|
181
191
|
k_extend.shape[-1],
|
@@ -193,6 +203,40 @@ def extend_attention_fwd(
|
|
193
203
|
num_warps = 4 if Lk <= 64 else 8
|
194
204
|
num_stages = 1
|
195
205
|
|
206
|
+
global cached_kernel
|
207
|
+
if cached_kernel:
|
208
|
+
cached_kernel(
|
209
|
+
grid,
|
210
|
+
num_warps,
|
211
|
+
q_extend,
|
212
|
+
k_extend,
|
213
|
+
v_extend,
|
214
|
+
o_extend,
|
215
|
+
k_buffer,
|
216
|
+
v_buffer,
|
217
|
+
req_to_tokens,
|
218
|
+
b_req_idx,
|
219
|
+
b_seq_len,
|
220
|
+
b_start_loc_extend,
|
221
|
+
b_seq_len_extend,
|
222
|
+
sm_scale,
|
223
|
+
kv_group_num,
|
224
|
+
q_extend.stride(0),
|
225
|
+
q_extend.stride(1),
|
226
|
+
k_extend.stride(0),
|
227
|
+
k_extend.stride(1),
|
228
|
+
v_extend.stride(0),
|
229
|
+
v_extend.stride(1),
|
230
|
+
o_extend.stride(0),
|
231
|
+
o_extend.stride(1),
|
232
|
+
k_buffer.stride(0),
|
233
|
+
k_buffer.stride(1),
|
234
|
+
v_buffer.stride(0),
|
235
|
+
v_buffer.stride(1),
|
236
|
+
req_to_tokens.stride(0),
|
237
|
+
)
|
238
|
+
return
|
239
|
+
|
196
240
|
_fwd_kernel[grid](
|
197
241
|
q_extend,
|
198
242
|
k_extend,
|
@@ -226,6 +270,7 @@ def extend_attention_fwd(
|
|
226
270
|
num_warps=num_warps,
|
227
271
|
num_stages=num_stages,
|
228
272
|
)
|
273
|
+
cached_kernel = wrap_kernel_launcher(_fwd_kernel)
|
229
274
|
|
230
275
|
|
231
276
|
def redundant_attention(
|
@@ -28,7 +28,7 @@ class RouterManager:
|
|
28
28
|
self.model_client = model_client
|
29
29
|
self.recv_reqs = []
|
30
30
|
|
31
|
-
# Init
|
31
|
+
# Init some configs
|
32
32
|
self.extend_dependency_time = GLOBAL_BACKEND_CONFIG.extend_dependency_time
|
33
33
|
|
34
34
|
async def loop_for_forward(self):
|
@@ -46,7 +46,7 @@ class RouterManager:
|
|
46
46
|
if has_finished:
|
47
47
|
await asyncio.sleep(self.extend_dependency_time)
|
48
48
|
|
49
|
-
await asyncio.sleep(0.
|
49
|
+
await asyncio.sleep(0.0006)
|
50
50
|
|
51
51
|
async def loop_for_recv_requests(self):
|
52
52
|
while True:
|
@@ -2,6 +2,7 @@ import asyncio
|
|
2
2
|
import logging
|
3
3
|
import multiprocessing
|
4
4
|
import time
|
5
|
+
import warnings
|
5
6
|
from concurrent.futures import ThreadPoolExecutor
|
6
7
|
from enum import Enum, auto
|
7
8
|
from typing import Dict, List, Optional, Tuple, Union
|
@@ -44,6 +45,7 @@ class ModelRpcServer(rpyc.Service):
|
|
44
45
|
self.tp_rank = tp_rank
|
45
46
|
self.tp_size = server_args.tp_size
|
46
47
|
self.schedule_heuristic = server_args.schedule_heuristic
|
48
|
+
self.schedule_conservativeness = server_args.schedule_conservativeness
|
47
49
|
|
48
50
|
# Init model and tokenizer
|
49
51
|
self.model_config = ModelConfig(
|
@@ -107,7 +109,7 @@ class ModelRpcServer(rpyc.Service):
|
|
107
109
|
self.running_batch: Batch = None
|
108
110
|
self.out_pyobjs = []
|
109
111
|
self.decode_forward_ct = 0
|
110
|
-
self.stream_interval =
|
112
|
+
self.stream_interval = server_args.stream_interval
|
111
113
|
|
112
114
|
# Init the FSM cache for constrained generation
|
113
115
|
self.regex_fsm_cache = FSMCache(self.tokenizer)
|
@@ -164,7 +166,7 @@ class ModelRpcServer(rpyc.Service):
|
|
164
166
|
+ self.tree_cache.evictable_size()
|
165
167
|
)
|
166
168
|
if available_size != self.max_total_num_token:
|
167
|
-
|
169
|
+
warnings.warn(
|
168
170
|
"Warning: "
|
169
171
|
f"available_size={available_size}, max_total_num_token={self.max_total_num_token}\n"
|
170
172
|
"KV cache pool leak detected!"
|
@@ -247,7 +249,9 @@ class ModelRpcServer(rpyc.Service):
|
|
247
249
|
available_size = (
|
248
250
|
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
|
249
251
|
)
|
250
|
-
new_ratio =
|
252
|
+
new_ratio = (
|
253
|
+
self.scheduler.new_token_estimation_ratio() * self.schedule_conservativeness
|
254
|
+
)
|
251
255
|
if self.running_batch:
|
252
256
|
available_size -= sum(
|
253
257
|
[
|
sglang/srt/models/mixtral.py
CHANGED
@@ -355,7 +355,7 @@ class MixtralForCausalLM(nn.Module):
|
|
355
355
|
):
|
356
356
|
if "rotary_emb.inv_freq" in name:
|
357
357
|
continue
|
358
|
-
for
|
358
|
+
for param_name, weight_name, shard_id in stacked_params_mapping:
|
359
359
|
if weight_name not in name:
|
360
360
|
continue
|
361
361
|
name = name.replace(weight_name, param_name)
|
sglang/srt/server_args.py
CHANGED
@@ -16,7 +16,9 @@ class ServerArgs:
|
|
16
16
|
tp_size: int = 1
|
17
17
|
model_mode: List[str] = ()
|
18
18
|
schedule_heuristic: str = "lpm"
|
19
|
+
schedule_conservativeness: float = 1.0
|
19
20
|
random_seed: int = 42
|
21
|
+
stream_interval: int = 2
|
20
22
|
disable_log_stats: bool = False
|
21
23
|
log_stats_interval: int = 10
|
22
24
|
log_level: str = "info"
|
@@ -25,10 +27,14 @@ class ServerArgs:
|
|
25
27
|
if self.tokenizer_path is None:
|
26
28
|
self.tokenizer_path = self.model_path
|
27
29
|
if self.mem_fraction_static is None:
|
28
|
-
if self.tp_size
|
29
|
-
self.mem_fraction_static = 0.
|
30
|
+
if self.tp_size >= 8:
|
31
|
+
self.mem_fraction_static = 0.80
|
32
|
+
elif self.tp_size >= 4:
|
33
|
+
self.mem_fraction_static = 0.82
|
34
|
+
elif self.tp_size >= 2:
|
35
|
+
self.mem_fraction_static = 0.85
|
30
36
|
else:
|
31
|
-
self.mem_fraction_static = 0.
|
37
|
+
self.mem_fraction_static = 0.90
|
32
38
|
|
33
39
|
@staticmethod
|
34
40
|
def add_cli_args(parser: argparse.ArgumentParser):
|
@@ -80,7 +86,7 @@ class ServerArgs:
|
|
80
86
|
"--mem-fraction-static",
|
81
87
|
type=float,
|
82
88
|
default=ServerArgs.mem_fraction_static,
|
83
|
-
help="The fraction of the memory used for static allocation (model weights and KV cache memory pool)",
|
89
|
+
help="The fraction of the memory used for static allocation (model weights and KV cache memory pool). Use a smaller value if you see out-of-memory errors.",
|
84
90
|
)
|
85
91
|
parser.add_argument(
|
86
92
|
"--tp-size",
|
@@ -102,12 +108,24 @@ class ServerArgs:
|
|
102
108
|
default=ServerArgs.schedule_heuristic,
|
103
109
|
help="Schudule mode: [lpm, weight, random, fcfs]",
|
104
110
|
)
|
111
|
+
parser.add_argument(
|
112
|
+
"--schedule-conservativeness",
|
113
|
+
type=float,
|
114
|
+
default=ServerArgs.schedule_conservativeness,
|
115
|
+
help="How conservative the schedule policy is. A larger value means more conservative scheduling. Use a larger value if you see out-of-memory errors.",
|
116
|
+
)
|
105
117
|
parser.add_argument(
|
106
118
|
"--random-seed",
|
107
119
|
type=int,
|
108
120
|
default=ServerArgs.random_seed,
|
109
121
|
help="Random seed.",
|
110
122
|
)
|
123
|
+
parser.add_argument(
|
124
|
+
"--stream-interval",
|
125
|
+
type=int,
|
126
|
+
default=ServerArgs.stream_interval,
|
127
|
+
help="The interval in terms of token length for streaming",
|
128
|
+
)
|
111
129
|
parser.add_argument(
|
112
130
|
"--log-level",
|
113
131
|
type=str,
|
sglang/srt/utils.py
CHANGED
@@ -209,7 +209,7 @@ def load_image(image_file):
|
|
209
209
|
elif image_file.lower().endswith(("png", "jpg", "jpeg", "webp", "gif")):
|
210
210
|
image = Image.open(image_file)
|
211
211
|
elif image_file.startswith("data:"):
|
212
|
-
image_file =
|
212
|
+
image_file = image_file.split(",")[1]
|
213
213
|
image = Image.open(BytesIO(base64.b64decode(image_file)))
|
214
214
|
else:
|
215
215
|
image = Image.open(BytesIO(base64.b64decode(image_file)))
|
sglang/test/test_programs.py
CHANGED
@@ -304,7 +304,10 @@ def test_image_qa():
|
|
304
304
|
temperature=0,
|
305
305
|
max_new_tokens=64,
|
306
306
|
)
|
307
|
-
assert
|
307
|
+
assert (
|
308
|
+
"taxi" in state.messages()[-1]["content"]
|
309
|
+
or "car" in state.messages()[-1]["content"]
|
310
|
+
)
|
308
311
|
|
309
312
|
|
310
313
|
def test_stream():
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: sglang
|
3
|
-
Version: 0.1.
|
3
|
+
Version: 0.1.5
|
4
4
|
Summary: A structured generation langauge for LLMs.
|
5
5
|
License: Apache License
|
6
6
|
Version 2.0, January 2004
|
@@ -234,6 +234,7 @@ Requires-Dist: lark ; extra == 'srt'
|
|
234
234
|
Requires-Dist: numba ; extra == 'srt'
|
235
235
|
|
236
236
|
# SGLang
|
237
|
+
| [**Blog**](https://lmsys.org/blog/2024-01-17-sglang/) | [**Paper**](https://arxiv.org/abs/2312.07104) |
|
237
238
|
|
238
239
|
SGLang is a structured generation language designed for large language models (LLMs).
|
239
240
|
It makes your interaction with LLMs faster and more controllable by co-designing the frontend language and the runtime system.
|
@@ -267,10 +268,20 @@ pip install --upgrade pip
|
|
267
268
|
pip install -e "python[all]"
|
268
269
|
```
|
269
270
|
|
271
|
+
### Notes
|
272
|
+
- If you are using older GPUs (NVIDIA T4, V100), please use `pip install "triton>=2.2.0"` to avoid some bugs in the triton compiler
|
273
|
+
- If you only need to use the OpenAI backend, you can avoid installing other dependencies by using `pip install sglang[openai]`
|
274
|
+
|
270
275
|
## Quick Start
|
271
276
|
The example below shows how to use sglang to answer a mulit-turn question.
|
272
277
|
|
273
278
|
### Using OpenAI Models
|
279
|
+
Set the OpenAI API Key
|
280
|
+
```
|
281
|
+
export OPENAI_API_KEY=sk-******
|
282
|
+
```
|
283
|
+
|
284
|
+
Then, answer a multi-turn question.
|
274
285
|
```python
|
275
286
|
from sglang import function, system, user, assistant, gen, set_default_backend, OpenAI
|
276
287
|
|
@@ -325,6 +336,7 @@ for m in state.messages():
|
|
325
336
|
|
326
337
|
### More Examples
|
327
338
|
|
339
|
+
Anthropic and VertexAI (Gemini) models are also supported.
|
328
340
|
You can find more examples at [examples/quick_start](examples/quick_start).
|
329
341
|
|
330
342
|
## Frontend: Structured Generation Langauge (SGLang)
|
@@ -334,19 +346,20 @@ To begin with, import sglang.
|
|
334
346
|
import sglang as sgl
|
335
347
|
```
|
336
348
|
|
337
|
-
`sglang` provides some simple primitives such as `gen`, `select`, `fork`.
|
349
|
+
`sglang` provides some simple primitives such as `gen`, `select`, `fork`, `image`.
|
338
350
|
You can implement your prompt flow in a function decorated by `sgl.function`.
|
339
351
|
You can then invoke the function with `run` or `run_batch`.
|
340
352
|
The system will manage the state, chat template, and parallelism for you.
|
341
353
|
|
342
354
|
### Control Flow
|
355
|
+
You can use any Python code within the function body, including control flow, nested function calls, and external libraries.
|
356
|
+
|
343
357
|
```python
|
344
358
|
@sgl.function
|
345
359
|
def control_flow(s, question):
|
346
360
|
s += "To answer this question: " + question + ", "
|
347
361
|
s += "I need to use a " + sgl.gen("tool", choices=["calculator", "web browser"]) + ". "
|
348
362
|
|
349
|
-
# You can use if or nested function calls
|
350
363
|
if s["tool"] == "calculator":
|
351
364
|
s += "The math expression is" + sgl.gen("expression")
|
352
365
|
elif s["tool"] == "web browser":
|
@@ -354,6 +367,9 @@ def control_flow(s, question):
|
|
354
367
|
```
|
355
368
|
|
356
369
|
### Parallelism
|
370
|
+
Use `fork` to launch parallel prompts.
|
371
|
+
Because `sgl.gen` is non-blocking, the for loop below issues two generation calls in parallel.
|
372
|
+
|
357
373
|
```python
|
358
374
|
@sgl.function
|
359
375
|
def tip_suggestion(s):
|
@@ -362,7 +378,7 @@ def tip_suggestion(s):
|
|
362
378
|
"1. Balanced Diet. 2. Regular Exercise.\n\n"
|
363
379
|
)
|
364
380
|
|
365
|
-
forks = s.fork(2)
|
381
|
+
forks = s.fork(2)
|
366
382
|
for i, f in enumerate(forks):
|
367
383
|
f += f"Now, expand tip {i+1} into a paragraph:\n"
|
368
384
|
f += sgl.gen(f"detailed_tip", max_tokens=256, stop="\n\n")
|
@@ -373,6 +389,8 @@ def tip_suggestion(s):
|
|
373
389
|
```
|
374
390
|
|
375
391
|
### Multi Modality
|
392
|
+
Use `sgl.image` to pass an image as input.
|
393
|
+
|
376
394
|
```python
|
377
395
|
@sgl.function
|
378
396
|
def image_qa(s, image_file, question):
|
@@ -381,11 +399,13 @@ def image_qa(s, image_file, question):
|
|
381
399
|
```
|
382
400
|
|
383
401
|
### Constrained Decoding
|
402
|
+
Use `regex=` to specify a regular expression as a decoding constraint.
|
403
|
+
|
384
404
|
```python
|
385
|
-
@function
|
405
|
+
@sgl.function
|
386
406
|
def regular_expression_gen(s):
|
387
407
|
s += "Q: What is the IP address of the Google DNS servers?\n"
|
388
|
-
s += "A: " + gen(
|
408
|
+
s += "A: " + sgl.gen(
|
389
409
|
"answer",
|
390
410
|
temperature=0,
|
391
411
|
regex=r"((25[0-5]|2[0-4]\d|[01]?\d\d?).){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)",
|
@@ -393,6 +413,8 @@ def regular_expression_gen(s):
|
|
393
413
|
```
|
394
414
|
|
395
415
|
### Batching
|
416
|
+
Use `run_batch` to run a batch of requests with continuous batching.
|
417
|
+
|
396
418
|
```python
|
397
419
|
@sgl.function
|
398
420
|
def text_qa(s, question):
|
@@ -405,10 +427,13 @@ states = text_qa.run_batch(
|
|
405
427
|
{"question": "What is the capital of France?"},
|
406
428
|
{"question": "What is the capital of Japan?"},
|
407
429
|
],
|
430
|
+
progress_bar=True
|
408
431
|
)
|
409
432
|
```
|
410
433
|
|
411
434
|
### Streaming
|
435
|
+
Add `stream=True` to enable streaming.
|
436
|
+
|
412
437
|
```python
|
413
438
|
@sgl.function
|
414
439
|
def text_qa(s, question):
|
@@ -417,7 +442,9 @@ def text_qa(s, question):
|
|
417
442
|
|
418
443
|
states = text_qa.run(
|
419
444
|
question="What is the capital of France?",
|
420
|
-
temperature=0.1
|
445
|
+
temperature=0.1,
|
446
|
+
stream=True
|
447
|
+
)
|
421
448
|
|
422
449
|
for out in state.text_iter():
|
423
450
|
print(out, end="", flush=True)
|
@@ -426,7 +453,7 @@ for out in state.text_iter():
|
|
426
453
|
## Backend: SGLang Runtime (SRT)
|
427
454
|
The SGLang Runtime (SRT) is designed to work best with the SGLang frontend.
|
428
455
|
However, it can also be used as a standalone API server.
|
429
|
-
In this case, the RadixAttention can still greatly accelerate many use cases.
|
456
|
+
In this case, the [RadixAttention](https://arxiv.org/abs/2312.07104) can still greatly accelerate many use cases with automatic KV cache reuse.
|
430
457
|
|
431
458
|
### Usage
|
432
459
|
Launch a server
|
@@ -450,6 +477,10 @@ curl http://localhost:30000/v1/completions \
|
|
450
477
|
```
|
451
478
|
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 --tp 2
|
452
479
|
```
|
480
|
+
- If you see out-of-memory errors during serving, please try to reduce the memory usage of the KV cache pool by setting a smaller value of `--mem-fraction-static`. The default value is `0.9`
|
481
|
+
```
|
482
|
+
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 --mem-fraction-static 0.7
|
483
|
+
```
|
453
484
|
|
454
485
|
### Supported Models
|
455
486
|
- Llama
|
@@ -457,6 +488,7 @@ python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port
|
|
457
488
|
- Mixtral
|
458
489
|
- LLaVA
|
459
490
|
- `python3 -m sglang.launch_server --model-path liuhaotian/llava-v1.5-7b --tokenizer-path llava-hf/llava-1.5-7b-hf --port 30000`
|
491
|
+
- AWQ quantization
|
460
492
|
|
461
493
|
## Benchmark And Performance
|
462
494
|
|
@@ -466,13 +498,13 @@ python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port
|
|
466
498
|
- Mixtral-8x7B on NVIDIA A10G, FP16, Tensor Parallelism=8
|
467
499
|

|
468
500
|
|
469
|
-
Learn more [here]().
|
501
|
+
Learn more [here](docs/benchmark_results.md).
|
470
502
|
|
471
503
|
## Roadmap
|
472
|
-
- [ ] Function call
|
473
|
-
- [ ] Quantization
|
504
|
+
- [ ] Function call APIs
|
474
505
|
- [ ] S-LoRA
|
475
|
-
- [ ]
|
506
|
+
- [ ] Support more models
|
507
|
+
- [ ] Support more hardware backends
|
476
508
|
|
477
509
|
## Citation And Acknowledgment
|
478
510
|
```
|
@@ -1,5 +1,5 @@
|
|
1
|
-
sglang/__init__.py,sha256=
|
2
|
-
sglang/api.py,sha256=
|
1
|
+
sglang/__init__.py,sha256=G73L_PWJ_6mF3NIE4ZAOWcb1CUbETSeRWr3wDTePrZ4,95
|
2
|
+
sglang/api.py,sha256=SxmPP_PMYi4DfUcwz_V9UvYOwGmQdHPgpMV6jDDJq68,3928
|
3
3
|
sglang/flush_cache.py,sha256=cCD_MTlQ5qEv__w0nOthDnVitdAfyscYjksBljwC5Mw,1835
|
4
4
|
sglang/global_config.py,sha256=PAX7TWeFcq0HBzNUWyCONAOjqIokWqw8vT7I6sBSKTc,797
|
5
5
|
sglang/launch_server.py,sha256=jKPZRDN5bUe8Wgz5eoDkqeePhmKa8DLD4DpXQLT5auo,294
|
@@ -7,15 +7,14 @@ sglang/utils.py,sha256=tvJs95QGZ_PcnTjvm-CDGQ8dJe84qUUOfG7BeF79nsA,5670
|
|
7
7
|
sglang/backend/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
8
8
|
sglang/backend/anthropic.py,sha256=y5TN9EDrJtOH4JEUxpXu-endloeYBy7xMUr3r7Ah3MA,1462
|
9
9
|
sglang/backend/base_backend.py,sha256=pPalZfoezxnUBs752j7lm0uMwa8tZuCWd-ijSdStMO8,1745
|
10
|
-
sglang/backend/huggingface.py,sha256=roQlt8y41PQbmnAY47CXiR0KJaxhtljH6j8RhbsR4f0,10533
|
11
10
|
sglang/backend/openai.py,sha256=umTWzC2p4PypDaXHe6Kc8By5IM_Doi0Ob97vK_fFWDc,7367
|
12
11
|
sglang/backend/runtime_endpoint.py,sha256=rIhwtKJaLLCJAc6q6kqxEVC8xO_NNjmJs7BnxlOydLM,5860
|
13
|
-
sglang/backend/
|
12
|
+
sglang/backend/vertexai.py,sha256=BLfWf_tEgoHY9srCufJM5PLe3tql2j0G6ia7cPykxCM,4713
|
14
13
|
sglang/lang/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
15
14
|
sglang/lang/chat_template.py,sha256=1x4724K2oxu7VID40-5Megk7SbZb97PQCbRjLpoescU,5599
|
16
15
|
sglang/lang/compiler.py,sha256=wNn_UqV6Sxl22mv-PpzFUtRgiFFV-Y4OYpO4LshEoRM,7527
|
17
|
-
sglang/lang/interpreter.py,sha256=
|
18
|
-
sglang/lang/ir.py,sha256=
|
16
|
+
sglang/lang/interpreter.py,sha256=0WTJxCB57WDBr_E6kW39wByhcG2nRFjEMTzOjAaNhrY,22453
|
17
|
+
sglang/lang/ir.py,sha256=uUnBRyaM-8suVOEb2qf4EAt_VN2pWbXV6V88jLk6wsI,13160
|
19
18
|
sglang/lang/tracer.py,sha256=zH9DENdJBPEvWkThgwqvHOW7aC1EPC8xha_WpEj-SRs,8243
|
20
19
|
sglang/srt/backend_config.py,sha256=7MdHjNsZeAKB9IWWxyrvyOjJJAdI5tl9hWl-MV7yHrI,226
|
21
20
|
sglang/srt/hf_transformers_utils.py,sha256=soRyYLoCn7GxgxvonufGFkdFBA3eH5i3Izk_wi7p1l0,5285
|
@@ -23,14 +22,14 @@ sglang/srt/memory_pool.py,sha256=cN3Lrs9fn0DFmt67_IN4g06mPzKUxpbAJGUw4O33xbo,360
|
|
23
22
|
sglang/srt/model_config.py,sha256=R7YaR8H8AmCJl_1XcSP0zII_5ebZNl0wMXNVANGWd2c,997
|
24
23
|
sglang/srt/sampling_params.py,sha256=Sd9l_uIIuS_mhbzljKwTGDO9ESMviNOYGxOifc71RrY,2895
|
25
24
|
sglang/srt/server.py,sha256=XxTS1K4N5y-ZknLBQefxk1UxC50l6DABVqJOrJ-NG74,6388
|
26
|
-
sglang/srt/server_args.py,sha256=
|
27
|
-
sglang/srt/utils.py,sha256
|
25
|
+
sglang/srt/server_args.py,sha256=ojox8nu2tgPEy_JlKKEvRenby4HKkmWk-1MpHy3PmnI,5771
|
26
|
+
sglang/srt/utils.py,sha256=-2F99bqYT99x1jScMjciJxgQec6CaH6PcCHSmrKHhhY,5692
|
28
27
|
sglang/srt/constrained/fsm.py,sha256=H4kXSsV4IX2ow5TMmnmd-8ho4qqJ5mpVZ4MOH5FUtnY,12900
|
29
28
|
sglang/srt/constrained/fsm_cache.py,sha256=KX4bFX5hj0W66SC9pSvst1ew7etaOMTtTC75z0enRME,1087
|
30
29
|
sglang/srt/constrained/regex.py,sha256=CcV7KBOKS2ZxGoEr6BHG5okagNIGEXYvGvhKXu5gtDA,18689
|
31
30
|
sglang/srt/constrained/tokenizer.py,sha256=rei9yKHFETcbDPOpI7bpIYdrBFgIBhGr_U-zb3r5Beo,7951
|
32
|
-
sglang/srt/layers/context_flashattention_nopad.py,sha256=
|
33
|
-
sglang/srt/layers/extend_attention.py,sha256=
|
31
|
+
sglang/srt/layers/context_flashattention_nopad.py,sha256=GkjLiTkS4px_uLcW0aDocE3_OBXtujZ-SlsN2b2U7ng,5204
|
32
|
+
sglang/srt/layers/extend_attention.py,sha256=pWVE6ySnPiVLFON__bie73eDhmXHk4tECMK8zTiJNbI,12558
|
34
33
|
sglang/srt/layers/get_selected_logprob.py,sha256=CpMXM9WXMSB-AVaxBB_aVl1Qx_ZtAFFnjDTm4CgNDpU,2199
|
35
34
|
sglang/srt/layers/logits_processor.py,sha256=rwcXwdZ7-dW9zvJX3MF_EHSxMLbU7TIQ9xUIYRu-WAs,3013
|
36
35
|
sglang/srt/layers/radix_attention.py,sha256=hmPNFg2TkN4EAVUj376N_89RRtUYRwFgUpjj5SydnRk,6170
|
@@ -40,18 +39,18 @@ sglang/srt/managers/io_struct.py,sha256=5jMWj6_U8yTQd5V3tpDtThnoFyF0A3ln-4Z5bSL3
|
|
40
39
|
sglang/srt/managers/openai_protocol.py,sha256=Eid_734Wup4jsL1ZS2Op0vwRuzvNbF4mV2UcwFxqEvI,327
|
41
40
|
sglang/srt/managers/tokenizer_manager.py,sha256=jVwr0lM18RFJLhDb5TWlUpQ4Q8tALT4L6GY0jmaZkLw,7861
|
42
41
|
sglang/srt/managers/router/infer_batch.py,sha256=UfS1uVhGnM-62Xv1cfu_IoTeIUxkjkKc4W3trtGbadc,11541
|
43
|
-
sglang/srt/managers/router/manager.py,sha256=
|
44
|
-
sglang/srt/managers/router/model_rpc.py,sha256=
|
45
|
-
sglang/srt/managers/router/model_runner.py,sha256=
|
42
|
+
sglang/srt/managers/router/manager.py,sha256=AVCdYKKYcIQsIwpudkfFY4jh6M--ubLjXeYGzfi2ebw,2528
|
43
|
+
sglang/srt/managers/router/model_rpc.py,sha256=CR3qbHvShttlC19qAZ8B8nhT6UPobeu2Dy3Z0n6WdC8,19448
|
44
|
+
sglang/srt/managers/router/model_runner.py,sha256=IhSdpBcd54HN01HDi_PAkJztFxEGDcnktdoPZDWEx3s,16487
|
46
45
|
sglang/srt/managers/router/radix_cache.py,sha256=ZQPm9HhQ7vD3Gl5nhuvw3ZW4ZRARcplqWed1GYUvHCg,6441
|
47
46
|
sglang/srt/managers/router/scheduler.py,sha256=ejuIRwqqMZVXFKUionRJxy5AtNvK25YoGRO9rFY-rc8,2926
|
48
47
|
sglang/srt/models/llama2.py,sha256=D3j-NtyM8PA74UhXM7wSPogI2HKX-JcQAWcOusrZZo0,11320
|
49
48
|
sglang/srt/models/llava.py,sha256=COS0IC6Yo-QiwKe5emgCbtEe9HgaSu5tt6CQA7UtV38,8533
|
50
|
-
sglang/srt/models/mixtral.py,sha256=
|
51
|
-
sglang/test/test_programs.py,sha256=
|
49
|
+
sglang/srt/models/mixtral.py,sha256=frd2XsNZwP0XsQtRiYhgy4PErLNLgtIsLakmNrOKBAU,13712
|
50
|
+
sglang/test/test_programs.py,sha256=EovA2xL7fODcTbFj2wAAmYKlg1mLZ1x1BRU6nrXFRdE,11416
|
52
51
|
sglang/test/test_utils.py,sha256=Knxg3BTA6d_7XSlprbBCdvfDr2SN5x7LhkT-tZFk5EQ,4828
|
53
|
-
sglang-0.1.
|
54
|
-
sglang-0.1.
|
55
|
-
sglang-0.1.
|
56
|
-
sglang-0.1.
|
57
|
-
sglang-0.1.
|
52
|
+
sglang-0.1.5.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
53
|
+
sglang-0.1.5.dist-info/METADATA,sha256=aepmAL6VoXRcxZBIDKvxwikCYSbvWFm_JFGTxb3Mgfw,23345
|
54
|
+
sglang-0.1.5.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
|
55
|
+
sglang-0.1.5.dist-info/top_level.txt,sha256=yxhh3pYQkcnA7v3Bg889C2jZhvtJdEincysO7PEB09M,7
|
56
|
+
sglang-0.1.5.dist-info/RECORD,,
|
sglang/backend/huggingface.py
DELETED
@@ -1,349 +0,0 @@
|
|
1
|
-
import functools
|
2
|
-
from enum import Enum, auto
|
3
|
-
from typing import Callable, List, Optional, Union
|
4
|
-
|
5
|
-
import numpy as np
|
6
|
-
import torch
|
7
|
-
import transformers
|
8
|
-
from sglang.backend.base_backend import BaseBackend
|
9
|
-
from sglang.lang.chat_template import get_chat_template_by_model_path
|
10
|
-
from sglang.lang.interpreter import ProgramState
|
11
|
-
from sglang.utils import get_available_gpu_memory
|
12
|
-
from transformers import (
|
13
|
-
AutoModelForCausalLM,
|
14
|
-
AutoTokenizer,
|
15
|
-
StoppingCriteria,
|
16
|
-
StoppingCriteriaList,
|
17
|
-
)
|
18
|
-
from transformersgl.generation.logits_process import (
|
19
|
-
LogitsProcessorList,
|
20
|
-
RepetitionPenaltyLogitsProcessor,
|
21
|
-
TemperatureLogitsWarper,
|
22
|
-
TopKLogitsWarper,
|
23
|
-
TopPLogitsWarper,
|
24
|
-
)
|
25
|
-
|
26
|
-
|
27
|
-
class StopReason(Enum):
|
28
|
-
EOS_TOKEN = auto()
|
29
|
-
STOP_STR = auto()
|
30
|
-
LENGTH = auto()
|
31
|
-
|
32
|
-
|
33
|
-
def load_model(
|
34
|
-
model_name: str,
|
35
|
-
device,
|
36
|
-
num_gpus,
|
37
|
-
max_gpu_memory,
|
38
|
-
model_kwargs=None,
|
39
|
-
tokenizer_kwargs=None,
|
40
|
-
):
|
41
|
-
model_kwargs = model_kwargs or {}
|
42
|
-
tokenizer_kwargs = tokenizer_kwargs or {}
|
43
|
-
|
44
|
-
if device == "cuda":
|
45
|
-
model_kwargs["torch_dtype"] = torch.float16
|
46
|
-
if num_gpus != 1:
|
47
|
-
model_kwargs["device_map"] = "auto"
|
48
|
-
if max_gpu_memory is None:
|
49
|
-
model_kwargs[
|
50
|
-
"device_map"
|
51
|
-
] = "sequential" # This is important for not the same VRAM sizes
|
52
|
-
available_gpu_memory = [
|
53
|
-
get_available_gpu_memory(i, False) for i in range(num_gpus)
|
54
|
-
]
|
55
|
-
model_kwargs["max_memory"] = {
|
56
|
-
i: str(int(available_gpu_memory[i] * 0.85)) + "GiB"
|
57
|
-
for i in range(num_gpus)
|
58
|
-
}
|
59
|
-
else:
|
60
|
-
model_kwargs["max_memory"] = {
|
61
|
-
i: max_gpu_memory for i in range(num_gpus)
|
62
|
-
}
|
63
|
-
elif device == "cpu":
|
64
|
-
model_kwargs["torch_dtype"] = torch.float32
|
65
|
-
else:
|
66
|
-
raise ValueError(f"Invalid device: {device}")
|
67
|
-
|
68
|
-
model = AutoModelForCausalLM.from_pretrained(
|
69
|
-
model_name, low_cpu_mem_usage=True, **model_kwargs
|
70
|
-
)
|
71
|
-
tokenizer = AutoTokenizer.from_pretrained(model_name, **tokenizer_kwargs)
|
72
|
-
|
73
|
-
if num_gpus == 1:
|
74
|
-
model.to(device).eval()
|
75
|
-
|
76
|
-
return model, tokenizer
|
77
|
-
|
78
|
-
|
79
|
-
def prepare_logits_processor(
|
80
|
-
temperature: float, repetition_penalty: float, top_p: float, top_k: int
|
81
|
-
) -> LogitsProcessorList:
|
82
|
-
processor_list = LogitsProcessorList()
|
83
|
-
# TemperatureLogitsWarper doesn't accept 0.0, 1.0 makes it a no-op so we skip two cases.
|
84
|
-
if temperature >= 1e-5 and temperature != 1.0:
|
85
|
-
processor_list.append(TemperatureLogitsWarper(temperature))
|
86
|
-
if repetition_penalty > 1.0:
|
87
|
-
processor_list.append(RepetitionPenaltyLogitsProcessor(repetition_penalty))
|
88
|
-
if 1e-8 <= top_p < 1.0:
|
89
|
-
processor_list.append(TopPLogitsWarper(top_p))
|
90
|
-
if top_k > 0:
|
91
|
-
processor_list.append(TopKLogitsWarper(top_k))
|
92
|
-
return processor_list
|
93
|
-
|
94
|
-
|
95
|
-
@functools.lru_cache
|
96
|
-
def get_token_healing_mask(tokenizer, prompt_last_token):
|
97
|
-
last_str = tokenizer.convert_ids_to_tokens(prompt_last_token)
|
98
|
-
disallowed = torch.zeros(len(tokenizer), dtype=bool)
|
99
|
-
for s, t_id in tokenizer.get_vocab().items():
|
100
|
-
if not s.startswith(last_str):
|
101
|
-
disallowed[t_id] = 1
|
102
|
-
return disallowed
|
103
|
-
|
104
|
-
|
105
|
-
@functools.lru_cache
|
106
|
-
def get_int_token_mask(tokenizer):
|
107
|
-
disallowed = torch.zeros(len(tokenizer), dtype=bool)
|
108
|
-
for s, t_id in tokenizer.get_vocab().items():
|
109
|
-
s = s.replace("▁", "").strip()
|
110
|
-
if not (s.isdigit() or len(s) == 0 or s == ","):
|
111
|
-
disallowed[t_id] = 1
|
112
|
-
disallowed[tokenizer.eos_token_id] = 0
|
113
|
-
return disallowed
|
114
|
-
|
115
|
-
|
116
|
-
@torch.inference_mode()
|
117
|
-
def generate_stream(
|
118
|
-
model,
|
119
|
-
tokenizer,
|
120
|
-
prompt,
|
121
|
-
max_new_tokens,
|
122
|
-
stop: List[str],
|
123
|
-
temperature,
|
124
|
-
top_p,
|
125
|
-
token_healing,
|
126
|
-
logit_mask=None,
|
127
|
-
):
|
128
|
-
logits_processor = prepare_logits_processor(
|
129
|
-
temperature=temperature, repetition_penalty=1.0, top_p=top_p, top_k=0
|
130
|
-
)
|
131
|
-
device = model.device
|
132
|
-
input_ids = tokenizer.encode(prompt)
|
133
|
-
output_ids = list(input_ids)
|
134
|
-
prompt_len = len(prompt)
|
135
|
-
|
136
|
-
# Resolve stop
|
137
|
-
stop_token_ids = [tokenizer.eos_token_id]
|
138
|
-
|
139
|
-
# Token healing
|
140
|
-
token_healing = token_healing and len(input_ids) > 0
|
141
|
-
if token_healing:
|
142
|
-
token_healing_mask = get_token_healing_mask(tokenizer, input_ids[-1])
|
143
|
-
del output_ids[-1]
|
144
|
-
|
145
|
-
# Generate
|
146
|
-
past_key_values = None
|
147
|
-
stop_reason = None
|
148
|
-
for i in range(max_new_tokens):
|
149
|
-
# Forward
|
150
|
-
if i == 0: # prefill
|
151
|
-
out = model(torch.as_tensor([output_ids], device=device), use_cache=True)
|
152
|
-
else: # decoding
|
153
|
-
out = model(
|
154
|
-
input_ids=torch.as_tensor([[token]], device=device),
|
155
|
-
use_cache=True,
|
156
|
-
past_key_values=past_key_values,
|
157
|
-
)
|
158
|
-
logits = out.logits
|
159
|
-
past_key_values = out.past_key_values
|
160
|
-
|
161
|
-
# Logit mask
|
162
|
-
if token_healing and i == 0:
|
163
|
-
logits[0, -1, token_healing_mask] = -1e4
|
164
|
-
if logit_mask is not None:
|
165
|
-
logits[0, -1, logit_mask] = -1e4
|
166
|
-
|
167
|
-
# Sample next token
|
168
|
-
last_token_logits = logits_processor(None, logits[:, -1, :])[0]
|
169
|
-
if temperature < 1e-5 or top_p < 1e-8: # greedy
|
170
|
-
token = int(torch.argmax(last_token_logits))
|
171
|
-
else:
|
172
|
-
probs = torch.softmax(last_token_logits, dim=-1)
|
173
|
-
token = int(torch.multinomial(probs, num_samples=1))
|
174
|
-
output_ids.append(token)
|
175
|
-
|
176
|
-
# Stop condition
|
177
|
-
if token in stop_token_ids:
|
178
|
-
stop_reason = StopReason.EOS_TOKEN
|
179
|
-
break
|
180
|
-
|
181
|
-
output_str = tokenizer.decode(output_ids, skip_special_tokens=True)
|
182
|
-
for stop_str in stop:
|
183
|
-
pos = output_str[prompt_len:].find(stop_str)
|
184
|
-
if pos != -1:
|
185
|
-
stop_reason = StopReason.STOP_STR
|
186
|
-
output_str = output_str[: prompt_len + pos]
|
187
|
-
break
|
188
|
-
|
189
|
-
if stop_reason:
|
190
|
-
break
|
191
|
-
|
192
|
-
return output_str[prompt_len:]
|
193
|
-
|
194
|
-
|
195
|
-
class HuggingFaceTransformers(BaseBackend):
|
196
|
-
def __init__(
|
197
|
-
self,
|
198
|
-
model_name,
|
199
|
-
device="cuda",
|
200
|
-
num_gpus=1,
|
201
|
-
max_gpu_memory=None,
|
202
|
-
model_kwargs=None,
|
203
|
-
tokenizer_kwargs=None,
|
204
|
-
):
|
205
|
-
self.model_name = model_name
|
206
|
-
self.device = device
|
207
|
-
|
208
|
-
self.model, self.tokenizer = load_model(
|
209
|
-
model_name, device, num_gpus, max_gpu_memory, model_kwargs, tokenizer_kwargs
|
210
|
-
)
|
211
|
-
|
212
|
-
self.chat_template = get_chat_template_by_model_path(model_name)
|
213
|
-
|
214
|
-
def get_chat_template(self):
|
215
|
-
return self.chat_template
|
216
|
-
|
217
|
-
def cache_prefix(self, prefix_str: str):
|
218
|
-
pass
|
219
|
-
|
220
|
-
def uncache_prefix(self, rid: str):
|
221
|
-
pass
|
222
|
-
|
223
|
-
def end_request(self, rid: str):
|
224
|
-
pass
|
225
|
-
|
226
|
-
def begin_program(self, s: ProgramState):
|
227
|
-
pass
|
228
|
-
|
229
|
-
def end_program(self, s: ProgramState):
|
230
|
-
pass
|
231
|
-
|
232
|
-
def fill(self, s: ProgramState, text: str):
|
233
|
-
return False
|
234
|
-
|
235
|
-
def generate_internal(
|
236
|
-
self,
|
237
|
-
prompt: str,
|
238
|
-
max_tokens: int,
|
239
|
-
stop: Union[str, List[str]],
|
240
|
-
temperature: float,
|
241
|
-
top_p: float,
|
242
|
-
dtype: Optional[str] = None,
|
243
|
-
):
|
244
|
-
if dtype is None:
|
245
|
-
comp = generate_stream(
|
246
|
-
self.model,
|
247
|
-
self.tokenizer,
|
248
|
-
prompt,
|
249
|
-
max_new_tokens=max_tokens,
|
250
|
-
stop=stop,
|
251
|
-
temperature=temperature,
|
252
|
-
top_p=top_p,
|
253
|
-
token_healing=True,
|
254
|
-
)
|
255
|
-
elif dtype in [str, "str", "string"]:
|
256
|
-
comp = generate_stream(
|
257
|
-
self.model,
|
258
|
-
self.tokenizer,
|
259
|
-
prompt + '"',
|
260
|
-
max_new_tokens=max_tokens,
|
261
|
-
stop=['"'],
|
262
|
-
temperature=temperature,
|
263
|
-
top_p=top_p,
|
264
|
-
token_healing=False,
|
265
|
-
)
|
266
|
-
comp = '"' + comp + '"'
|
267
|
-
elif dtype in [int, "int"]:
|
268
|
-
logit_mask = get_int_token_mask(self.tokenizer)
|
269
|
-
comp = generate_stream(
|
270
|
-
self.model,
|
271
|
-
self.tokenizer,
|
272
|
-
prompt,
|
273
|
-
max_new_tokens=max_tokens,
|
274
|
-
stop=stop + [" ", ","],
|
275
|
-
temperature=temperature,
|
276
|
-
top_p=top_p,
|
277
|
-
token_healing=False,
|
278
|
-
logit_mask=logit_mask,
|
279
|
-
)
|
280
|
-
return comp
|
281
|
-
|
282
|
-
def generate(
|
283
|
-
self,
|
284
|
-
s: ProgramState,
|
285
|
-
max_tokens: int,
|
286
|
-
stop: Union[str, List[str]],
|
287
|
-
temperature: float,
|
288
|
-
top_p: float,
|
289
|
-
dtype: Optional[str] = None,
|
290
|
-
):
|
291
|
-
prompt = s.text
|
292
|
-
comp = self.generate_internal(
|
293
|
-
prompt, max_tokens, stop, temperature, top_p, dtype
|
294
|
-
)
|
295
|
-
return comp
|
296
|
-
|
297
|
-
def parallel_generate(
|
298
|
-
self,
|
299
|
-
s: ProgramState,
|
300
|
-
prefixes: List[str],
|
301
|
-
join_func: Callable,
|
302
|
-
max_tokens: int,
|
303
|
-
stop: Union[str, List[str]],
|
304
|
-
temperature: float,
|
305
|
-
top_p: float,
|
306
|
-
dtype: Optional[str] = None,
|
307
|
-
):
|
308
|
-
prompt = s.text
|
309
|
-
parallel_prompts = [prompt + prefix for prefix in prefixes]
|
310
|
-
|
311
|
-
comps = []
|
312
|
-
for i in range(len(parallel_prompts)):
|
313
|
-
comps.append(
|
314
|
-
self.generate_internal(
|
315
|
-
parallel_prompts[i], max_tokens, stop, temperature, top_p, dtype
|
316
|
-
)
|
317
|
-
)
|
318
|
-
|
319
|
-
joined = join_func([p + c for p, c in zip(prefixes, comps)])
|
320
|
-
return joined, comps
|
321
|
-
|
322
|
-
@torch.inference_mode()
|
323
|
-
def select(
|
324
|
-
self, s: ProgramState, choices: List[str], temperature: float, top_p: float
|
325
|
-
):
|
326
|
-
loss_fct = torch.nn.CrossEntropyLoss()
|
327
|
-
prompt = s.text
|
328
|
-
|
329
|
-
prompt_len = self.tokenizer.encode(prompt, return_tensors="pt").shape[1]
|
330
|
-
prompt_choices = [prompt + choice for choice in choices]
|
331
|
-
|
332
|
-
scores = []
|
333
|
-
for i in range(len(choices)):
|
334
|
-
choice_ids = self.tokenizer.encode(
|
335
|
-
prompt_choices[i], return_tensors="pt"
|
336
|
-
).to(self.model.device)
|
337
|
-
logits = self.model(choice_ids).logits
|
338
|
-
|
339
|
-
# score = -loss_fct(logits[0, :-1, :], choice_ids[0, 1:]).item()
|
340
|
-
|
341
|
-
logprobs = torch.log(torch.softmax(logits, dim=-1))
|
342
|
-
idx1 = torch.arange(0, logits.shape[1] - 1, device=logits.device)
|
343
|
-
idx2 = choice_ids[0, 1:]
|
344
|
-
selected_logprobs = logprobs[0, idx1, idx2]
|
345
|
-
score = selected_logprobs.mean().item()
|
346
|
-
scores.append(score)
|
347
|
-
|
348
|
-
decision = choices[np.argmax(scores)]
|
349
|
-
return decision, scores
|
sglang/backend/tgi.py
DELETED
@@ -1,190 +0,0 @@
|
|
1
|
-
import re
|
2
|
-
from concurrent.futures import ThreadPoolExecutor
|
3
|
-
from functools import partial
|
4
|
-
from itertools import repeat
|
5
|
-
from typing import List, Optional, Union
|
6
|
-
|
7
|
-
from sglang.backend.base_backend import BaseBackend
|
8
|
-
from sglang.lang.chat_template import get_chat_template_by_model_path
|
9
|
-
from sglang.lang.interpreter import StreamExecutor
|
10
|
-
from sglang.lang.ir import SglSamplingParams
|
11
|
-
from sglang.utils import http_request
|
12
|
-
|
13
|
-
|
14
|
-
class TGI(BaseBackend):
|
15
|
-
def __init__(self, base_url):
|
16
|
-
super().__init__()
|
17
|
-
|
18
|
-
self.base_url = base_url
|
19
|
-
|
20
|
-
res = http_request(self.base_url + "/info")
|
21
|
-
assert res.status_code == 200
|
22
|
-
self.model_info = res.json()
|
23
|
-
self.chat_template = get_chat_template_by_model_path(
|
24
|
-
self.model_info["model_id"]
|
25
|
-
)
|
26
|
-
|
27
|
-
def get_model_name(self):
|
28
|
-
return self.model_info["model_id"]
|
29
|
-
|
30
|
-
def get_chat_template(self):
|
31
|
-
return self.chat_template
|
32
|
-
|
33
|
-
@staticmethod
|
34
|
-
def adapt_params(max_tokens, stop, sampling_params, **override_params):
|
35
|
-
temperature = sampling_params.temperature
|
36
|
-
do_sample = True
|
37
|
-
if temperature == 0:
|
38
|
-
do_sample = False
|
39
|
-
temperature = None
|
40
|
-
|
41
|
-
if stop is None:
|
42
|
-
stop = []
|
43
|
-
elif isinstance(stop, str):
|
44
|
-
stop = [stop]
|
45
|
-
|
46
|
-
top_p = sampling_params.top_p
|
47
|
-
if top_p == 0:
|
48
|
-
top_p = 0.001
|
49
|
-
if top_p == 1:
|
50
|
-
top_p = 0.999
|
51
|
-
|
52
|
-
top_k = sampling_params.top_k
|
53
|
-
if top_k == -1:
|
54
|
-
top_k = None
|
55
|
-
|
56
|
-
params = {
|
57
|
-
"decoder_input_details": False,
|
58
|
-
"details": False,
|
59
|
-
"do_sample": do_sample,
|
60
|
-
"max_new_tokens": max_tokens,
|
61
|
-
"stop": stop,
|
62
|
-
"temperature": temperature,
|
63
|
-
"top_p": top_p,
|
64
|
-
"top_k": top_k,
|
65
|
-
"return_full_text": False,
|
66
|
-
}
|
67
|
-
params.update(override_params)
|
68
|
-
return params
|
69
|
-
|
70
|
-
@staticmethod
|
71
|
-
def _extract_int(text):
|
72
|
-
words = re.split("\ |'|\/|\(|\)|\n|\.|,", text)
|
73
|
-
for word in words:
|
74
|
-
try:
|
75
|
-
int(word)
|
76
|
-
return word
|
77
|
-
except ValueError:
|
78
|
-
continue
|
79
|
-
raise ValueError
|
80
|
-
|
81
|
-
@staticmethod
|
82
|
-
def _extract_choice(choices, text):
|
83
|
-
# FIXME: Current only support the case where the choices are single words.
|
84
|
-
words = re.split("\ |'|\/|\(|\)|\n|\.|,", text)
|
85
|
-
for word in words:
|
86
|
-
if word in choices:
|
87
|
-
return word
|
88
|
-
raise ValueError
|
89
|
-
|
90
|
-
@staticmethod
|
91
|
-
def _truncate_to_stop(text, stop):
|
92
|
-
# The stop sequence may not be a single token. In this case TGI will generate
|
93
|
-
# too many tokens so we need to truncate the output.
|
94
|
-
if stop:
|
95
|
-
stop = [stop] if isinstance(stop, str) else stop
|
96
|
-
for stop_seq in stop:
|
97
|
-
pos = text.find(stop_seq)
|
98
|
-
if pos != -1:
|
99
|
-
return text[:pos]
|
100
|
-
return text
|
101
|
-
|
102
|
-
def _make_request(self, params):
|
103
|
-
res = http_request(self.base_url + "/generate", json=params)
|
104
|
-
if res.status_code != 200:
|
105
|
-
raise ValueError(f"Error from TGI backend: {res.text}")
|
106
|
-
return res.json()
|
107
|
-
|
108
|
-
def retry_for_expected(self, prompt, params, extract_fn, retry=5):
|
109
|
-
# TGI does not support logis_bias (yet), so we have to use an inefficient hack.
|
110
|
-
failed = []
|
111
|
-
while retry > 0:
|
112
|
-
res_json = self._make_request(
|
113
|
-
{
|
114
|
-
"inputs": prompt,
|
115
|
-
"parameters": params,
|
116
|
-
}
|
117
|
-
)
|
118
|
-
text = res_json["generated_text"]
|
119
|
-
try:
|
120
|
-
return extract_fn(text)
|
121
|
-
except ValueError:
|
122
|
-
retry -= 1
|
123
|
-
failed.append(text)
|
124
|
-
|
125
|
-
msg = "=" * 20 + "\n"
|
126
|
-
msg += f"Prompt:\n{prompt}\n"
|
127
|
-
msg += "=" * 20 + "\n"
|
128
|
-
for i, text in enumerate(failed):
|
129
|
-
msg += f"====== Try {i+1}:\n{text}\n"
|
130
|
-
|
131
|
-
raise ValueError(
|
132
|
-
f"Model {self.model_info['model_id']} served by TGI backend does not generate"
|
133
|
-
"expected output. Please improve the prompt, increase the temperature, or "
|
134
|
-
f"use different models.\n{msg}"
|
135
|
-
)
|
136
|
-
|
137
|
-
def select(
|
138
|
-
self,
|
139
|
-
s: StreamExecutor,
|
140
|
-
choices: List[str],
|
141
|
-
sampling_params: SglSamplingParams,
|
142
|
-
):
|
143
|
-
decision = self.retry_for_expected(
|
144
|
-
s.text_,
|
145
|
-
self.adapt_params(16, [], sampling_params),
|
146
|
-
partial(self._extract_choice, choices),
|
147
|
-
)
|
148
|
-
return decision, [1 if choice == decision else 0 for choice in choices]
|
149
|
-
|
150
|
-
def generate(
|
151
|
-
self,
|
152
|
-
s: StreamExecutor,
|
153
|
-
max_tokens: int,
|
154
|
-
stop: Union[str, List[str]],
|
155
|
-
sampling_params: SglSamplingParams,
|
156
|
-
dtype: Optional[str] = None,
|
157
|
-
):
|
158
|
-
if dtype is None:
|
159
|
-
res_json = self._make_request(
|
160
|
-
{
|
161
|
-
"inputs": s.text_,
|
162
|
-
"parameters": self.adapt_params(max_tokens, stop, sampling_params),
|
163
|
-
}
|
164
|
-
)
|
165
|
-
return self._truncate_to_stop(res_json["generated_text"], stop), {}
|
166
|
-
|
167
|
-
if dtype in [str, "str", "string"]:
|
168
|
-
stop = ['"']
|
169
|
-
res_json = self._make_request(
|
170
|
-
{
|
171
|
-
"inputs": f'{s.text_}"',
|
172
|
-
"parameters": self.adapt_params(max_tokens, stop, sampling_params),
|
173
|
-
}
|
174
|
-
)
|
175
|
-
return (
|
176
|
-
'"' + self._truncate_to_stop(res_json["generated_text"], stop) + '"',
|
177
|
-
{},
|
178
|
-
)
|
179
|
-
|
180
|
-
if dtype in [int, "int"]:
|
181
|
-
return (
|
182
|
-
self.retry_for_expected(
|
183
|
-
s.text_,
|
184
|
-
self.adapt_params(max_tokens, stop, sampling_params),
|
185
|
-
self._extract_int,
|
186
|
-
),
|
187
|
-
{},
|
188
|
-
)
|
189
|
-
|
190
|
-
raise ValueError(f"Unknown dtype: {dtype}")
|
File without changes
|
File without changes
|
File without changes
|