sglang 0.1.3__py3-none-any.whl → 0.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sglang/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- __version__ = "0.1.3"
1
+ __version__ = "0.1.5"
2
2
 
3
3
  from sglang.api import *
4
4
  from sglang.global_config import global_config
sglang/api.py CHANGED
@@ -6,6 +6,7 @@ from sglang.backend.anthropic import Anthropic
6
6
  from sglang.backend.base_backend import BaseBackend
7
7
  from sglang.backend.openai import OpenAI
8
8
  from sglang.backend.runtime_endpoint import RuntimeEndpoint
9
+ from sglang.backend.vertexai import VertexAI
9
10
  from sglang.global_config import global_config
10
11
  from sglang.lang.ir import (
11
12
  SglExpr,
@@ -0,0 +1,147 @@
1
+ import os
2
+ import warnings
3
+ from typing import List, Optional, Union
4
+
5
+ import numpy as np
6
+ from sglang.backend.base_backend import BaseBackend
7
+ from sglang.lang.chat_template import get_chat_template
8
+ from sglang.lang.interpreter import StreamExecutor
9
+ from sglang.lang.ir import SglSamplingParams
10
+
11
+ try:
12
+ import vertexai
13
+ from vertexai.preview.generative_models import (
14
+ GenerationConfig,
15
+ GenerativeModel,
16
+ Image,
17
+ )
18
+ except ImportError as e:
19
+ GenerativeModel = e
20
+
21
+
22
+ class VertexAI(BaseBackend):
23
+ def __init__(self, model_name):
24
+ super().__init__()
25
+
26
+ if isinstance(GenerativeModel, Exception):
27
+ raise GenerativeModel
28
+
29
+ project_id = os.environ["GCP_PROJECT_ID"]
30
+ location = os.environ.get("GCP_LOCATION")
31
+ vertexai.init(project=project_id, location=location)
32
+
33
+ self.model_name = model_name
34
+ self.chat_template = get_chat_template("default")
35
+
36
+ def get_chat_template(self):
37
+ return self.chat_template
38
+
39
+ def generate(
40
+ self,
41
+ s: StreamExecutor,
42
+ sampling_params: SglSamplingParams,
43
+ ):
44
+ if s.messages_:
45
+ prompt = self.messages_to_vertexai_input(s.messages_)
46
+ else:
47
+ # single-turn
48
+ prompt = (
49
+ self.text_to_vertexai_input(s.text_, s.cur_images)
50
+ if s.cur_images
51
+ else s.text_
52
+ )
53
+ ret = GenerativeModel(self.model_name).generate_content(
54
+ prompt,
55
+ generation_config=GenerationConfig(**sampling_params.to_vertexai_kwargs()),
56
+ )
57
+
58
+ comp = ret.text
59
+
60
+ return comp, {}
61
+
62
+ def generate_stream(
63
+ self,
64
+ s: StreamExecutor,
65
+ sampling_params: SglSamplingParams,
66
+ ):
67
+ if s.messages_:
68
+ prompt = self.messages_to_vertexai_input(s.messages_)
69
+ else:
70
+ # single-turn
71
+ prompt = (
72
+ self.text_to_vertexai_input(s.text_, s.cur_images)
73
+ if s.cur_images
74
+ else s.text_
75
+ )
76
+ generator = GenerativeModel(self.model_name).generate_content(
77
+ prompt,
78
+ stream=True,
79
+ generation_config=GenerationConfig(**sampling_params.to_vertexai_kwargs()),
80
+ )
81
+ for ret in generator:
82
+ yield ret.text, {}
83
+
84
+ def text_to_vertexai_input(self, text, images):
85
+ input = []
86
+ # split with image token
87
+ text_segs = text.split(self.chat_template.image_token)
88
+ for image_path, image_base64_data in images:
89
+ text_seg = text_segs.pop(0)
90
+ if text_seg != "":
91
+ input.append(text_seg)
92
+ input.append(Image.from_bytes(image_base64_data))
93
+ text_seg = text_segs.pop(0)
94
+ if text_seg != "":
95
+ input.append(text_seg)
96
+ return input
97
+
98
+ def messages_to_vertexai_input(self, messages):
99
+ vertexai_message = []
100
+ # from openai message format to vertexai message format
101
+ for msg in messages:
102
+ if isinstance(msg["content"], str):
103
+ text = msg["content"]
104
+ else:
105
+ text = msg["content"][0]["text"]
106
+
107
+ if msg["role"] == "system":
108
+ warnings.warn("Warning: system prompt is not supported in VertexAI.")
109
+ vertexai_message.append(
110
+ {
111
+ "role": "user",
112
+ "parts": [{"text": "System prompt: " + text}],
113
+ }
114
+ )
115
+ vertexai_message.append(
116
+ {
117
+ "role": "model",
118
+ "parts": [{"text": "Understood."}],
119
+ }
120
+ )
121
+ continue
122
+ if msg["role"] == "user":
123
+ vertexai_msg = {
124
+ "role": "user",
125
+ "parts": [{"text": text}],
126
+ }
127
+ elif msg["role"] == "assistant":
128
+ vertexai_msg = {
129
+ "role": "model",
130
+ "parts": [{"text": text}],
131
+ }
132
+
133
+ # images
134
+ if isinstance(msg["content"], list) and len(msg["content"]) > 1:
135
+ for image in msg["content"][1:]:
136
+ assert image["type"] == "image_url"
137
+ vertexai_msg["parts"].append(
138
+ {
139
+ "inline_data": {
140
+ "data": image["image_url"]["url"].split(",")[1],
141
+ "mime_type": "image/jpeg",
142
+ }
143
+ }
144
+ )
145
+
146
+ vertexai_message.append(vertexai_msg)
147
+ return vertexai_message
@@ -365,11 +365,10 @@ class StreamExecutor:
365
365
  for comp, meta_info in generator:
366
366
  self.text_ += comp
367
367
  self.variables[name] += comp
368
+ self.meta_info[name] = meta_info
368
369
  self.stream_var_event[name].set()
369
370
  self.stream_text_event.set()
370
371
 
371
- self.meta_info[name] = meta_info
372
-
373
372
  self.variable_event[name].set()
374
373
  self.stream_var_event[name].set()
375
374
 
@@ -428,6 +427,7 @@ class StreamExecutor:
428
427
  self.messages_.append(last_msg)
429
428
  self.cur_images = []
430
429
  else:
430
+ # OpenAI chat API format
431
431
  self.messages_.append({"role": expr.role, "content": new_text})
432
432
 
433
433
  self.cur_role = None
@@ -582,7 +582,7 @@ class ProgramState:
582
582
  else:
583
583
  yield self.get_var(name)
584
584
 
585
- async def text_async_iter(self, var_name=None):
585
+ async def text_async_iter(self, var_name=None, return_meta_data=False):
586
586
  loop = asyncio.get_running_loop()
587
587
 
588
588
  if self.stream_executor.stream:
@@ -606,7 +606,10 @@ class ProgramState:
606
606
  out = str(self.stream_executor.variables[var_name][prev:])
607
607
  prev += len(out)
608
608
  if out:
609
- yield out
609
+ if return_meta_data:
610
+ yield out, self.stream_executor.meta_info[var_name]
611
+ else:
612
+ yield out
610
613
  if self.stream_executor.variable_event[var_name].is_set():
611
614
  break
612
615
  else:
@@ -632,11 +635,7 @@ class ProgramState:
632
635
  self.stream_executor.end()
633
636
 
634
637
  def __repr__(self) -> str:
635
- msgs = self.messages()
636
- ret = ""
637
- for msg in msgs:
638
- ret += msg["role"] + ":\n" + msg["content"] + "\n"
639
- return ret
638
+ return f"ProgramState({self.text()})"
640
639
 
641
640
 
642
641
  class ProgramStateGroup:
sglang/lang/ir.py CHANGED
@@ -2,6 +2,7 @@
2
2
 
3
3
  import dataclasses
4
4
  import inspect
5
+ import warnings
5
6
  from typing import List, Optional, Union
6
7
 
7
8
  from sglang.global_config import global_config
@@ -40,6 +41,8 @@ class SglSamplingParams:
40
41
 
41
42
  def to_openai_kwargs(self):
42
43
  # OpenAI does not support top_k, so we drop it here
44
+ if self.regex is not None:
45
+ warnings.warn("Regular expression is not supported in the OpenAI backend.")
43
46
  return {
44
47
  "max_tokens": self.max_new_tokens,
45
48
  "stop": self.stop or None,
@@ -49,8 +52,26 @@ class SglSamplingParams:
49
52
  "presence_penalty": self.presence_penalty,
50
53
  }
51
54
 
55
+ def to_vertexai_kwargs(self):
56
+ if self.regex is not None:
57
+ warnings.warn(
58
+ "Regular expression is not supported in the VertexAI backend."
59
+ )
60
+ return {
61
+ "candidate_count": 1,
62
+ "max_output_tokens": self.max_new_tokens,
63
+ "stop_sequences": self.stop,
64
+ "temperature": self.temperature,
65
+ "top_p": self.top_p,
66
+ "top_k": self.top_k if self.top_k > 0 else None,
67
+ }
68
+
52
69
  def to_anthropic_kwargs(self):
53
70
  # Anthropic does not support frequency_penalty or presence_penalty, so we drop it here
71
+ if self.regex is not None:
72
+ warnings.warn(
73
+ "Regular expression is not supported in the Anthropic backend."
74
+ )
54
75
  return {
55
76
  "max_tokens_to_sample": self.max_new_tokens,
56
77
  "stop_sequences": self.stop,
@@ -5,6 +5,8 @@ import triton
5
5
  import triton.language as tl
6
6
  from sglang.srt.utils import wrap_kernel_launcher
7
7
 
8
+ CUDA_CAPABILITY = torch.cuda.get_device_capability()
9
+
8
10
 
9
11
  @triton.jit
10
12
  def _fwd_kernel(
@@ -120,7 +122,11 @@ cached_kernel = None
120
122
 
121
123
 
122
124
  def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len):
123
- BLOCK = 128
125
+ if CUDA_CAPABILITY[0] >= 8:
126
+ BLOCK = 128
127
+ else:
128
+ BLOCK = 64
129
+
124
130
  Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
125
131
  assert Lq == Lk and Lk == Lv
126
132
  assert Lk in {16, 32, 64, 128}
@@ -2,6 +2,9 @@ import torch
2
2
  import triton
3
3
  import triton.language as tl
4
4
  from sglang.srt.layers.context_flashattention_nopad import context_attention_fwd
5
+ from sglang.srt.utils import wrap_kernel_launcher
6
+
7
+ CUDA_CAPABILITY = torch.cuda.get_device_capability()
5
8
 
6
9
 
7
10
  @triton.jit
@@ -153,6 +156,9 @@ def _fwd_kernel(
153
156
  tl.store(O_Extend + offs_o, acc / deno[:, None], mask=mask_m[:, None])
154
157
 
155
158
 
159
+ cached_kernel = None
160
+
161
+
156
162
  def extend_attention_fwd(
157
163
  q_extend,
158
164
  k_extend,
@@ -175,7 +181,11 @@ def extend_attention_fwd(
175
181
 
176
182
  k_buffer, v_buffer: (prefix + extend) tensors in mem_manager
177
183
  """
178
- BLOCK_M, BLOCK_N = 128, 128
184
+ if CUDA_CAPABILITY[0] >= 8:
185
+ BLOCK_M, BLOCK_N = 128, 128
186
+ else:
187
+ BLOCK_M, BLOCK_N = 64, 64
188
+
179
189
  Lq, Lk, Lv, Lo = (
180
190
  q_extend.shape[-1],
181
191
  k_extend.shape[-1],
@@ -193,6 +203,40 @@ def extend_attention_fwd(
193
203
  num_warps = 4 if Lk <= 64 else 8
194
204
  num_stages = 1
195
205
 
206
+ global cached_kernel
207
+ if cached_kernel:
208
+ cached_kernel(
209
+ grid,
210
+ num_warps,
211
+ q_extend,
212
+ k_extend,
213
+ v_extend,
214
+ o_extend,
215
+ k_buffer,
216
+ v_buffer,
217
+ req_to_tokens,
218
+ b_req_idx,
219
+ b_seq_len,
220
+ b_start_loc_extend,
221
+ b_seq_len_extend,
222
+ sm_scale,
223
+ kv_group_num,
224
+ q_extend.stride(0),
225
+ q_extend.stride(1),
226
+ k_extend.stride(0),
227
+ k_extend.stride(1),
228
+ v_extend.stride(0),
229
+ v_extend.stride(1),
230
+ o_extend.stride(0),
231
+ o_extend.stride(1),
232
+ k_buffer.stride(0),
233
+ k_buffer.stride(1),
234
+ v_buffer.stride(0),
235
+ v_buffer.stride(1),
236
+ req_to_tokens.stride(0),
237
+ )
238
+ return
239
+
196
240
  _fwd_kernel[grid](
197
241
  q_extend,
198
242
  k_extend,
@@ -226,6 +270,7 @@ def extend_attention_fwd(
226
270
  num_warps=num_warps,
227
271
  num_stages=num_stages,
228
272
  )
273
+ cached_kernel = wrap_kernel_launcher(_fwd_kernel)
229
274
 
230
275
 
231
276
  def redundant_attention(
@@ -28,7 +28,7 @@ class RouterManager:
28
28
  self.model_client = model_client
29
29
  self.recv_reqs = []
30
30
 
31
- # Init Some Configs
31
+ # Init some configs
32
32
  self.extend_dependency_time = GLOBAL_BACKEND_CONFIG.extend_dependency_time
33
33
 
34
34
  async def loop_for_forward(self):
@@ -46,7 +46,7 @@ class RouterManager:
46
46
  if has_finished:
47
47
  await asyncio.sleep(self.extend_dependency_time)
48
48
 
49
- await asyncio.sleep(0.001)
49
+ await asyncio.sleep(0.0006)
50
50
 
51
51
  async def loop_for_recv_requests(self):
52
52
  while True:
@@ -2,6 +2,7 @@ import asyncio
2
2
  import logging
3
3
  import multiprocessing
4
4
  import time
5
+ import warnings
5
6
  from concurrent.futures import ThreadPoolExecutor
6
7
  from enum import Enum, auto
7
8
  from typing import Dict, List, Optional, Tuple, Union
@@ -44,6 +45,7 @@ class ModelRpcServer(rpyc.Service):
44
45
  self.tp_rank = tp_rank
45
46
  self.tp_size = server_args.tp_size
46
47
  self.schedule_heuristic = server_args.schedule_heuristic
48
+ self.schedule_conservativeness = server_args.schedule_conservativeness
47
49
 
48
50
  # Init model and tokenizer
49
51
  self.model_config = ModelConfig(
@@ -107,7 +109,7 @@ class ModelRpcServer(rpyc.Service):
107
109
  self.running_batch: Batch = None
108
110
  self.out_pyobjs = []
109
111
  self.decode_forward_ct = 0
110
- self.stream_interval = 2
112
+ self.stream_interval = server_args.stream_interval
111
113
 
112
114
  # Init the FSM cache for constrained generation
113
115
  self.regex_fsm_cache = FSMCache(self.tokenizer)
@@ -164,7 +166,7 @@ class ModelRpcServer(rpyc.Service):
164
166
  + self.tree_cache.evictable_size()
165
167
  )
166
168
  if available_size != self.max_total_num_token:
167
- logger.warning(
169
+ warnings.warn(
168
170
  "Warning: "
169
171
  f"available_size={available_size}, max_total_num_token={self.max_total_num_token}\n"
170
172
  "KV cache pool leak detected!"
@@ -247,7 +249,9 @@ class ModelRpcServer(rpyc.Service):
247
249
  available_size = (
248
250
  self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
249
251
  )
250
- new_ratio = self.scheduler.new_token_estimation_ratio()
252
+ new_ratio = (
253
+ self.scheduler.new_token_estimation_ratio() * self.schedule_conservativeness
254
+ )
251
255
  if self.running_batch:
252
256
  available_size -= sum(
253
257
  [
@@ -278,7 +278,7 @@ class ModelRunner:
278
278
  load_format=self.load_format,
279
279
  revision=None,
280
280
  )
281
- self.model = model
281
+ self.model = model.eval()
282
282
 
283
283
  def profile_max_num_token(self, total_gpu_memory):
284
284
  available_gpu_memory = get_available_gpu_memory(
@@ -355,7 +355,7 @@ class MixtralForCausalLM(nn.Module):
355
355
  ):
356
356
  if "rotary_emb.inv_freq" in name:
357
357
  continue
358
- for (param_name, weight_name, shard_id) in stacked_params_mapping:
358
+ for param_name, weight_name, shard_id in stacked_params_mapping:
359
359
  if weight_name not in name:
360
360
  continue
361
361
  name = name.replace(weight_name, param_name)
sglang/srt/server_args.py CHANGED
@@ -16,7 +16,9 @@ class ServerArgs:
16
16
  tp_size: int = 1
17
17
  model_mode: List[str] = ()
18
18
  schedule_heuristic: str = "lpm"
19
+ schedule_conservativeness: float = 1.0
19
20
  random_seed: int = 42
21
+ stream_interval: int = 2
20
22
  disable_log_stats: bool = False
21
23
  log_stats_interval: int = 10
22
24
  log_level: str = "info"
@@ -25,10 +27,14 @@ class ServerArgs:
25
27
  if self.tokenizer_path is None:
26
28
  self.tokenizer_path = self.model_path
27
29
  if self.mem_fraction_static is None:
28
- if self.tp_size > 1:
29
- self.mem_fraction_static = 0.8
30
+ if self.tp_size >= 8:
31
+ self.mem_fraction_static = 0.80
32
+ elif self.tp_size >= 4:
33
+ self.mem_fraction_static = 0.82
34
+ elif self.tp_size >= 2:
35
+ self.mem_fraction_static = 0.85
30
36
  else:
31
- self.mem_fraction_static = 0.9
37
+ self.mem_fraction_static = 0.90
32
38
 
33
39
  @staticmethod
34
40
  def add_cli_args(parser: argparse.ArgumentParser):
@@ -80,7 +86,7 @@ class ServerArgs:
80
86
  "--mem-fraction-static",
81
87
  type=float,
82
88
  default=ServerArgs.mem_fraction_static,
83
- help="The fraction of the memory used for static allocation (model weights and KV cache memory pool)",
89
+ help="The fraction of the memory used for static allocation (model weights and KV cache memory pool). Use a smaller value if you see out-of-memory errors.",
84
90
  )
85
91
  parser.add_argument(
86
92
  "--tp-size",
@@ -102,12 +108,24 @@ class ServerArgs:
102
108
  default=ServerArgs.schedule_heuristic,
103
109
  help="Schudule mode: [lpm, weight, random, fcfs]",
104
110
  )
111
+ parser.add_argument(
112
+ "--schedule-conservativeness",
113
+ type=float,
114
+ default=ServerArgs.schedule_conservativeness,
115
+ help="How conservative the schedule policy is. A larger value means more conservative scheduling. Use a larger value if you see out-of-memory errors.",
116
+ )
105
117
  parser.add_argument(
106
118
  "--random-seed",
107
119
  type=int,
108
120
  default=ServerArgs.random_seed,
109
121
  help="Random seed.",
110
122
  )
123
+ parser.add_argument(
124
+ "--stream-interval",
125
+ type=int,
126
+ default=ServerArgs.stream_interval,
127
+ help="The interval in terms of token length for streaming",
128
+ )
111
129
  parser.add_argument(
112
130
  "--log-level",
113
131
  type=str,
sglang/srt/utils.py CHANGED
@@ -209,7 +209,7 @@ def load_image(image_file):
209
209
  elif image_file.lower().endswith(("png", "jpg", "jpeg", "webp", "gif")):
210
210
  image = Image.open(image_file)
211
211
  elif image_file.startswith("data:"):
212
- image_file = image_url.split(",")[1]
212
+ image_file = image_file.split(",")[1]
213
213
  image = Image.open(BytesIO(base64.b64decode(image_file)))
214
214
  else:
215
215
  image = Image.open(BytesIO(base64.b64decode(image_file)))
@@ -304,7 +304,10 @@ def test_image_qa():
304
304
  temperature=0,
305
305
  max_new_tokens=64,
306
306
  )
307
- assert "taxi" in state.messages()[-1]["content"]
307
+ assert (
308
+ "taxi" in state.messages()[-1]["content"]
309
+ or "car" in state.messages()[-1]["content"]
310
+ )
308
311
 
309
312
 
310
313
  def test_stream():
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: sglang
3
- Version: 0.1.3
3
+ Version: 0.1.5
4
4
  Summary: A structured generation langauge for LLMs.
5
5
  License: Apache License
6
6
  Version 2.0, January 2004
@@ -234,6 +234,7 @@ Requires-Dist: lark ; extra == 'srt'
234
234
  Requires-Dist: numba ; extra == 'srt'
235
235
 
236
236
  # SGLang
237
+ | [**Blog**](https://lmsys.org/blog/2024-01-17-sglang/) | [**Paper**](https://arxiv.org/abs/2312.07104) |
237
238
 
238
239
  SGLang is a structured generation language designed for large language models (LLMs).
239
240
  It makes your interaction with LLMs faster and more controllable by co-designing the frontend language and the runtime system.
@@ -267,10 +268,20 @@ pip install --upgrade pip
267
268
  pip install -e "python[all]"
268
269
  ```
269
270
 
271
+ ### Notes
272
+ - If you are using older GPUs (NVIDIA T4, V100), please use `pip install "triton>=2.2.0"` to avoid some bugs in the triton compiler
273
+ - If you only need to use the OpenAI backend, you can avoid installing other dependencies by using `pip install sglang[openai]`
274
+
270
275
  ## Quick Start
271
276
  The example below shows how to use sglang to answer a mulit-turn question.
272
277
 
273
278
  ### Using OpenAI Models
279
+ Set the OpenAI API Key
280
+ ```
281
+ export OPENAI_API_KEY=sk-******
282
+ ```
283
+
284
+ Then, answer a multi-turn question.
274
285
  ```python
275
286
  from sglang import function, system, user, assistant, gen, set_default_backend, OpenAI
276
287
 
@@ -325,6 +336,7 @@ for m in state.messages():
325
336
 
326
337
  ### More Examples
327
338
 
339
+ Anthropic and VertexAI (Gemini) models are also supported.
328
340
  You can find more examples at [examples/quick_start](examples/quick_start).
329
341
 
330
342
  ## Frontend: Structured Generation Langauge (SGLang)
@@ -334,19 +346,20 @@ To begin with, import sglang.
334
346
  import sglang as sgl
335
347
  ```
336
348
 
337
- `sglang` provides some simple primitives such as `gen`, `select`, `fork`.
349
+ `sglang` provides some simple primitives such as `gen`, `select`, `fork`, `image`.
338
350
  You can implement your prompt flow in a function decorated by `sgl.function`.
339
351
  You can then invoke the function with `run` or `run_batch`.
340
352
  The system will manage the state, chat template, and parallelism for you.
341
353
 
342
354
  ### Control Flow
355
+ You can use any Python code within the function body, including control flow, nested function calls, and external libraries.
356
+
343
357
  ```python
344
358
  @sgl.function
345
359
  def control_flow(s, question):
346
360
  s += "To answer this question: " + question + ", "
347
361
  s += "I need to use a " + sgl.gen("tool", choices=["calculator", "web browser"]) + ". "
348
362
 
349
- # You can use if or nested function calls
350
363
  if s["tool"] == "calculator":
351
364
  s += "The math expression is" + sgl.gen("expression")
352
365
  elif s["tool"] == "web browser":
@@ -354,6 +367,9 @@ def control_flow(s, question):
354
367
  ```
355
368
 
356
369
  ### Parallelism
370
+ Use `fork` to launch parallel prompts.
371
+ Because `sgl.gen` is non-blocking, the for loop below issues two generation calls in parallel.
372
+
357
373
  ```python
358
374
  @sgl.function
359
375
  def tip_suggestion(s):
@@ -362,7 +378,7 @@ def tip_suggestion(s):
362
378
  "1. Balanced Diet. 2. Regular Exercise.\n\n"
363
379
  )
364
380
 
365
- forks = s.fork(2) # Launch parallel prompts
381
+ forks = s.fork(2)
366
382
  for i, f in enumerate(forks):
367
383
  f += f"Now, expand tip {i+1} into a paragraph:\n"
368
384
  f += sgl.gen(f"detailed_tip", max_tokens=256, stop="\n\n")
@@ -373,6 +389,8 @@ def tip_suggestion(s):
373
389
  ```
374
390
 
375
391
  ### Multi Modality
392
+ Use `sgl.image` to pass an image as input.
393
+
376
394
  ```python
377
395
  @sgl.function
378
396
  def image_qa(s, image_file, question):
@@ -381,11 +399,13 @@ def image_qa(s, image_file, question):
381
399
  ```
382
400
 
383
401
  ### Constrained Decoding
402
+ Use `regex=` to specify a regular expression as a decoding constraint.
403
+
384
404
  ```python
385
- @function
405
+ @sgl.function
386
406
  def regular_expression_gen(s):
387
407
  s += "Q: What is the IP address of the Google DNS servers?\n"
388
- s += "A: " + gen(
408
+ s += "A: " + sgl.gen(
389
409
  "answer",
390
410
  temperature=0,
391
411
  regex=r"((25[0-5]|2[0-4]\d|[01]?\d\d?).){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)",
@@ -393,6 +413,8 @@ def regular_expression_gen(s):
393
413
  ```
394
414
 
395
415
  ### Batching
416
+ Use `run_batch` to run a batch of requests with continuous batching.
417
+
396
418
  ```python
397
419
  @sgl.function
398
420
  def text_qa(s, question):
@@ -405,10 +427,13 @@ states = text_qa.run_batch(
405
427
  {"question": "What is the capital of France?"},
406
428
  {"question": "What is the capital of Japan?"},
407
429
  ],
430
+ progress_bar=True
408
431
  )
409
432
  ```
410
433
 
411
434
  ### Streaming
435
+ Add `stream=True` to enable streaming.
436
+
412
437
  ```python
413
438
  @sgl.function
414
439
  def text_qa(s, question):
@@ -417,7 +442,9 @@ def text_qa(s, question):
417
442
 
418
443
  states = text_qa.run(
419
444
  question="What is the capital of France?",
420
- temperature=0.1)
445
+ temperature=0.1,
446
+ stream=True
447
+ )
421
448
 
422
449
  for out in state.text_iter():
423
450
  print(out, end="", flush=True)
@@ -426,7 +453,7 @@ for out in state.text_iter():
426
453
  ## Backend: SGLang Runtime (SRT)
427
454
  The SGLang Runtime (SRT) is designed to work best with the SGLang frontend.
428
455
  However, it can also be used as a standalone API server.
429
- In this case, the RadixAttention can still greatly accelerate many use cases.
456
+ In this case, the [RadixAttention](https://arxiv.org/abs/2312.07104) can still greatly accelerate many use cases with automatic KV cache reuse.
430
457
 
431
458
  ### Usage
432
459
  Launch a server
@@ -450,6 +477,10 @@ curl http://localhost:30000/v1/completions \
450
477
  ```
451
478
  python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 --tp 2
452
479
  ```
480
+ - If you see out-of-memory errors during serving, please try to reduce the memory usage of the KV cache pool by setting a smaller value of `--mem-fraction-static`. The default value is `0.9`
481
+ ```
482
+ python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 --mem-fraction-static 0.7
483
+ ```
453
484
 
454
485
  ### Supported Models
455
486
  - Llama
@@ -457,6 +488,7 @@ python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port
457
488
  - Mixtral
458
489
  - LLaVA
459
490
  - `python3 -m sglang.launch_server --model-path liuhaotian/llava-v1.5-7b --tokenizer-path llava-hf/llava-1.5-7b-hf --port 30000`
491
+ - AWQ quantization
460
492
 
461
493
  ## Benchmark And Performance
462
494
 
@@ -466,13 +498,13 @@ python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port
466
498
  - Mixtral-8x7B on NVIDIA A10G, FP16, Tensor Parallelism=8
467
499
  ![mixtral_8x7b](assets/mixtral_8x7b.jpg)
468
500
 
469
- Learn more [here]().
501
+ Learn more [here](docs/benchmark_results.md).
470
502
 
471
503
  ## Roadmap
472
- - [ ] Function call
473
- - [ ] Quantization
504
+ - [ ] Function call APIs
474
505
  - [ ] S-LoRA
475
- - [ ] More models
506
+ - [ ] Support more models
507
+ - [ ] Support more hardware backends
476
508
 
477
509
  ## Citation And Acknowledgment
478
510
  ```
@@ -1,5 +1,5 @@
1
- sglang/__init__.py,sha256=U_vIUJQoAKKm3mK9wNlAiUFO4rk5G0epSNmOO43IQrI,95
2
- sglang/api.py,sha256=tJuEyB28BUQfl0-dQr4vi6UMHBhUbmyu9Z3iAE5xFcU,3883
1
+ sglang/__init__.py,sha256=G73L_PWJ_6mF3NIE4ZAOWcb1CUbETSeRWr3wDTePrZ4,95
2
+ sglang/api.py,sha256=SxmPP_PMYi4DfUcwz_V9UvYOwGmQdHPgpMV6jDDJq68,3928
3
3
  sglang/flush_cache.py,sha256=cCD_MTlQ5qEv__w0nOthDnVitdAfyscYjksBljwC5Mw,1835
4
4
  sglang/global_config.py,sha256=PAX7TWeFcq0HBzNUWyCONAOjqIokWqw8vT7I6sBSKTc,797
5
5
  sglang/launch_server.py,sha256=jKPZRDN5bUe8Wgz5eoDkqeePhmKa8DLD4DpXQLT5auo,294
@@ -7,15 +7,14 @@ sglang/utils.py,sha256=tvJs95QGZ_PcnTjvm-CDGQ8dJe84qUUOfG7BeF79nsA,5670
7
7
  sglang/backend/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
8
8
  sglang/backend/anthropic.py,sha256=y5TN9EDrJtOH4JEUxpXu-endloeYBy7xMUr3r7Ah3MA,1462
9
9
  sglang/backend/base_backend.py,sha256=pPalZfoezxnUBs752j7lm0uMwa8tZuCWd-ijSdStMO8,1745
10
- sglang/backend/huggingface.py,sha256=roQlt8y41PQbmnAY47CXiR0KJaxhtljH6j8RhbsR4f0,10533
11
10
  sglang/backend/openai.py,sha256=umTWzC2p4PypDaXHe6Kc8By5IM_Doi0Ob97vK_fFWDc,7367
12
11
  sglang/backend/runtime_endpoint.py,sha256=rIhwtKJaLLCJAc6q6kqxEVC8xO_NNjmJs7BnxlOydLM,5860
13
- sglang/backend/tgi.py,sha256=2wlfparGJNLN806bvPi_8jsk6ezJG1QviSZu2IBf1No,5935
12
+ sglang/backend/vertexai.py,sha256=BLfWf_tEgoHY9srCufJM5PLe3tql2j0G6ia7cPykxCM,4713
14
13
  sglang/lang/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
15
14
  sglang/lang/chat_template.py,sha256=1x4724K2oxu7VID40-5Megk7SbZb97PQCbRjLpoescU,5599
16
15
  sglang/lang/compiler.py,sha256=wNn_UqV6Sxl22mv-PpzFUtRgiFFV-Y4OYpO4LshEoRM,7527
17
- sglang/lang/interpreter.py,sha256=YqCqsVonZt_xwL1ZMBNXHRSyxnGUVQr736ESn1Q7NWE,22339
18
- sglang/lang/ir.py,sha256=9vUL68VkT3gmcDaLjTiwerM21UwlQg-95FRaIt32jSM,12380
16
+ sglang/lang/interpreter.py,sha256=0WTJxCB57WDBr_E6kW39wByhcG2nRFjEMTzOjAaNhrY,22453
17
+ sglang/lang/ir.py,sha256=uUnBRyaM-8suVOEb2qf4EAt_VN2pWbXV6V88jLk6wsI,13160
19
18
  sglang/lang/tracer.py,sha256=zH9DENdJBPEvWkThgwqvHOW7aC1EPC8xha_WpEj-SRs,8243
20
19
  sglang/srt/backend_config.py,sha256=7MdHjNsZeAKB9IWWxyrvyOjJJAdI5tl9hWl-MV7yHrI,226
21
20
  sglang/srt/hf_transformers_utils.py,sha256=soRyYLoCn7GxgxvonufGFkdFBA3eH5i3Izk_wi7p1l0,5285
@@ -23,14 +22,14 @@ sglang/srt/memory_pool.py,sha256=cN3Lrs9fn0DFmt67_IN4g06mPzKUxpbAJGUw4O33xbo,360
23
22
  sglang/srt/model_config.py,sha256=R7YaR8H8AmCJl_1XcSP0zII_5ebZNl0wMXNVANGWd2c,997
24
23
  sglang/srt/sampling_params.py,sha256=Sd9l_uIIuS_mhbzljKwTGDO9ESMviNOYGxOifc71RrY,2895
25
24
  sglang/srt/server.py,sha256=XxTS1K4N5y-ZknLBQefxk1UxC50l6DABVqJOrJ-NG74,6388
26
- sglang/srt/server_args.py,sha256=Fpj3To5hEgmWn9qCS-pfypOEh34x9xVmiHBoEx5Smbo,4932
27
- sglang/srt/utils.py,sha256=YtTLEtVnOTrjub0Ct_xjrKGtHIajiQ57FB38l6Dw3a4,5691
25
+ sglang/srt/server_args.py,sha256=ojox8nu2tgPEy_JlKKEvRenby4HKkmWk-1MpHy3PmnI,5771
26
+ sglang/srt/utils.py,sha256=-2F99bqYT99x1jScMjciJxgQec6CaH6PcCHSmrKHhhY,5692
28
27
  sglang/srt/constrained/fsm.py,sha256=H4kXSsV4IX2ow5TMmnmd-8ho4qqJ5mpVZ4MOH5FUtnY,12900
29
28
  sglang/srt/constrained/fsm_cache.py,sha256=KX4bFX5hj0W66SC9pSvst1ew7etaOMTtTC75z0enRME,1087
30
29
  sglang/srt/constrained/regex.py,sha256=CcV7KBOKS2ZxGoEr6BHG5okagNIGEXYvGvhKXu5gtDA,18689
31
30
  sglang/srt/constrained/tokenizer.py,sha256=rei9yKHFETcbDPOpI7bpIYdrBFgIBhGr_U-zb3r5Beo,7951
32
- sglang/srt/layers/context_flashattention_nopad.py,sha256=qQc35BVOYPoZlLbbTUWB3a43Zwd3v5ZKR_uFRoypUIU,5084
33
- sglang/srt/layers/extend_attention.py,sha256=X-3nrQBeUyA3_cp2vZH1dC85x-EF9rppiK95FocMnKA,11423
31
+ sglang/srt/layers/context_flashattention_nopad.py,sha256=GkjLiTkS4px_uLcW0aDocE3_OBXtujZ-SlsN2b2U7ng,5204
32
+ sglang/srt/layers/extend_attention.py,sha256=pWVE6ySnPiVLFON__bie73eDhmXHk4tECMK8zTiJNbI,12558
34
33
  sglang/srt/layers/get_selected_logprob.py,sha256=CpMXM9WXMSB-AVaxBB_aVl1Qx_ZtAFFnjDTm4CgNDpU,2199
35
34
  sglang/srt/layers/logits_processor.py,sha256=rwcXwdZ7-dW9zvJX3MF_EHSxMLbU7TIQ9xUIYRu-WAs,3013
36
35
  sglang/srt/layers/radix_attention.py,sha256=hmPNFg2TkN4EAVUj376N_89RRtUYRwFgUpjj5SydnRk,6170
@@ -40,18 +39,18 @@ sglang/srt/managers/io_struct.py,sha256=5jMWj6_U8yTQd5V3tpDtThnoFyF0A3ln-4Z5bSL3
40
39
  sglang/srt/managers/openai_protocol.py,sha256=Eid_734Wup4jsL1ZS2Op0vwRuzvNbF4mV2UcwFxqEvI,327
41
40
  sglang/srt/managers/tokenizer_manager.py,sha256=jVwr0lM18RFJLhDb5TWlUpQ4Q8tALT4L6GY0jmaZkLw,7861
42
41
  sglang/srt/managers/router/infer_batch.py,sha256=UfS1uVhGnM-62Xv1cfu_IoTeIUxkjkKc4W3trtGbadc,11541
43
- sglang/srt/managers/router/manager.py,sha256=H-T-LlnIssHw-FXMHbs3yDQewkTMBCqG6jTYjugopCA,2527
44
- sglang/srt/managers/router/model_rpc.py,sha256=ZLK5izxMGpfCs4uT7DJ8u-aww5UG_jwjr7eJdbWGZ3Y,19271
45
- sglang/srt/managers/router/model_runner.py,sha256=U-SBnEeLvwolLcaxyxrPgVG7PnR2rRvuXWV50t9y0Fo,16480
42
+ sglang/srt/managers/router/manager.py,sha256=AVCdYKKYcIQsIwpudkfFY4jh6M--ubLjXeYGzfi2ebw,2528
43
+ sglang/srt/managers/router/model_rpc.py,sha256=CR3qbHvShttlC19qAZ8B8nhT6UPobeu2Dy3Z0n6WdC8,19448
44
+ sglang/srt/managers/router/model_runner.py,sha256=IhSdpBcd54HN01HDi_PAkJztFxEGDcnktdoPZDWEx3s,16487
46
45
  sglang/srt/managers/router/radix_cache.py,sha256=ZQPm9HhQ7vD3Gl5nhuvw3ZW4ZRARcplqWed1GYUvHCg,6441
47
46
  sglang/srt/managers/router/scheduler.py,sha256=ejuIRwqqMZVXFKUionRJxy5AtNvK25YoGRO9rFY-rc8,2926
48
47
  sglang/srt/models/llama2.py,sha256=D3j-NtyM8PA74UhXM7wSPogI2HKX-JcQAWcOusrZZo0,11320
49
48
  sglang/srt/models/llava.py,sha256=COS0IC6Yo-QiwKe5emgCbtEe9HgaSu5tt6CQA7UtV38,8533
50
- sglang/srt/models/mixtral.py,sha256=j91xOt6NZ5tJiyTPqmUSzgJqFAw7vTDnfBtEs5x0jDM,13714
51
- sglang/test/test_programs.py,sha256=ua3wufnS3x6d_U3aboY4ivqoglrRPZj18j96vuiUtiE,11348
49
+ sglang/srt/models/mixtral.py,sha256=frd2XsNZwP0XsQtRiYhgy4PErLNLgtIsLakmNrOKBAU,13712
50
+ sglang/test/test_programs.py,sha256=EovA2xL7fODcTbFj2wAAmYKlg1mLZ1x1BRU6nrXFRdE,11416
52
51
  sglang/test/test_utils.py,sha256=Knxg3BTA6d_7XSlprbBCdvfDr2SN5x7LhkT-tZFk5EQ,4828
53
- sglang-0.1.3.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
54
- sglang-0.1.3.dist-info/METADATA,sha256=SSRJ09MVErF7DrD5lJLm2oBDkk7sySET3AVaxJMciKs,21885
55
- sglang-0.1.3.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
56
- sglang-0.1.3.dist-info/top_level.txt,sha256=yxhh3pYQkcnA7v3Bg889C2jZhvtJdEincysO7PEB09M,7
57
- sglang-0.1.3.dist-info/RECORD,,
52
+ sglang-0.1.5.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
53
+ sglang-0.1.5.dist-info/METADATA,sha256=aepmAL6VoXRcxZBIDKvxwikCYSbvWFm_JFGTxb3Mgfw,23345
54
+ sglang-0.1.5.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
55
+ sglang-0.1.5.dist-info/top_level.txt,sha256=yxhh3pYQkcnA7v3Bg889C2jZhvtJdEincysO7PEB09M,7
56
+ sglang-0.1.5.dist-info/RECORD,,
@@ -1,349 +0,0 @@
1
- import functools
2
- from enum import Enum, auto
3
- from typing import Callable, List, Optional, Union
4
-
5
- import numpy as np
6
- import torch
7
- import transformers
8
- from sglang.backend.base_backend import BaseBackend
9
- from sglang.lang.chat_template import get_chat_template_by_model_path
10
- from sglang.lang.interpreter import ProgramState
11
- from sglang.utils import get_available_gpu_memory
12
- from transformers import (
13
- AutoModelForCausalLM,
14
- AutoTokenizer,
15
- StoppingCriteria,
16
- StoppingCriteriaList,
17
- )
18
- from transformersgl.generation.logits_process import (
19
- LogitsProcessorList,
20
- RepetitionPenaltyLogitsProcessor,
21
- TemperatureLogitsWarper,
22
- TopKLogitsWarper,
23
- TopPLogitsWarper,
24
- )
25
-
26
-
27
- class StopReason(Enum):
28
- EOS_TOKEN = auto()
29
- STOP_STR = auto()
30
- LENGTH = auto()
31
-
32
-
33
- def load_model(
34
- model_name: str,
35
- device,
36
- num_gpus,
37
- max_gpu_memory,
38
- model_kwargs=None,
39
- tokenizer_kwargs=None,
40
- ):
41
- model_kwargs = model_kwargs or {}
42
- tokenizer_kwargs = tokenizer_kwargs or {}
43
-
44
- if device == "cuda":
45
- model_kwargs["torch_dtype"] = torch.float16
46
- if num_gpus != 1:
47
- model_kwargs["device_map"] = "auto"
48
- if max_gpu_memory is None:
49
- model_kwargs[
50
- "device_map"
51
- ] = "sequential" # This is important for not the same VRAM sizes
52
- available_gpu_memory = [
53
- get_available_gpu_memory(i, False) for i in range(num_gpus)
54
- ]
55
- model_kwargs["max_memory"] = {
56
- i: str(int(available_gpu_memory[i] * 0.85)) + "GiB"
57
- for i in range(num_gpus)
58
- }
59
- else:
60
- model_kwargs["max_memory"] = {
61
- i: max_gpu_memory for i in range(num_gpus)
62
- }
63
- elif device == "cpu":
64
- model_kwargs["torch_dtype"] = torch.float32
65
- else:
66
- raise ValueError(f"Invalid device: {device}")
67
-
68
- model = AutoModelForCausalLM.from_pretrained(
69
- model_name, low_cpu_mem_usage=True, **model_kwargs
70
- )
71
- tokenizer = AutoTokenizer.from_pretrained(model_name, **tokenizer_kwargs)
72
-
73
- if num_gpus == 1:
74
- model.to(device).eval()
75
-
76
- return model, tokenizer
77
-
78
-
79
- def prepare_logits_processor(
80
- temperature: float, repetition_penalty: float, top_p: float, top_k: int
81
- ) -> LogitsProcessorList:
82
- processor_list = LogitsProcessorList()
83
- # TemperatureLogitsWarper doesn't accept 0.0, 1.0 makes it a no-op so we skip two cases.
84
- if temperature >= 1e-5 and temperature != 1.0:
85
- processor_list.append(TemperatureLogitsWarper(temperature))
86
- if repetition_penalty > 1.0:
87
- processor_list.append(RepetitionPenaltyLogitsProcessor(repetition_penalty))
88
- if 1e-8 <= top_p < 1.0:
89
- processor_list.append(TopPLogitsWarper(top_p))
90
- if top_k > 0:
91
- processor_list.append(TopKLogitsWarper(top_k))
92
- return processor_list
93
-
94
-
95
- @functools.lru_cache
96
- def get_token_healing_mask(tokenizer, prompt_last_token):
97
- last_str = tokenizer.convert_ids_to_tokens(prompt_last_token)
98
- disallowed = torch.zeros(len(tokenizer), dtype=bool)
99
- for s, t_id in tokenizer.get_vocab().items():
100
- if not s.startswith(last_str):
101
- disallowed[t_id] = 1
102
- return disallowed
103
-
104
-
105
- @functools.lru_cache
106
- def get_int_token_mask(tokenizer):
107
- disallowed = torch.zeros(len(tokenizer), dtype=bool)
108
- for s, t_id in tokenizer.get_vocab().items():
109
- s = s.replace("▁", "").strip()
110
- if not (s.isdigit() or len(s) == 0 or s == ","):
111
- disallowed[t_id] = 1
112
- disallowed[tokenizer.eos_token_id] = 0
113
- return disallowed
114
-
115
-
116
- @torch.inference_mode()
117
- def generate_stream(
118
- model,
119
- tokenizer,
120
- prompt,
121
- max_new_tokens,
122
- stop: List[str],
123
- temperature,
124
- top_p,
125
- token_healing,
126
- logit_mask=None,
127
- ):
128
- logits_processor = prepare_logits_processor(
129
- temperature=temperature, repetition_penalty=1.0, top_p=top_p, top_k=0
130
- )
131
- device = model.device
132
- input_ids = tokenizer.encode(prompt)
133
- output_ids = list(input_ids)
134
- prompt_len = len(prompt)
135
-
136
- # Resolve stop
137
- stop_token_ids = [tokenizer.eos_token_id]
138
-
139
- # Token healing
140
- token_healing = token_healing and len(input_ids) > 0
141
- if token_healing:
142
- token_healing_mask = get_token_healing_mask(tokenizer, input_ids[-1])
143
- del output_ids[-1]
144
-
145
- # Generate
146
- past_key_values = None
147
- stop_reason = None
148
- for i in range(max_new_tokens):
149
- # Forward
150
- if i == 0: # prefill
151
- out = model(torch.as_tensor([output_ids], device=device), use_cache=True)
152
- else: # decoding
153
- out = model(
154
- input_ids=torch.as_tensor([[token]], device=device),
155
- use_cache=True,
156
- past_key_values=past_key_values,
157
- )
158
- logits = out.logits
159
- past_key_values = out.past_key_values
160
-
161
- # Logit mask
162
- if token_healing and i == 0:
163
- logits[0, -1, token_healing_mask] = -1e4
164
- if logit_mask is not None:
165
- logits[0, -1, logit_mask] = -1e4
166
-
167
- # Sample next token
168
- last_token_logits = logits_processor(None, logits[:, -1, :])[0]
169
- if temperature < 1e-5 or top_p < 1e-8: # greedy
170
- token = int(torch.argmax(last_token_logits))
171
- else:
172
- probs = torch.softmax(last_token_logits, dim=-1)
173
- token = int(torch.multinomial(probs, num_samples=1))
174
- output_ids.append(token)
175
-
176
- # Stop condition
177
- if token in stop_token_ids:
178
- stop_reason = StopReason.EOS_TOKEN
179
- break
180
-
181
- output_str = tokenizer.decode(output_ids, skip_special_tokens=True)
182
- for stop_str in stop:
183
- pos = output_str[prompt_len:].find(stop_str)
184
- if pos != -1:
185
- stop_reason = StopReason.STOP_STR
186
- output_str = output_str[: prompt_len + pos]
187
- break
188
-
189
- if stop_reason:
190
- break
191
-
192
- return output_str[prompt_len:]
193
-
194
-
195
- class HuggingFaceTransformers(BaseBackend):
196
- def __init__(
197
- self,
198
- model_name,
199
- device="cuda",
200
- num_gpus=1,
201
- max_gpu_memory=None,
202
- model_kwargs=None,
203
- tokenizer_kwargs=None,
204
- ):
205
- self.model_name = model_name
206
- self.device = device
207
-
208
- self.model, self.tokenizer = load_model(
209
- model_name, device, num_gpus, max_gpu_memory, model_kwargs, tokenizer_kwargs
210
- )
211
-
212
- self.chat_template = get_chat_template_by_model_path(model_name)
213
-
214
- def get_chat_template(self):
215
- return self.chat_template
216
-
217
- def cache_prefix(self, prefix_str: str):
218
- pass
219
-
220
- def uncache_prefix(self, rid: str):
221
- pass
222
-
223
- def end_request(self, rid: str):
224
- pass
225
-
226
- def begin_program(self, s: ProgramState):
227
- pass
228
-
229
- def end_program(self, s: ProgramState):
230
- pass
231
-
232
- def fill(self, s: ProgramState, text: str):
233
- return False
234
-
235
- def generate_internal(
236
- self,
237
- prompt: str,
238
- max_tokens: int,
239
- stop: Union[str, List[str]],
240
- temperature: float,
241
- top_p: float,
242
- dtype: Optional[str] = None,
243
- ):
244
- if dtype is None:
245
- comp = generate_stream(
246
- self.model,
247
- self.tokenizer,
248
- prompt,
249
- max_new_tokens=max_tokens,
250
- stop=stop,
251
- temperature=temperature,
252
- top_p=top_p,
253
- token_healing=True,
254
- )
255
- elif dtype in [str, "str", "string"]:
256
- comp = generate_stream(
257
- self.model,
258
- self.tokenizer,
259
- prompt + '"',
260
- max_new_tokens=max_tokens,
261
- stop=['"'],
262
- temperature=temperature,
263
- top_p=top_p,
264
- token_healing=False,
265
- )
266
- comp = '"' + comp + '"'
267
- elif dtype in [int, "int"]:
268
- logit_mask = get_int_token_mask(self.tokenizer)
269
- comp = generate_stream(
270
- self.model,
271
- self.tokenizer,
272
- prompt,
273
- max_new_tokens=max_tokens,
274
- stop=stop + [" ", ","],
275
- temperature=temperature,
276
- top_p=top_p,
277
- token_healing=False,
278
- logit_mask=logit_mask,
279
- )
280
- return comp
281
-
282
- def generate(
283
- self,
284
- s: ProgramState,
285
- max_tokens: int,
286
- stop: Union[str, List[str]],
287
- temperature: float,
288
- top_p: float,
289
- dtype: Optional[str] = None,
290
- ):
291
- prompt = s.text
292
- comp = self.generate_internal(
293
- prompt, max_tokens, stop, temperature, top_p, dtype
294
- )
295
- return comp
296
-
297
- def parallel_generate(
298
- self,
299
- s: ProgramState,
300
- prefixes: List[str],
301
- join_func: Callable,
302
- max_tokens: int,
303
- stop: Union[str, List[str]],
304
- temperature: float,
305
- top_p: float,
306
- dtype: Optional[str] = None,
307
- ):
308
- prompt = s.text
309
- parallel_prompts = [prompt + prefix for prefix in prefixes]
310
-
311
- comps = []
312
- for i in range(len(parallel_prompts)):
313
- comps.append(
314
- self.generate_internal(
315
- parallel_prompts[i], max_tokens, stop, temperature, top_p, dtype
316
- )
317
- )
318
-
319
- joined = join_func([p + c for p, c in zip(prefixes, comps)])
320
- return joined, comps
321
-
322
- @torch.inference_mode()
323
- def select(
324
- self, s: ProgramState, choices: List[str], temperature: float, top_p: float
325
- ):
326
- loss_fct = torch.nn.CrossEntropyLoss()
327
- prompt = s.text
328
-
329
- prompt_len = self.tokenizer.encode(prompt, return_tensors="pt").shape[1]
330
- prompt_choices = [prompt + choice for choice in choices]
331
-
332
- scores = []
333
- for i in range(len(choices)):
334
- choice_ids = self.tokenizer.encode(
335
- prompt_choices[i], return_tensors="pt"
336
- ).to(self.model.device)
337
- logits = self.model(choice_ids).logits
338
-
339
- # score = -loss_fct(logits[0, :-1, :], choice_ids[0, 1:]).item()
340
-
341
- logprobs = torch.log(torch.softmax(logits, dim=-1))
342
- idx1 = torch.arange(0, logits.shape[1] - 1, device=logits.device)
343
- idx2 = choice_ids[0, 1:]
344
- selected_logprobs = logprobs[0, idx1, idx2]
345
- score = selected_logprobs.mean().item()
346
- scores.append(score)
347
-
348
- decision = choices[np.argmax(scores)]
349
- return decision, scores
sglang/backend/tgi.py DELETED
@@ -1,190 +0,0 @@
1
- import re
2
- from concurrent.futures import ThreadPoolExecutor
3
- from functools import partial
4
- from itertools import repeat
5
- from typing import List, Optional, Union
6
-
7
- from sglang.backend.base_backend import BaseBackend
8
- from sglang.lang.chat_template import get_chat_template_by_model_path
9
- from sglang.lang.interpreter import StreamExecutor
10
- from sglang.lang.ir import SglSamplingParams
11
- from sglang.utils import http_request
12
-
13
-
14
- class TGI(BaseBackend):
15
- def __init__(self, base_url):
16
- super().__init__()
17
-
18
- self.base_url = base_url
19
-
20
- res = http_request(self.base_url + "/info")
21
- assert res.status_code == 200
22
- self.model_info = res.json()
23
- self.chat_template = get_chat_template_by_model_path(
24
- self.model_info["model_id"]
25
- )
26
-
27
- def get_model_name(self):
28
- return self.model_info["model_id"]
29
-
30
- def get_chat_template(self):
31
- return self.chat_template
32
-
33
- @staticmethod
34
- def adapt_params(max_tokens, stop, sampling_params, **override_params):
35
- temperature = sampling_params.temperature
36
- do_sample = True
37
- if temperature == 0:
38
- do_sample = False
39
- temperature = None
40
-
41
- if stop is None:
42
- stop = []
43
- elif isinstance(stop, str):
44
- stop = [stop]
45
-
46
- top_p = sampling_params.top_p
47
- if top_p == 0:
48
- top_p = 0.001
49
- if top_p == 1:
50
- top_p = 0.999
51
-
52
- top_k = sampling_params.top_k
53
- if top_k == -1:
54
- top_k = None
55
-
56
- params = {
57
- "decoder_input_details": False,
58
- "details": False,
59
- "do_sample": do_sample,
60
- "max_new_tokens": max_tokens,
61
- "stop": stop,
62
- "temperature": temperature,
63
- "top_p": top_p,
64
- "top_k": top_k,
65
- "return_full_text": False,
66
- }
67
- params.update(override_params)
68
- return params
69
-
70
- @staticmethod
71
- def _extract_int(text):
72
- words = re.split("\ |'|\/|\(|\)|\n|\.|,", text)
73
- for word in words:
74
- try:
75
- int(word)
76
- return word
77
- except ValueError:
78
- continue
79
- raise ValueError
80
-
81
- @staticmethod
82
- def _extract_choice(choices, text):
83
- # FIXME: Current only support the case where the choices are single words.
84
- words = re.split("\ |'|\/|\(|\)|\n|\.|,", text)
85
- for word in words:
86
- if word in choices:
87
- return word
88
- raise ValueError
89
-
90
- @staticmethod
91
- def _truncate_to_stop(text, stop):
92
- # The stop sequence may not be a single token. In this case TGI will generate
93
- # too many tokens so we need to truncate the output.
94
- if stop:
95
- stop = [stop] if isinstance(stop, str) else stop
96
- for stop_seq in stop:
97
- pos = text.find(stop_seq)
98
- if pos != -1:
99
- return text[:pos]
100
- return text
101
-
102
- def _make_request(self, params):
103
- res = http_request(self.base_url + "/generate", json=params)
104
- if res.status_code != 200:
105
- raise ValueError(f"Error from TGI backend: {res.text}")
106
- return res.json()
107
-
108
- def retry_for_expected(self, prompt, params, extract_fn, retry=5):
109
- # TGI does not support logis_bias (yet), so we have to use an inefficient hack.
110
- failed = []
111
- while retry > 0:
112
- res_json = self._make_request(
113
- {
114
- "inputs": prompt,
115
- "parameters": params,
116
- }
117
- )
118
- text = res_json["generated_text"]
119
- try:
120
- return extract_fn(text)
121
- except ValueError:
122
- retry -= 1
123
- failed.append(text)
124
-
125
- msg = "=" * 20 + "\n"
126
- msg += f"Prompt:\n{prompt}\n"
127
- msg += "=" * 20 + "\n"
128
- for i, text in enumerate(failed):
129
- msg += f"====== Try {i+1}:\n{text}\n"
130
-
131
- raise ValueError(
132
- f"Model {self.model_info['model_id']} served by TGI backend does not generate"
133
- "expected output. Please improve the prompt, increase the temperature, or "
134
- f"use different models.\n{msg}"
135
- )
136
-
137
- def select(
138
- self,
139
- s: StreamExecutor,
140
- choices: List[str],
141
- sampling_params: SglSamplingParams,
142
- ):
143
- decision = self.retry_for_expected(
144
- s.text_,
145
- self.adapt_params(16, [], sampling_params),
146
- partial(self._extract_choice, choices),
147
- )
148
- return decision, [1 if choice == decision else 0 for choice in choices]
149
-
150
- def generate(
151
- self,
152
- s: StreamExecutor,
153
- max_tokens: int,
154
- stop: Union[str, List[str]],
155
- sampling_params: SglSamplingParams,
156
- dtype: Optional[str] = None,
157
- ):
158
- if dtype is None:
159
- res_json = self._make_request(
160
- {
161
- "inputs": s.text_,
162
- "parameters": self.adapt_params(max_tokens, stop, sampling_params),
163
- }
164
- )
165
- return self._truncate_to_stop(res_json["generated_text"], stop), {}
166
-
167
- if dtype in [str, "str", "string"]:
168
- stop = ['"']
169
- res_json = self._make_request(
170
- {
171
- "inputs": f'{s.text_}"',
172
- "parameters": self.adapt_params(max_tokens, stop, sampling_params),
173
- }
174
- )
175
- return (
176
- '"' + self._truncate_to_stop(res_json["generated_text"], stop) + '"',
177
- {},
178
- )
179
-
180
- if dtype in [int, "int"]:
181
- return (
182
- self.retry_for_expected(
183
- s.text_,
184
- self.adapt_params(max_tokens, stop, sampling_params),
185
- self._extract_int,
186
- ),
187
- {},
188
- )
189
-
190
- raise ValueError(f"Unknown dtype: {dtype}")
File without changes