sglang 0.1.4__py3-none-any.whl → 0.1.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +1 -1
- sglang/api.py +1 -0
- sglang/backend/vertexai.py +147 -0
- sglang/lang/interpreter.py +8 -9
- sglang/lang/ir.py +21 -0
- sglang/srt/layers/context_flashattention_nopad.py +0 -1
- sglang/srt/layers/extend_attention.py +0 -1
- sglang/srt/managers/router/manager.py +2 -2
- sglang/srt/managers/router/model_rpc.py +6 -3
- sglang/srt/managers/router/model_runner.py +1 -1
- sglang/srt/models/mixtral.py +1 -1
- sglang/srt/server_args.py +22 -4
- sglang/test/test_programs.py +4 -1
- {sglang-0.1.4.dist-info → sglang-0.1.5.dist-info}/METADATA +26 -8
- {sglang-0.1.4.dist-info → sglang-0.1.5.dist-info}/RECORD +18 -19
- sglang/backend/huggingface.py +0 -349
- sglang/backend/tgi.py +0 -190
- {sglang-0.1.4.dist-info → sglang-0.1.5.dist-info}/LICENSE +0 -0
- {sglang-0.1.4.dist-info → sglang-0.1.5.dist-info}/WHEEL +0 -0
- {sglang-0.1.4.dist-info → sglang-0.1.5.dist-info}/top_level.txt +0 -0
sglang/__init__.py
CHANGED
sglang/api.py
CHANGED
@@ -6,6 +6,7 @@ from sglang.backend.anthropic import Anthropic
|
|
6
6
|
from sglang.backend.base_backend import BaseBackend
|
7
7
|
from sglang.backend.openai import OpenAI
|
8
8
|
from sglang.backend.runtime_endpoint import RuntimeEndpoint
|
9
|
+
from sglang.backend.vertexai import VertexAI
|
9
10
|
from sglang.global_config import global_config
|
10
11
|
from sglang.lang.ir import (
|
11
12
|
SglExpr,
|
@@ -0,0 +1,147 @@
|
|
1
|
+
import os
|
2
|
+
import warnings
|
3
|
+
from typing import List, Optional, Union
|
4
|
+
|
5
|
+
import numpy as np
|
6
|
+
from sglang.backend.base_backend import BaseBackend
|
7
|
+
from sglang.lang.chat_template import get_chat_template
|
8
|
+
from sglang.lang.interpreter import StreamExecutor
|
9
|
+
from sglang.lang.ir import SglSamplingParams
|
10
|
+
|
11
|
+
try:
|
12
|
+
import vertexai
|
13
|
+
from vertexai.preview.generative_models import (
|
14
|
+
GenerationConfig,
|
15
|
+
GenerativeModel,
|
16
|
+
Image,
|
17
|
+
)
|
18
|
+
except ImportError as e:
|
19
|
+
GenerativeModel = e
|
20
|
+
|
21
|
+
|
22
|
+
class VertexAI(BaseBackend):
|
23
|
+
def __init__(self, model_name):
|
24
|
+
super().__init__()
|
25
|
+
|
26
|
+
if isinstance(GenerativeModel, Exception):
|
27
|
+
raise GenerativeModel
|
28
|
+
|
29
|
+
project_id = os.environ["GCP_PROJECT_ID"]
|
30
|
+
location = os.environ.get("GCP_LOCATION")
|
31
|
+
vertexai.init(project=project_id, location=location)
|
32
|
+
|
33
|
+
self.model_name = model_name
|
34
|
+
self.chat_template = get_chat_template("default")
|
35
|
+
|
36
|
+
def get_chat_template(self):
|
37
|
+
return self.chat_template
|
38
|
+
|
39
|
+
def generate(
|
40
|
+
self,
|
41
|
+
s: StreamExecutor,
|
42
|
+
sampling_params: SglSamplingParams,
|
43
|
+
):
|
44
|
+
if s.messages_:
|
45
|
+
prompt = self.messages_to_vertexai_input(s.messages_)
|
46
|
+
else:
|
47
|
+
# single-turn
|
48
|
+
prompt = (
|
49
|
+
self.text_to_vertexai_input(s.text_, s.cur_images)
|
50
|
+
if s.cur_images
|
51
|
+
else s.text_
|
52
|
+
)
|
53
|
+
ret = GenerativeModel(self.model_name).generate_content(
|
54
|
+
prompt,
|
55
|
+
generation_config=GenerationConfig(**sampling_params.to_vertexai_kwargs()),
|
56
|
+
)
|
57
|
+
|
58
|
+
comp = ret.text
|
59
|
+
|
60
|
+
return comp, {}
|
61
|
+
|
62
|
+
def generate_stream(
|
63
|
+
self,
|
64
|
+
s: StreamExecutor,
|
65
|
+
sampling_params: SglSamplingParams,
|
66
|
+
):
|
67
|
+
if s.messages_:
|
68
|
+
prompt = self.messages_to_vertexai_input(s.messages_)
|
69
|
+
else:
|
70
|
+
# single-turn
|
71
|
+
prompt = (
|
72
|
+
self.text_to_vertexai_input(s.text_, s.cur_images)
|
73
|
+
if s.cur_images
|
74
|
+
else s.text_
|
75
|
+
)
|
76
|
+
generator = GenerativeModel(self.model_name).generate_content(
|
77
|
+
prompt,
|
78
|
+
stream=True,
|
79
|
+
generation_config=GenerationConfig(**sampling_params.to_vertexai_kwargs()),
|
80
|
+
)
|
81
|
+
for ret in generator:
|
82
|
+
yield ret.text, {}
|
83
|
+
|
84
|
+
def text_to_vertexai_input(self, text, images):
|
85
|
+
input = []
|
86
|
+
# split with image token
|
87
|
+
text_segs = text.split(self.chat_template.image_token)
|
88
|
+
for image_path, image_base64_data in images:
|
89
|
+
text_seg = text_segs.pop(0)
|
90
|
+
if text_seg != "":
|
91
|
+
input.append(text_seg)
|
92
|
+
input.append(Image.from_bytes(image_base64_data))
|
93
|
+
text_seg = text_segs.pop(0)
|
94
|
+
if text_seg != "":
|
95
|
+
input.append(text_seg)
|
96
|
+
return input
|
97
|
+
|
98
|
+
def messages_to_vertexai_input(self, messages):
|
99
|
+
vertexai_message = []
|
100
|
+
# from openai message format to vertexai message format
|
101
|
+
for msg in messages:
|
102
|
+
if isinstance(msg["content"], str):
|
103
|
+
text = msg["content"]
|
104
|
+
else:
|
105
|
+
text = msg["content"][0]["text"]
|
106
|
+
|
107
|
+
if msg["role"] == "system":
|
108
|
+
warnings.warn("Warning: system prompt is not supported in VertexAI.")
|
109
|
+
vertexai_message.append(
|
110
|
+
{
|
111
|
+
"role": "user",
|
112
|
+
"parts": [{"text": "System prompt: " + text}],
|
113
|
+
}
|
114
|
+
)
|
115
|
+
vertexai_message.append(
|
116
|
+
{
|
117
|
+
"role": "model",
|
118
|
+
"parts": [{"text": "Understood."}],
|
119
|
+
}
|
120
|
+
)
|
121
|
+
continue
|
122
|
+
if msg["role"] == "user":
|
123
|
+
vertexai_msg = {
|
124
|
+
"role": "user",
|
125
|
+
"parts": [{"text": text}],
|
126
|
+
}
|
127
|
+
elif msg["role"] == "assistant":
|
128
|
+
vertexai_msg = {
|
129
|
+
"role": "model",
|
130
|
+
"parts": [{"text": text}],
|
131
|
+
}
|
132
|
+
|
133
|
+
# images
|
134
|
+
if isinstance(msg["content"], list) and len(msg["content"]) > 1:
|
135
|
+
for image in msg["content"][1:]:
|
136
|
+
assert image["type"] == "image_url"
|
137
|
+
vertexai_msg["parts"].append(
|
138
|
+
{
|
139
|
+
"inline_data": {
|
140
|
+
"data": image["image_url"]["url"].split(",")[1],
|
141
|
+
"mime_type": "image/jpeg",
|
142
|
+
}
|
143
|
+
}
|
144
|
+
)
|
145
|
+
|
146
|
+
vertexai_message.append(vertexai_msg)
|
147
|
+
return vertexai_message
|
sglang/lang/interpreter.py
CHANGED
@@ -365,11 +365,10 @@ class StreamExecutor:
|
|
365
365
|
for comp, meta_info in generator:
|
366
366
|
self.text_ += comp
|
367
367
|
self.variables[name] += comp
|
368
|
+
self.meta_info[name] = meta_info
|
368
369
|
self.stream_var_event[name].set()
|
369
370
|
self.stream_text_event.set()
|
370
371
|
|
371
|
-
self.meta_info[name] = meta_info
|
372
|
-
|
373
372
|
self.variable_event[name].set()
|
374
373
|
self.stream_var_event[name].set()
|
375
374
|
|
@@ -428,6 +427,7 @@ class StreamExecutor:
|
|
428
427
|
self.messages_.append(last_msg)
|
429
428
|
self.cur_images = []
|
430
429
|
else:
|
430
|
+
# OpenAI chat API format
|
431
431
|
self.messages_.append({"role": expr.role, "content": new_text})
|
432
432
|
|
433
433
|
self.cur_role = None
|
@@ -582,7 +582,7 @@ class ProgramState:
|
|
582
582
|
else:
|
583
583
|
yield self.get_var(name)
|
584
584
|
|
585
|
-
async def text_async_iter(self, var_name=None):
|
585
|
+
async def text_async_iter(self, var_name=None, return_meta_data=False):
|
586
586
|
loop = asyncio.get_running_loop()
|
587
587
|
|
588
588
|
if self.stream_executor.stream:
|
@@ -606,7 +606,10 @@ class ProgramState:
|
|
606
606
|
out = str(self.stream_executor.variables[var_name][prev:])
|
607
607
|
prev += len(out)
|
608
608
|
if out:
|
609
|
-
|
609
|
+
if return_meta_data:
|
610
|
+
yield out, self.stream_executor.meta_info[var_name]
|
611
|
+
else:
|
612
|
+
yield out
|
610
613
|
if self.stream_executor.variable_event[var_name].is_set():
|
611
614
|
break
|
612
615
|
else:
|
@@ -632,11 +635,7 @@ class ProgramState:
|
|
632
635
|
self.stream_executor.end()
|
633
636
|
|
634
637
|
def __repr__(self) -> str:
|
635
|
-
|
636
|
-
ret = ""
|
637
|
-
for msg in msgs:
|
638
|
-
ret += msg["role"] + ":\n" + msg["content"] + "\n"
|
639
|
-
return ret
|
638
|
+
return f"ProgramState({self.text()})"
|
640
639
|
|
641
640
|
|
642
641
|
class ProgramStateGroup:
|
sglang/lang/ir.py
CHANGED
@@ -2,6 +2,7 @@
|
|
2
2
|
|
3
3
|
import dataclasses
|
4
4
|
import inspect
|
5
|
+
import warnings
|
5
6
|
from typing import List, Optional, Union
|
6
7
|
|
7
8
|
from sglang.global_config import global_config
|
@@ -40,6 +41,8 @@ class SglSamplingParams:
|
|
40
41
|
|
41
42
|
def to_openai_kwargs(self):
|
42
43
|
# OpenAI does not support top_k, so we drop it here
|
44
|
+
if self.regex is not None:
|
45
|
+
warnings.warn("Regular expression is not supported in the OpenAI backend.")
|
43
46
|
return {
|
44
47
|
"max_tokens": self.max_new_tokens,
|
45
48
|
"stop": self.stop or None,
|
@@ -49,8 +52,26 @@ class SglSamplingParams:
|
|
49
52
|
"presence_penalty": self.presence_penalty,
|
50
53
|
}
|
51
54
|
|
55
|
+
def to_vertexai_kwargs(self):
|
56
|
+
if self.regex is not None:
|
57
|
+
warnings.warn(
|
58
|
+
"Regular expression is not supported in the VertexAI backend."
|
59
|
+
)
|
60
|
+
return {
|
61
|
+
"candidate_count": 1,
|
62
|
+
"max_output_tokens": self.max_new_tokens,
|
63
|
+
"stop_sequences": self.stop,
|
64
|
+
"temperature": self.temperature,
|
65
|
+
"top_p": self.top_p,
|
66
|
+
"top_k": self.top_k if self.top_k > 0 else None,
|
67
|
+
}
|
68
|
+
|
52
69
|
def to_anthropic_kwargs(self):
|
53
70
|
# Anthropic does not support frequency_penalty or presence_penalty, so we drop it here
|
71
|
+
if self.regex is not None:
|
72
|
+
warnings.warn(
|
73
|
+
"Regular expression is not supported in the Anthropic backend."
|
74
|
+
)
|
54
75
|
return {
|
55
76
|
"max_tokens_to_sample": self.max_new_tokens,
|
56
77
|
"stop_sequences": self.stop,
|
@@ -28,7 +28,7 @@ class RouterManager:
|
|
28
28
|
self.model_client = model_client
|
29
29
|
self.recv_reqs = []
|
30
30
|
|
31
|
-
# Init
|
31
|
+
# Init some configs
|
32
32
|
self.extend_dependency_time = GLOBAL_BACKEND_CONFIG.extend_dependency_time
|
33
33
|
|
34
34
|
async def loop_for_forward(self):
|
@@ -46,7 +46,7 @@ class RouterManager:
|
|
46
46
|
if has_finished:
|
47
47
|
await asyncio.sleep(self.extend_dependency_time)
|
48
48
|
|
49
|
-
await asyncio.sleep(0.
|
49
|
+
await asyncio.sleep(0.0006)
|
50
50
|
|
51
51
|
async def loop_for_recv_requests(self):
|
52
52
|
while True:
|
@@ -2,10 +2,10 @@ import asyncio
|
|
2
2
|
import logging
|
3
3
|
import multiprocessing
|
4
4
|
import time
|
5
|
+
import warnings
|
5
6
|
from concurrent.futures import ThreadPoolExecutor
|
6
7
|
from enum import Enum, auto
|
7
8
|
from typing import Dict, List, Optional, Tuple, Union
|
8
|
-
import warnings
|
9
9
|
|
10
10
|
import numpy as np
|
11
11
|
import rpyc
|
@@ -45,6 +45,7 @@ class ModelRpcServer(rpyc.Service):
|
|
45
45
|
self.tp_rank = tp_rank
|
46
46
|
self.tp_size = server_args.tp_size
|
47
47
|
self.schedule_heuristic = server_args.schedule_heuristic
|
48
|
+
self.schedule_conservativeness = server_args.schedule_conservativeness
|
48
49
|
|
49
50
|
# Init model and tokenizer
|
50
51
|
self.model_config = ModelConfig(
|
@@ -108,7 +109,7 @@ class ModelRpcServer(rpyc.Service):
|
|
108
109
|
self.running_batch: Batch = None
|
109
110
|
self.out_pyobjs = []
|
110
111
|
self.decode_forward_ct = 0
|
111
|
-
self.stream_interval =
|
112
|
+
self.stream_interval = server_args.stream_interval
|
112
113
|
|
113
114
|
# Init the FSM cache for constrained generation
|
114
115
|
self.regex_fsm_cache = FSMCache(self.tokenizer)
|
@@ -248,7 +249,9 @@ class ModelRpcServer(rpyc.Service):
|
|
248
249
|
available_size = (
|
249
250
|
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
|
250
251
|
)
|
251
|
-
new_ratio =
|
252
|
+
new_ratio = (
|
253
|
+
self.scheduler.new_token_estimation_ratio() * self.schedule_conservativeness
|
254
|
+
)
|
252
255
|
if self.running_batch:
|
253
256
|
available_size -= sum(
|
254
257
|
[
|
sglang/srt/models/mixtral.py
CHANGED
@@ -355,7 +355,7 @@ class MixtralForCausalLM(nn.Module):
|
|
355
355
|
):
|
356
356
|
if "rotary_emb.inv_freq" in name:
|
357
357
|
continue
|
358
|
-
for
|
358
|
+
for param_name, weight_name, shard_id in stacked_params_mapping:
|
359
359
|
if weight_name not in name:
|
360
360
|
continue
|
361
361
|
name = name.replace(weight_name, param_name)
|
sglang/srt/server_args.py
CHANGED
@@ -16,7 +16,9 @@ class ServerArgs:
|
|
16
16
|
tp_size: int = 1
|
17
17
|
model_mode: List[str] = ()
|
18
18
|
schedule_heuristic: str = "lpm"
|
19
|
+
schedule_conservativeness: float = 1.0
|
19
20
|
random_seed: int = 42
|
21
|
+
stream_interval: int = 2
|
20
22
|
disable_log_stats: bool = False
|
21
23
|
log_stats_interval: int = 10
|
22
24
|
log_level: str = "info"
|
@@ -25,10 +27,14 @@ class ServerArgs:
|
|
25
27
|
if self.tokenizer_path is None:
|
26
28
|
self.tokenizer_path = self.model_path
|
27
29
|
if self.mem_fraction_static is None:
|
28
|
-
if self.tp_size
|
29
|
-
self.mem_fraction_static = 0.
|
30
|
+
if self.tp_size >= 8:
|
31
|
+
self.mem_fraction_static = 0.80
|
32
|
+
elif self.tp_size >= 4:
|
33
|
+
self.mem_fraction_static = 0.82
|
34
|
+
elif self.tp_size >= 2:
|
35
|
+
self.mem_fraction_static = 0.85
|
30
36
|
else:
|
31
|
-
self.mem_fraction_static = 0.
|
37
|
+
self.mem_fraction_static = 0.90
|
32
38
|
|
33
39
|
@staticmethod
|
34
40
|
def add_cli_args(parser: argparse.ArgumentParser):
|
@@ -80,7 +86,7 @@ class ServerArgs:
|
|
80
86
|
"--mem-fraction-static",
|
81
87
|
type=float,
|
82
88
|
default=ServerArgs.mem_fraction_static,
|
83
|
-
help="The fraction of the memory used for static allocation (model weights and KV cache memory pool)",
|
89
|
+
help="The fraction of the memory used for static allocation (model weights and KV cache memory pool). Use a smaller value if you see out-of-memory errors.",
|
84
90
|
)
|
85
91
|
parser.add_argument(
|
86
92
|
"--tp-size",
|
@@ -102,12 +108,24 @@ class ServerArgs:
|
|
102
108
|
default=ServerArgs.schedule_heuristic,
|
103
109
|
help="Schudule mode: [lpm, weight, random, fcfs]",
|
104
110
|
)
|
111
|
+
parser.add_argument(
|
112
|
+
"--schedule-conservativeness",
|
113
|
+
type=float,
|
114
|
+
default=ServerArgs.schedule_conservativeness,
|
115
|
+
help="How conservative the schedule policy is. A larger value means more conservative scheduling. Use a larger value if you see out-of-memory errors.",
|
116
|
+
)
|
105
117
|
parser.add_argument(
|
106
118
|
"--random-seed",
|
107
119
|
type=int,
|
108
120
|
default=ServerArgs.random_seed,
|
109
121
|
help="Random seed.",
|
110
122
|
)
|
123
|
+
parser.add_argument(
|
124
|
+
"--stream-interval",
|
125
|
+
type=int,
|
126
|
+
default=ServerArgs.stream_interval,
|
127
|
+
help="The interval in terms of token length for streaming",
|
128
|
+
)
|
111
129
|
parser.add_argument(
|
112
130
|
"--log-level",
|
113
131
|
type=str,
|
sglang/test/test_programs.py
CHANGED
@@ -304,7 +304,10 @@ def test_image_qa():
|
|
304
304
|
temperature=0,
|
305
305
|
max_new_tokens=64,
|
306
306
|
)
|
307
|
-
assert
|
307
|
+
assert (
|
308
|
+
"taxi" in state.messages()[-1]["content"]
|
309
|
+
or "car" in state.messages()[-1]["content"]
|
310
|
+
)
|
308
311
|
|
309
312
|
|
310
313
|
def test_stream():
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: sglang
|
3
|
-
Version: 0.1.
|
3
|
+
Version: 0.1.5
|
4
4
|
Summary: A structured generation langauge for LLMs.
|
5
5
|
License: Apache License
|
6
6
|
Version 2.0, January 2004
|
@@ -234,6 +234,7 @@ Requires-Dist: lark ; extra == 'srt'
|
|
234
234
|
Requires-Dist: numba ; extra == 'srt'
|
235
235
|
|
236
236
|
# SGLang
|
237
|
+
| [**Blog**](https://lmsys.org/blog/2024-01-17-sglang/) | [**Paper**](https://arxiv.org/abs/2312.07104) |
|
237
238
|
|
238
239
|
SGLang is a structured generation language designed for large language models (LLMs).
|
239
240
|
It makes your interaction with LLMs faster and more controllable by co-designing the frontend language and the runtime system.
|
@@ -277,7 +278,7 @@ The example below shows how to use sglang to answer a mulit-turn question.
|
|
277
278
|
### Using OpenAI Models
|
278
279
|
Set the OpenAI API Key
|
279
280
|
```
|
280
|
-
export OPENAI_API_KEY=sk
|
281
|
+
export OPENAI_API_KEY=sk-******
|
281
282
|
```
|
282
283
|
|
283
284
|
Then, answer a multi-turn question.
|
@@ -335,6 +336,7 @@ for m in state.messages():
|
|
335
336
|
|
336
337
|
### More Examples
|
337
338
|
|
339
|
+
Anthropic and VertexAI (Gemini) models are also supported.
|
338
340
|
You can find more examples at [examples/quick_start](examples/quick_start).
|
339
341
|
|
340
342
|
## Frontend: Structured Generation Langauge (SGLang)
|
@@ -350,13 +352,14 @@ You can then invoke the function with `run` or `run_batch`.
|
|
350
352
|
The system will manage the state, chat template, and parallelism for you.
|
351
353
|
|
352
354
|
### Control Flow
|
355
|
+
You can use any Python code within the function body, including control flow, nested function calls, and external libraries.
|
356
|
+
|
353
357
|
```python
|
354
358
|
@sgl.function
|
355
359
|
def control_flow(s, question):
|
356
360
|
s += "To answer this question: " + question + ", "
|
357
361
|
s += "I need to use a " + sgl.gen("tool", choices=["calculator", "web browser"]) + ". "
|
358
362
|
|
359
|
-
# You can use if or nested function calls
|
360
363
|
if s["tool"] == "calculator":
|
361
364
|
s += "The math expression is" + sgl.gen("expression")
|
362
365
|
elif s["tool"] == "web browser":
|
@@ -364,6 +367,9 @@ def control_flow(s, question):
|
|
364
367
|
```
|
365
368
|
|
366
369
|
### Parallelism
|
370
|
+
Use `fork` to launch parallel prompts.
|
371
|
+
Because `sgl.gen` is non-blocking, the for loop below issues two generation calls in parallel.
|
372
|
+
|
367
373
|
```python
|
368
374
|
@sgl.function
|
369
375
|
def tip_suggestion(s):
|
@@ -372,7 +378,7 @@ def tip_suggestion(s):
|
|
372
378
|
"1. Balanced Diet. 2. Regular Exercise.\n\n"
|
373
379
|
)
|
374
380
|
|
375
|
-
forks = s.fork(2)
|
381
|
+
forks = s.fork(2)
|
376
382
|
for i, f in enumerate(forks):
|
377
383
|
f += f"Now, expand tip {i+1} into a paragraph:\n"
|
378
384
|
f += sgl.gen(f"detailed_tip", max_tokens=256, stop="\n\n")
|
@@ -383,6 +389,8 @@ def tip_suggestion(s):
|
|
383
389
|
```
|
384
390
|
|
385
391
|
### Multi Modality
|
392
|
+
Use `sgl.image` to pass an image as input.
|
393
|
+
|
386
394
|
```python
|
387
395
|
@sgl.function
|
388
396
|
def image_qa(s, image_file, question):
|
@@ -391,6 +399,8 @@ def image_qa(s, image_file, question):
|
|
391
399
|
```
|
392
400
|
|
393
401
|
### Constrained Decoding
|
402
|
+
Use `regex=` to specify a regular expression as a decoding constraint.
|
403
|
+
|
394
404
|
```python
|
395
405
|
@sgl.function
|
396
406
|
def regular_expression_gen(s):
|
@@ -403,6 +413,8 @@ def regular_expression_gen(s):
|
|
403
413
|
```
|
404
414
|
|
405
415
|
### Batching
|
416
|
+
Use `run_batch` to run a batch of requests with continuous batching.
|
417
|
+
|
406
418
|
```python
|
407
419
|
@sgl.function
|
408
420
|
def text_qa(s, question):
|
@@ -415,10 +427,13 @@ states = text_qa.run_batch(
|
|
415
427
|
{"question": "What is the capital of France?"},
|
416
428
|
{"question": "What is the capital of Japan?"},
|
417
429
|
],
|
430
|
+
progress_bar=True
|
418
431
|
)
|
419
432
|
```
|
420
433
|
|
421
434
|
### Streaming
|
435
|
+
Add `stream=True` to enable streaming.
|
436
|
+
|
422
437
|
```python
|
423
438
|
@sgl.function
|
424
439
|
def text_qa(s, question):
|
@@ -427,7 +442,9 @@ def text_qa(s, question):
|
|
427
442
|
|
428
443
|
states = text_qa.run(
|
429
444
|
question="What is the capital of France?",
|
430
|
-
temperature=0.1
|
445
|
+
temperature=0.1,
|
446
|
+
stream=True
|
447
|
+
)
|
431
448
|
|
432
449
|
for out in state.text_iter():
|
433
450
|
print(out, end="", flush=True)
|
@@ -471,6 +488,7 @@ python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port
|
|
471
488
|
- Mixtral
|
472
489
|
- LLaVA
|
473
490
|
- `python3 -m sglang.launch_server --model-path liuhaotian/llava-v1.5-7b --tokenizer-path llava-hf/llava-1.5-7b-hf --port 30000`
|
491
|
+
- AWQ quantization
|
474
492
|
|
475
493
|
## Benchmark And Performance
|
476
494
|
|
@@ -483,10 +501,10 @@ python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port
|
|
483
501
|
Learn more [here](docs/benchmark_results.md).
|
484
502
|
|
485
503
|
## Roadmap
|
486
|
-
- [ ] Function call
|
487
|
-
- [ ] Quantization
|
504
|
+
- [ ] Function call APIs
|
488
505
|
- [ ] S-LoRA
|
489
|
-
- [ ]
|
506
|
+
- [ ] Support more models
|
507
|
+
- [ ] Support more hardware backends
|
490
508
|
|
491
509
|
## Citation And Acknowledgment
|
492
510
|
```
|
@@ -1,5 +1,5 @@
|
|
1
|
-
sglang/__init__.py,sha256=
|
2
|
-
sglang/api.py,sha256=
|
1
|
+
sglang/__init__.py,sha256=G73L_PWJ_6mF3NIE4ZAOWcb1CUbETSeRWr3wDTePrZ4,95
|
2
|
+
sglang/api.py,sha256=SxmPP_PMYi4DfUcwz_V9UvYOwGmQdHPgpMV6jDDJq68,3928
|
3
3
|
sglang/flush_cache.py,sha256=cCD_MTlQ5qEv__w0nOthDnVitdAfyscYjksBljwC5Mw,1835
|
4
4
|
sglang/global_config.py,sha256=PAX7TWeFcq0HBzNUWyCONAOjqIokWqw8vT7I6sBSKTc,797
|
5
5
|
sglang/launch_server.py,sha256=jKPZRDN5bUe8Wgz5eoDkqeePhmKa8DLD4DpXQLT5auo,294
|
@@ -7,15 +7,14 @@ sglang/utils.py,sha256=tvJs95QGZ_PcnTjvm-CDGQ8dJe84qUUOfG7BeF79nsA,5670
|
|
7
7
|
sglang/backend/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
8
8
|
sglang/backend/anthropic.py,sha256=y5TN9EDrJtOH4JEUxpXu-endloeYBy7xMUr3r7Ah3MA,1462
|
9
9
|
sglang/backend/base_backend.py,sha256=pPalZfoezxnUBs752j7lm0uMwa8tZuCWd-ijSdStMO8,1745
|
10
|
-
sglang/backend/huggingface.py,sha256=roQlt8y41PQbmnAY47CXiR0KJaxhtljH6j8RhbsR4f0,10533
|
11
10
|
sglang/backend/openai.py,sha256=umTWzC2p4PypDaXHe6Kc8By5IM_Doi0Ob97vK_fFWDc,7367
|
12
11
|
sglang/backend/runtime_endpoint.py,sha256=rIhwtKJaLLCJAc6q6kqxEVC8xO_NNjmJs7BnxlOydLM,5860
|
13
|
-
sglang/backend/
|
12
|
+
sglang/backend/vertexai.py,sha256=BLfWf_tEgoHY9srCufJM5PLe3tql2j0G6ia7cPykxCM,4713
|
14
13
|
sglang/lang/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
15
14
|
sglang/lang/chat_template.py,sha256=1x4724K2oxu7VID40-5Megk7SbZb97PQCbRjLpoescU,5599
|
16
15
|
sglang/lang/compiler.py,sha256=wNn_UqV6Sxl22mv-PpzFUtRgiFFV-Y4OYpO4LshEoRM,7527
|
17
|
-
sglang/lang/interpreter.py,sha256=
|
18
|
-
sglang/lang/ir.py,sha256=
|
16
|
+
sglang/lang/interpreter.py,sha256=0WTJxCB57WDBr_E6kW39wByhcG2nRFjEMTzOjAaNhrY,22453
|
17
|
+
sglang/lang/ir.py,sha256=uUnBRyaM-8suVOEb2qf4EAt_VN2pWbXV6V88jLk6wsI,13160
|
19
18
|
sglang/lang/tracer.py,sha256=zH9DENdJBPEvWkThgwqvHOW7aC1EPC8xha_WpEj-SRs,8243
|
20
19
|
sglang/srt/backend_config.py,sha256=7MdHjNsZeAKB9IWWxyrvyOjJJAdI5tl9hWl-MV7yHrI,226
|
21
20
|
sglang/srt/hf_transformers_utils.py,sha256=soRyYLoCn7GxgxvonufGFkdFBA3eH5i3Izk_wi7p1l0,5285
|
@@ -23,14 +22,14 @@ sglang/srt/memory_pool.py,sha256=cN3Lrs9fn0DFmt67_IN4g06mPzKUxpbAJGUw4O33xbo,360
|
|
23
22
|
sglang/srt/model_config.py,sha256=R7YaR8H8AmCJl_1XcSP0zII_5ebZNl0wMXNVANGWd2c,997
|
24
23
|
sglang/srt/sampling_params.py,sha256=Sd9l_uIIuS_mhbzljKwTGDO9ESMviNOYGxOifc71RrY,2895
|
25
24
|
sglang/srt/server.py,sha256=XxTS1K4N5y-ZknLBQefxk1UxC50l6DABVqJOrJ-NG74,6388
|
26
|
-
sglang/srt/server_args.py,sha256=
|
25
|
+
sglang/srt/server_args.py,sha256=ojox8nu2tgPEy_JlKKEvRenby4HKkmWk-1MpHy3PmnI,5771
|
27
26
|
sglang/srt/utils.py,sha256=-2F99bqYT99x1jScMjciJxgQec6CaH6PcCHSmrKHhhY,5692
|
28
27
|
sglang/srt/constrained/fsm.py,sha256=H4kXSsV4IX2ow5TMmnmd-8ho4qqJ5mpVZ4MOH5FUtnY,12900
|
29
28
|
sglang/srt/constrained/fsm_cache.py,sha256=KX4bFX5hj0W66SC9pSvst1ew7etaOMTtTC75z0enRME,1087
|
30
29
|
sglang/srt/constrained/regex.py,sha256=CcV7KBOKS2ZxGoEr6BHG5okagNIGEXYvGvhKXu5gtDA,18689
|
31
30
|
sglang/srt/constrained/tokenizer.py,sha256=rei9yKHFETcbDPOpI7bpIYdrBFgIBhGr_U-zb3r5Beo,7951
|
32
|
-
sglang/srt/layers/context_flashattention_nopad.py,sha256=
|
33
|
-
sglang/srt/layers/extend_attention.py,sha256=
|
31
|
+
sglang/srt/layers/context_flashattention_nopad.py,sha256=GkjLiTkS4px_uLcW0aDocE3_OBXtujZ-SlsN2b2U7ng,5204
|
32
|
+
sglang/srt/layers/extend_attention.py,sha256=pWVE6ySnPiVLFON__bie73eDhmXHk4tECMK8zTiJNbI,12558
|
34
33
|
sglang/srt/layers/get_selected_logprob.py,sha256=CpMXM9WXMSB-AVaxBB_aVl1Qx_ZtAFFnjDTm4CgNDpU,2199
|
35
34
|
sglang/srt/layers/logits_processor.py,sha256=rwcXwdZ7-dW9zvJX3MF_EHSxMLbU7TIQ9xUIYRu-WAs,3013
|
36
35
|
sglang/srt/layers/radix_attention.py,sha256=hmPNFg2TkN4EAVUj376N_89RRtUYRwFgUpjj5SydnRk,6170
|
@@ -40,18 +39,18 @@ sglang/srt/managers/io_struct.py,sha256=5jMWj6_U8yTQd5V3tpDtThnoFyF0A3ln-4Z5bSL3
|
|
40
39
|
sglang/srt/managers/openai_protocol.py,sha256=Eid_734Wup4jsL1ZS2Op0vwRuzvNbF4mV2UcwFxqEvI,327
|
41
40
|
sglang/srt/managers/tokenizer_manager.py,sha256=jVwr0lM18RFJLhDb5TWlUpQ4Q8tALT4L6GY0jmaZkLw,7861
|
42
41
|
sglang/srt/managers/router/infer_batch.py,sha256=UfS1uVhGnM-62Xv1cfu_IoTeIUxkjkKc4W3trtGbadc,11541
|
43
|
-
sglang/srt/managers/router/manager.py,sha256=
|
44
|
-
sglang/srt/managers/router/model_rpc.py,sha256=
|
45
|
-
sglang/srt/managers/router/model_runner.py,sha256=
|
42
|
+
sglang/srt/managers/router/manager.py,sha256=AVCdYKKYcIQsIwpudkfFY4jh6M--ubLjXeYGzfi2ebw,2528
|
43
|
+
sglang/srt/managers/router/model_rpc.py,sha256=CR3qbHvShttlC19qAZ8B8nhT6UPobeu2Dy3Z0n6WdC8,19448
|
44
|
+
sglang/srt/managers/router/model_runner.py,sha256=IhSdpBcd54HN01HDi_PAkJztFxEGDcnktdoPZDWEx3s,16487
|
46
45
|
sglang/srt/managers/router/radix_cache.py,sha256=ZQPm9HhQ7vD3Gl5nhuvw3ZW4ZRARcplqWed1GYUvHCg,6441
|
47
46
|
sglang/srt/managers/router/scheduler.py,sha256=ejuIRwqqMZVXFKUionRJxy5AtNvK25YoGRO9rFY-rc8,2926
|
48
47
|
sglang/srt/models/llama2.py,sha256=D3j-NtyM8PA74UhXM7wSPogI2HKX-JcQAWcOusrZZo0,11320
|
49
48
|
sglang/srt/models/llava.py,sha256=COS0IC6Yo-QiwKe5emgCbtEe9HgaSu5tt6CQA7UtV38,8533
|
50
|
-
sglang/srt/models/mixtral.py,sha256=
|
51
|
-
sglang/test/test_programs.py,sha256=
|
49
|
+
sglang/srt/models/mixtral.py,sha256=frd2XsNZwP0XsQtRiYhgy4PErLNLgtIsLakmNrOKBAU,13712
|
50
|
+
sglang/test/test_programs.py,sha256=EovA2xL7fODcTbFj2wAAmYKlg1mLZ1x1BRU6nrXFRdE,11416
|
52
51
|
sglang/test/test_utils.py,sha256=Knxg3BTA6d_7XSlprbBCdvfDr2SN5x7LhkT-tZFk5EQ,4828
|
53
|
-
sglang-0.1.
|
54
|
-
sglang-0.1.
|
55
|
-
sglang-0.1.
|
56
|
-
sglang-0.1.
|
57
|
-
sglang-0.1.
|
52
|
+
sglang-0.1.5.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
53
|
+
sglang-0.1.5.dist-info/METADATA,sha256=aepmAL6VoXRcxZBIDKvxwikCYSbvWFm_JFGTxb3Mgfw,23345
|
54
|
+
sglang-0.1.5.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
|
55
|
+
sglang-0.1.5.dist-info/top_level.txt,sha256=yxhh3pYQkcnA7v3Bg889C2jZhvtJdEincysO7PEB09M,7
|
56
|
+
sglang-0.1.5.dist-info/RECORD,,
|
sglang/backend/huggingface.py
DELETED
@@ -1,349 +0,0 @@
|
|
1
|
-
import functools
|
2
|
-
from enum import Enum, auto
|
3
|
-
from typing import Callable, List, Optional, Union
|
4
|
-
|
5
|
-
import numpy as np
|
6
|
-
import torch
|
7
|
-
import transformers
|
8
|
-
from sglang.backend.base_backend import BaseBackend
|
9
|
-
from sglang.lang.chat_template import get_chat_template_by_model_path
|
10
|
-
from sglang.lang.interpreter import ProgramState
|
11
|
-
from sglang.utils import get_available_gpu_memory
|
12
|
-
from transformers import (
|
13
|
-
AutoModelForCausalLM,
|
14
|
-
AutoTokenizer,
|
15
|
-
StoppingCriteria,
|
16
|
-
StoppingCriteriaList,
|
17
|
-
)
|
18
|
-
from transformersgl.generation.logits_process import (
|
19
|
-
LogitsProcessorList,
|
20
|
-
RepetitionPenaltyLogitsProcessor,
|
21
|
-
TemperatureLogitsWarper,
|
22
|
-
TopKLogitsWarper,
|
23
|
-
TopPLogitsWarper,
|
24
|
-
)
|
25
|
-
|
26
|
-
|
27
|
-
class StopReason(Enum):
|
28
|
-
EOS_TOKEN = auto()
|
29
|
-
STOP_STR = auto()
|
30
|
-
LENGTH = auto()
|
31
|
-
|
32
|
-
|
33
|
-
def load_model(
|
34
|
-
model_name: str,
|
35
|
-
device,
|
36
|
-
num_gpus,
|
37
|
-
max_gpu_memory,
|
38
|
-
model_kwargs=None,
|
39
|
-
tokenizer_kwargs=None,
|
40
|
-
):
|
41
|
-
model_kwargs = model_kwargs or {}
|
42
|
-
tokenizer_kwargs = tokenizer_kwargs or {}
|
43
|
-
|
44
|
-
if device == "cuda":
|
45
|
-
model_kwargs["torch_dtype"] = torch.float16
|
46
|
-
if num_gpus != 1:
|
47
|
-
model_kwargs["device_map"] = "auto"
|
48
|
-
if max_gpu_memory is None:
|
49
|
-
model_kwargs[
|
50
|
-
"device_map"
|
51
|
-
] = "sequential" # This is important for not the same VRAM sizes
|
52
|
-
available_gpu_memory = [
|
53
|
-
get_available_gpu_memory(i, False) for i in range(num_gpus)
|
54
|
-
]
|
55
|
-
model_kwargs["max_memory"] = {
|
56
|
-
i: str(int(available_gpu_memory[i] * 0.85)) + "GiB"
|
57
|
-
for i in range(num_gpus)
|
58
|
-
}
|
59
|
-
else:
|
60
|
-
model_kwargs["max_memory"] = {
|
61
|
-
i: max_gpu_memory for i in range(num_gpus)
|
62
|
-
}
|
63
|
-
elif device == "cpu":
|
64
|
-
model_kwargs["torch_dtype"] = torch.float32
|
65
|
-
else:
|
66
|
-
raise ValueError(f"Invalid device: {device}")
|
67
|
-
|
68
|
-
model = AutoModelForCausalLM.from_pretrained(
|
69
|
-
model_name, low_cpu_mem_usage=True, **model_kwargs
|
70
|
-
)
|
71
|
-
tokenizer = AutoTokenizer.from_pretrained(model_name, **tokenizer_kwargs)
|
72
|
-
|
73
|
-
if num_gpus == 1:
|
74
|
-
model.to(device).eval()
|
75
|
-
|
76
|
-
return model, tokenizer
|
77
|
-
|
78
|
-
|
79
|
-
def prepare_logits_processor(
|
80
|
-
temperature: float, repetition_penalty: float, top_p: float, top_k: int
|
81
|
-
) -> LogitsProcessorList:
|
82
|
-
processor_list = LogitsProcessorList()
|
83
|
-
# TemperatureLogitsWarper doesn't accept 0.0, 1.0 makes it a no-op so we skip two cases.
|
84
|
-
if temperature >= 1e-5 and temperature != 1.0:
|
85
|
-
processor_list.append(TemperatureLogitsWarper(temperature))
|
86
|
-
if repetition_penalty > 1.0:
|
87
|
-
processor_list.append(RepetitionPenaltyLogitsProcessor(repetition_penalty))
|
88
|
-
if 1e-8 <= top_p < 1.0:
|
89
|
-
processor_list.append(TopPLogitsWarper(top_p))
|
90
|
-
if top_k > 0:
|
91
|
-
processor_list.append(TopKLogitsWarper(top_k))
|
92
|
-
return processor_list
|
93
|
-
|
94
|
-
|
95
|
-
@functools.lru_cache
|
96
|
-
def get_token_healing_mask(tokenizer, prompt_last_token):
|
97
|
-
last_str = tokenizer.convert_ids_to_tokens(prompt_last_token)
|
98
|
-
disallowed = torch.zeros(len(tokenizer), dtype=bool)
|
99
|
-
for s, t_id in tokenizer.get_vocab().items():
|
100
|
-
if not s.startswith(last_str):
|
101
|
-
disallowed[t_id] = 1
|
102
|
-
return disallowed
|
103
|
-
|
104
|
-
|
105
|
-
@functools.lru_cache
|
106
|
-
def get_int_token_mask(tokenizer):
|
107
|
-
disallowed = torch.zeros(len(tokenizer), dtype=bool)
|
108
|
-
for s, t_id in tokenizer.get_vocab().items():
|
109
|
-
s = s.replace("▁", "").strip()
|
110
|
-
if not (s.isdigit() or len(s) == 0 or s == ","):
|
111
|
-
disallowed[t_id] = 1
|
112
|
-
disallowed[tokenizer.eos_token_id] = 0
|
113
|
-
return disallowed
|
114
|
-
|
115
|
-
|
116
|
-
@torch.inference_mode()
|
117
|
-
def generate_stream(
|
118
|
-
model,
|
119
|
-
tokenizer,
|
120
|
-
prompt,
|
121
|
-
max_new_tokens,
|
122
|
-
stop: List[str],
|
123
|
-
temperature,
|
124
|
-
top_p,
|
125
|
-
token_healing,
|
126
|
-
logit_mask=None,
|
127
|
-
):
|
128
|
-
logits_processor = prepare_logits_processor(
|
129
|
-
temperature=temperature, repetition_penalty=1.0, top_p=top_p, top_k=0
|
130
|
-
)
|
131
|
-
device = model.device
|
132
|
-
input_ids = tokenizer.encode(prompt)
|
133
|
-
output_ids = list(input_ids)
|
134
|
-
prompt_len = len(prompt)
|
135
|
-
|
136
|
-
# Resolve stop
|
137
|
-
stop_token_ids = [tokenizer.eos_token_id]
|
138
|
-
|
139
|
-
# Token healing
|
140
|
-
token_healing = token_healing and len(input_ids) > 0
|
141
|
-
if token_healing:
|
142
|
-
token_healing_mask = get_token_healing_mask(tokenizer, input_ids[-1])
|
143
|
-
del output_ids[-1]
|
144
|
-
|
145
|
-
# Generate
|
146
|
-
past_key_values = None
|
147
|
-
stop_reason = None
|
148
|
-
for i in range(max_new_tokens):
|
149
|
-
# Forward
|
150
|
-
if i == 0: # prefill
|
151
|
-
out = model(torch.as_tensor([output_ids], device=device), use_cache=True)
|
152
|
-
else: # decoding
|
153
|
-
out = model(
|
154
|
-
input_ids=torch.as_tensor([[token]], device=device),
|
155
|
-
use_cache=True,
|
156
|
-
past_key_values=past_key_values,
|
157
|
-
)
|
158
|
-
logits = out.logits
|
159
|
-
past_key_values = out.past_key_values
|
160
|
-
|
161
|
-
# Logit mask
|
162
|
-
if token_healing and i == 0:
|
163
|
-
logits[0, -1, token_healing_mask] = -1e4
|
164
|
-
if logit_mask is not None:
|
165
|
-
logits[0, -1, logit_mask] = -1e4
|
166
|
-
|
167
|
-
# Sample next token
|
168
|
-
last_token_logits = logits_processor(None, logits[:, -1, :])[0]
|
169
|
-
if temperature < 1e-5 or top_p < 1e-8: # greedy
|
170
|
-
token = int(torch.argmax(last_token_logits))
|
171
|
-
else:
|
172
|
-
probs = torch.softmax(last_token_logits, dim=-1)
|
173
|
-
token = int(torch.multinomial(probs, num_samples=1))
|
174
|
-
output_ids.append(token)
|
175
|
-
|
176
|
-
# Stop condition
|
177
|
-
if token in stop_token_ids:
|
178
|
-
stop_reason = StopReason.EOS_TOKEN
|
179
|
-
break
|
180
|
-
|
181
|
-
output_str = tokenizer.decode(output_ids, skip_special_tokens=True)
|
182
|
-
for stop_str in stop:
|
183
|
-
pos = output_str[prompt_len:].find(stop_str)
|
184
|
-
if pos != -1:
|
185
|
-
stop_reason = StopReason.STOP_STR
|
186
|
-
output_str = output_str[: prompt_len + pos]
|
187
|
-
break
|
188
|
-
|
189
|
-
if stop_reason:
|
190
|
-
break
|
191
|
-
|
192
|
-
return output_str[prompt_len:]
|
193
|
-
|
194
|
-
|
195
|
-
class HuggingFaceTransformers(BaseBackend):
|
196
|
-
def __init__(
|
197
|
-
self,
|
198
|
-
model_name,
|
199
|
-
device="cuda",
|
200
|
-
num_gpus=1,
|
201
|
-
max_gpu_memory=None,
|
202
|
-
model_kwargs=None,
|
203
|
-
tokenizer_kwargs=None,
|
204
|
-
):
|
205
|
-
self.model_name = model_name
|
206
|
-
self.device = device
|
207
|
-
|
208
|
-
self.model, self.tokenizer = load_model(
|
209
|
-
model_name, device, num_gpus, max_gpu_memory, model_kwargs, tokenizer_kwargs
|
210
|
-
)
|
211
|
-
|
212
|
-
self.chat_template = get_chat_template_by_model_path(model_name)
|
213
|
-
|
214
|
-
def get_chat_template(self):
|
215
|
-
return self.chat_template
|
216
|
-
|
217
|
-
def cache_prefix(self, prefix_str: str):
|
218
|
-
pass
|
219
|
-
|
220
|
-
def uncache_prefix(self, rid: str):
|
221
|
-
pass
|
222
|
-
|
223
|
-
def end_request(self, rid: str):
|
224
|
-
pass
|
225
|
-
|
226
|
-
def begin_program(self, s: ProgramState):
|
227
|
-
pass
|
228
|
-
|
229
|
-
def end_program(self, s: ProgramState):
|
230
|
-
pass
|
231
|
-
|
232
|
-
def fill(self, s: ProgramState, text: str):
|
233
|
-
return False
|
234
|
-
|
235
|
-
def generate_internal(
|
236
|
-
self,
|
237
|
-
prompt: str,
|
238
|
-
max_tokens: int,
|
239
|
-
stop: Union[str, List[str]],
|
240
|
-
temperature: float,
|
241
|
-
top_p: float,
|
242
|
-
dtype: Optional[str] = None,
|
243
|
-
):
|
244
|
-
if dtype is None:
|
245
|
-
comp = generate_stream(
|
246
|
-
self.model,
|
247
|
-
self.tokenizer,
|
248
|
-
prompt,
|
249
|
-
max_new_tokens=max_tokens,
|
250
|
-
stop=stop,
|
251
|
-
temperature=temperature,
|
252
|
-
top_p=top_p,
|
253
|
-
token_healing=True,
|
254
|
-
)
|
255
|
-
elif dtype in [str, "str", "string"]:
|
256
|
-
comp = generate_stream(
|
257
|
-
self.model,
|
258
|
-
self.tokenizer,
|
259
|
-
prompt + '"',
|
260
|
-
max_new_tokens=max_tokens,
|
261
|
-
stop=['"'],
|
262
|
-
temperature=temperature,
|
263
|
-
top_p=top_p,
|
264
|
-
token_healing=False,
|
265
|
-
)
|
266
|
-
comp = '"' + comp + '"'
|
267
|
-
elif dtype in [int, "int"]:
|
268
|
-
logit_mask = get_int_token_mask(self.tokenizer)
|
269
|
-
comp = generate_stream(
|
270
|
-
self.model,
|
271
|
-
self.tokenizer,
|
272
|
-
prompt,
|
273
|
-
max_new_tokens=max_tokens,
|
274
|
-
stop=stop + [" ", ","],
|
275
|
-
temperature=temperature,
|
276
|
-
top_p=top_p,
|
277
|
-
token_healing=False,
|
278
|
-
logit_mask=logit_mask,
|
279
|
-
)
|
280
|
-
return comp
|
281
|
-
|
282
|
-
def generate(
|
283
|
-
self,
|
284
|
-
s: ProgramState,
|
285
|
-
max_tokens: int,
|
286
|
-
stop: Union[str, List[str]],
|
287
|
-
temperature: float,
|
288
|
-
top_p: float,
|
289
|
-
dtype: Optional[str] = None,
|
290
|
-
):
|
291
|
-
prompt = s.text
|
292
|
-
comp = self.generate_internal(
|
293
|
-
prompt, max_tokens, stop, temperature, top_p, dtype
|
294
|
-
)
|
295
|
-
return comp
|
296
|
-
|
297
|
-
def parallel_generate(
|
298
|
-
self,
|
299
|
-
s: ProgramState,
|
300
|
-
prefixes: List[str],
|
301
|
-
join_func: Callable,
|
302
|
-
max_tokens: int,
|
303
|
-
stop: Union[str, List[str]],
|
304
|
-
temperature: float,
|
305
|
-
top_p: float,
|
306
|
-
dtype: Optional[str] = None,
|
307
|
-
):
|
308
|
-
prompt = s.text
|
309
|
-
parallel_prompts = [prompt + prefix for prefix in prefixes]
|
310
|
-
|
311
|
-
comps = []
|
312
|
-
for i in range(len(parallel_prompts)):
|
313
|
-
comps.append(
|
314
|
-
self.generate_internal(
|
315
|
-
parallel_prompts[i], max_tokens, stop, temperature, top_p, dtype
|
316
|
-
)
|
317
|
-
)
|
318
|
-
|
319
|
-
joined = join_func([p + c for p, c in zip(prefixes, comps)])
|
320
|
-
return joined, comps
|
321
|
-
|
322
|
-
@torch.inference_mode()
|
323
|
-
def select(
|
324
|
-
self, s: ProgramState, choices: List[str], temperature: float, top_p: float
|
325
|
-
):
|
326
|
-
loss_fct = torch.nn.CrossEntropyLoss()
|
327
|
-
prompt = s.text
|
328
|
-
|
329
|
-
prompt_len = self.tokenizer.encode(prompt, return_tensors="pt").shape[1]
|
330
|
-
prompt_choices = [prompt + choice for choice in choices]
|
331
|
-
|
332
|
-
scores = []
|
333
|
-
for i in range(len(choices)):
|
334
|
-
choice_ids = self.tokenizer.encode(
|
335
|
-
prompt_choices[i], return_tensors="pt"
|
336
|
-
).to(self.model.device)
|
337
|
-
logits = self.model(choice_ids).logits
|
338
|
-
|
339
|
-
# score = -loss_fct(logits[0, :-1, :], choice_ids[0, 1:]).item()
|
340
|
-
|
341
|
-
logprobs = torch.log(torch.softmax(logits, dim=-1))
|
342
|
-
idx1 = torch.arange(0, logits.shape[1] - 1, device=logits.device)
|
343
|
-
idx2 = choice_ids[0, 1:]
|
344
|
-
selected_logprobs = logprobs[0, idx1, idx2]
|
345
|
-
score = selected_logprobs.mean().item()
|
346
|
-
scores.append(score)
|
347
|
-
|
348
|
-
decision = choices[np.argmax(scores)]
|
349
|
-
return decision, scores
|
sglang/backend/tgi.py
DELETED
@@ -1,190 +0,0 @@
|
|
1
|
-
import re
|
2
|
-
from concurrent.futures import ThreadPoolExecutor
|
3
|
-
from functools import partial
|
4
|
-
from itertools import repeat
|
5
|
-
from typing import List, Optional, Union
|
6
|
-
|
7
|
-
from sglang.backend.base_backend import BaseBackend
|
8
|
-
from sglang.lang.chat_template import get_chat_template_by_model_path
|
9
|
-
from sglang.lang.interpreter import StreamExecutor
|
10
|
-
from sglang.lang.ir import SglSamplingParams
|
11
|
-
from sglang.utils import http_request
|
12
|
-
|
13
|
-
|
14
|
-
class TGI(BaseBackend):
|
15
|
-
def __init__(self, base_url):
|
16
|
-
super().__init__()
|
17
|
-
|
18
|
-
self.base_url = base_url
|
19
|
-
|
20
|
-
res = http_request(self.base_url + "/info")
|
21
|
-
assert res.status_code == 200
|
22
|
-
self.model_info = res.json()
|
23
|
-
self.chat_template = get_chat_template_by_model_path(
|
24
|
-
self.model_info["model_id"]
|
25
|
-
)
|
26
|
-
|
27
|
-
def get_model_name(self):
|
28
|
-
return self.model_info["model_id"]
|
29
|
-
|
30
|
-
def get_chat_template(self):
|
31
|
-
return self.chat_template
|
32
|
-
|
33
|
-
@staticmethod
|
34
|
-
def adapt_params(max_tokens, stop, sampling_params, **override_params):
|
35
|
-
temperature = sampling_params.temperature
|
36
|
-
do_sample = True
|
37
|
-
if temperature == 0:
|
38
|
-
do_sample = False
|
39
|
-
temperature = None
|
40
|
-
|
41
|
-
if stop is None:
|
42
|
-
stop = []
|
43
|
-
elif isinstance(stop, str):
|
44
|
-
stop = [stop]
|
45
|
-
|
46
|
-
top_p = sampling_params.top_p
|
47
|
-
if top_p == 0:
|
48
|
-
top_p = 0.001
|
49
|
-
if top_p == 1:
|
50
|
-
top_p = 0.999
|
51
|
-
|
52
|
-
top_k = sampling_params.top_k
|
53
|
-
if top_k == -1:
|
54
|
-
top_k = None
|
55
|
-
|
56
|
-
params = {
|
57
|
-
"decoder_input_details": False,
|
58
|
-
"details": False,
|
59
|
-
"do_sample": do_sample,
|
60
|
-
"max_new_tokens": max_tokens,
|
61
|
-
"stop": stop,
|
62
|
-
"temperature": temperature,
|
63
|
-
"top_p": top_p,
|
64
|
-
"top_k": top_k,
|
65
|
-
"return_full_text": False,
|
66
|
-
}
|
67
|
-
params.update(override_params)
|
68
|
-
return params
|
69
|
-
|
70
|
-
@staticmethod
|
71
|
-
def _extract_int(text):
|
72
|
-
words = re.split("\ |'|\/|\(|\)|\n|\.|,", text)
|
73
|
-
for word in words:
|
74
|
-
try:
|
75
|
-
int(word)
|
76
|
-
return word
|
77
|
-
except ValueError:
|
78
|
-
continue
|
79
|
-
raise ValueError
|
80
|
-
|
81
|
-
@staticmethod
|
82
|
-
def _extract_choice(choices, text):
|
83
|
-
# FIXME: Current only support the case where the choices are single words.
|
84
|
-
words = re.split("\ |'|\/|\(|\)|\n|\.|,", text)
|
85
|
-
for word in words:
|
86
|
-
if word in choices:
|
87
|
-
return word
|
88
|
-
raise ValueError
|
89
|
-
|
90
|
-
@staticmethod
|
91
|
-
def _truncate_to_stop(text, stop):
|
92
|
-
# The stop sequence may not be a single token. In this case TGI will generate
|
93
|
-
# too many tokens so we need to truncate the output.
|
94
|
-
if stop:
|
95
|
-
stop = [stop] if isinstance(stop, str) else stop
|
96
|
-
for stop_seq in stop:
|
97
|
-
pos = text.find(stop_seq)
|
98
|
-
if pos != -1:
|
99
|
-
return text[:pos]
|
100
|
-
return text
|
101
|
-
|
102
|
-
def _make_request(self, params):
|
103
|
-
res = http_request(self.base_url + "/generate", json=params)
|
104
|
-
if res.status_code != 200:
|
105
|
-
raise ValueError(f"Error from TGI backend: {res.text}")
|
106
|
-
return res.json()
|
107
|
-
|
108
|
-
def retry_for_expected(self, prompt, params, extract_fn, retry=5):
|
109
|
-
# TGI does not support logis_bias (yet), so we have to use an inefficient hack.
|
110
|
-
failed = []
|
111
|
-
while retry > 0:
|
112
|
-
res_json = self._make_request(
|
113
|
-
{
|
114
|
-
"inputs": prompt,
|
115
|
-
"parameters": params,
|
116
|
-
}
|
117
|
-
)
|
118
|
-
text = res_json["generated_text"]
|
119
|
-
try:
|
120
|
-
return extract_fn(text)
|
121
|
-
except ValueError:
|
122
|
-
retry -= 1
|
123
|
-
failed.append(text)
|
124
|
-
|
125
|
-
msg = "=" * 20 + "\n"
|
126
|
-
msg += f"Prompt:\n{prompt}\n"
|
127
|
-
msg += "=" * 20 + "\n"
|
128
|
-
for i, text in enumerate(failed):
|
129
|
-
msg += f"====== Try {i+1}:\n{text}\n"
|
130
|
-
|
131
|
-
raise ValueError(
|
132
|
-
f"Model {self.model_info['model_id']} served by TGI backend does not generate"
|
133
|
-
"expected output. Please improve the prompt, increase the temperature, or "
|
134
|
-
f"use different models.\n{msg}"
|
135
|
-
)
|
136
|
-
|
137
|
-
def select(
|
138
|
-
self,
|
139
|
-
s: StreamExecutor,
|
140
|
-
choices: List[str],
|
141
|
-
sampling_params: SglSamplingParams,
|
142
|
-
):
|
143
|
-
decision = self.retry_for_expected(
|
144
|
-
s.text_,
|
145
|
-
self.adapt_params(16, [], sampling_params),
|
146
|
-
partial(self._extract_choice, choices),
|
147
|
-
)
|
148
|
-
return decision, [1 if choice == decision else 0 for choice in choices]
|
149
|
-
|
150
|
-
def generate(
|
151
|
-
self,
|
152
|
-
s: StreamExecutor,
|
153
|
-
max_tokens: int,
|
154
|
-
stop: Union[str, List[str]],
|
155
|
-
sampling_params: SglSamplingParams,
|
156
|
-
dtype: Optional[str] = None,
|
157
|
-
):
|
158
|
-
if dtype is None:
|
159
|
-
res_json = self._make_request(
|
160
|
-
{
|
161
|
-
"inputs": s.text_,
|
162
|
-
"parameters": self.adapt_params(max_tokens, stop, sampling_params),
|
163
|
-
}
|
164
|
-
)
|
165
|
-
return self._truncate_to_stop(res_json["generated_text"], stop), {}
|
166
|
-
|
167
|
-
if dtype in [str, "str", "string"]:
|
168
|
-
stop = ['"']
|
169
|
-
res_json = self._make_request(
|
170
|
-
{
|
171
|
-
"inputs": f'{s.text_}"',
|
172
|
-
"parameters": self.adapt_params(max_tokens, stop, sampling_params),
|
173
|
-
}
|
174
|
-
)
|
175
|
-
return (
|
176
|
-
'"' + self._truncate_to_stop(res_json["generated_text"], stop) + '"',
|
177
|
-
{},
|
178
|
-
)
|
179
|
-
|
180
|
-
if dtype in [int, "int"]:
|
181
|
-
return (
|
182
|
-
self.retry_for_expected(
|
183
|
-
s.text_,
|
184
|
-
self.adapt_params(max_tokens, stop, sampling_params),
|
185
|
-
self._extract_int,
|
186
|
-
),
|
187
|
-
{},
|
188
|
-
)
|
189
|
-
|
190
|
-
raise ValueError(f"Unknown dtype: {dtype}")
|
File without changes
|
File without changes
|
File without changes
|