sglang 0.1.4__py3-none-any.whl → 0.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sglang/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- __version__ = "0.1.4"
1
+ __version__ = "0.1.5"
2
2
 
3
3
  from sglang.api import *
4
4
  from sglang.global_config import global_config
sglang/api.py CHANGED
@@ -6,6 +6,7 @@ from sglang.backend.anthropic import Anthropic
6
6
  from sglang.backend.base_backend import BaseBackend
7
7
  from sglang.backend.openai import OpenAI
8
8
  from sglang.backend.runtime_endpoint import RuntimeEndpoint
9
+ from sglang.backend.vertexai import VertexAI
9
10
  from sglang.global_config import global_config
10
11
  from sglang.lang.ir import (
11
12
  SglExpr,
@@ -0,0 +1,147 @@
1
+ import os
2
+ import warnings
3
+ from typing import List, Optional, Union
4
+
5
+ import numpy as np
6
+ from sglang.backend.base_backend import BaseBackend
7
+ from sglang.lang.chat_template import get_chat_template
8
+ from sglang.lang.interpreter import StreamExecutor
9
+ from sglang.lang.ir import SglSamplingParams
10
+
11
+ try:
12
+ import vertexai
13
+ from vertexai.preview.generative_models import (
14
+ GenerationConfig,
15
+ GenerativeModel,
16
+ Image,
17
+ )
18
+ except ImportError as e:
19
+ GenerativeModel = e
20
+
21
+
22
+ class VertexAI(BaseBackend):
23
+ def __init__(self, model_name):
24
+ super().__init__()
25
+
26
+ if isinstance(GenerativeModel, Exception):
27
+ raise GenerativeModel
28
+
29
+ project_id = os.environ["GCP_PROJECT_ID"]
30
+ location = os.environ.get("GCP_LOCATION")
31
+ vertexai.init(project=project_id, location=location)
32
+
33
+ self.model_name = model_name
34
+ self.chat_template = get_chat_template("default")
35
+
36
+ def get_chat_template(self):
37
+ return self.chat_template
38
+
39
+ def generate(
40
+ self,
41
+ s: StreamExecutor,
42
+ sampling_params: SglSamplingParams,
43
+ ):
44
+ if s.messages_:
45
+ prompt = self.messages_to_vertexai_input(s.messages_)
46
+ else:
47
+ # single-turn
48
+ prompt = (
49
+ self.text_to_vertexai_input(s.text_, s.cur_images)
50
+ if s.cur_images
51
+ else s.text_
52
+ )
53
+ ret = GenerativeModel(self.model_name).generate_content(
54
+ prompt,
55
+ generation_config=GenerationConfig(**sampling_params.to_vertexai_kwargs()),
56
+ )
57
+
58
+ comp = ret.text
59
+
60
+ return comp, {}
61
+
62
+ def generate_stream(
63
+ self,
64
+ s: StreamExecutor,
65
+ sampling_params: SglSamplingParams,
66
+ ):
67
+ if s.messages_:
68
+ prompt = self.messages_to_vertexai_input(s.messages_)
69
+ else:
70
+ # single-turn
71
+ prompt = (
72
+ self.text_to_vertexai_input(s.text_, s.cur_images)
73
+ if s.cur_images
74
+ else s.text_
75
+ )
76
+ generator = GenerativeModel(self.model_name).generate_content(
77
+ prompt,
78
+ stream=True,
79
+ generation_config=GenerationConfig(**sampling_params.to_vertexai_kwargs()),
80
+ )
81
+ for ret in generator:
82
+ yield ret.text, {}
83
+
84
+ def text_to_vertexai_input(self, text, images):
85
+ input = []
86
+ # split with image token
87
+ text_segs = text.split(self.chat_template.image_token)
88
+ for image_path, image_base64_data in images:
89
+ text_seg = text_segs.pop(0)
90
+ if text_seg != "":
91
+ input.append(text_seg)
92
+ input.append(Image.from_bytes(image_base64_data))
93
+ text_seg = text_segs.pop(0)
94
+ if text_seg != "":
95
+ input.append(text_seg)
96
+ return input
97
+
98
+ def messages_to_vertexai_input(self, messages):
99
+ vertexai_message = []
100
+ # from openai message format to vertexai message format
101
+ for msg in messages:
102
+ if isinstance(msg["content"], str):
103
+ text = msg["content"]
104
+ else:
105
+ text = msg["content"][0]["text"]
106
+
107
+ if msg["role"] == "system":
108
+ warnings.warn("Warning: system prompt is not supported in VertexAI.")
109
+ vertexai_message.append(
110
+ {
111
+ "role": "user",
112
+ "parts": [{"text": "System prompt: " + text}],
113
+ }
114
+ )
115
+ vertexai_message.append(
116
+ {
117
+ "role": "model",
118
+ "parts": [{"text": "Understood."}],
119
+ }
120
+ )
121
+ continue
122
+ if msg["role"] == "user":
123
+ vertexai_msg = {
124
+ "role": "user",
125
+ "parts": [{"text": text}],
126
+ }
127
+ elif msg["role"] == "assistant":
128
+ vertexai_msg = {
129
+ "role": "model",
130
+ "parts": [{"text": text}],
131
+ }
132
+
133
+ # images
134
+ if isinstance(msg["content"], list) and len(msg["content"]) > 1:
135
+ for image in msg["content"][1:]:
136
+ assert image["type"] == "image_url"
137
+ vertexai_msg["parts"].append(
138
+ {
139
+ "inline_data": {
140
+ "data": image["image_url"]["url"].split(",")[1],
141
+ "mime_type": "image/jpeg",
142
+ }
143
+ }
144
+ )
145
+
146
+ vertexai_message.append(vertexai_msg)
147
+ return vertexai_message
@@ -365,11 +365,10 @@ class StreamExecutor:
365
365
  for comp, meta_info in generator:
366
366
  self.text_ += comp
367
367
  self.variables[name] += comp
368
+ self.meta_info[name] = meta_info
368
369
  self.stream_var_event[name].set()
369
370
  self.stream_text_event.set()
370
371
 
371
- self.meta_info[name] = meta_info
372
-
373
372
  self.variable_event[name].set()
374
373
  self.stream_var_event[name].set()
375
374
 
@@ -428,6 +427,7 @@ class StreamExecutor:
428
427
  self.messages_.append(last_msg)
429
428
  self.cur_images = []
430
429
  else:
430
+ # OpenAI chat API format
431
431
  self.messages_.append({"role": expr.role, "content": new_text})
432
432
 
433
433
  self.cur_role = None
@@ -582,7 +582,7 @@ class ProgramState:
582
582
  else:
583
583
  yield self.get_var(name)
584
584
 
585
- async def text_async_iter(self, var_name=None):
585
+ async def text_async_iter(self, var_name=None, return_meta_data=False):
586
586
  loop = asyncio.get_running_loop()
587
587
 
588
588
  if self.stream_executor.stream:
@@ -606,7 +606,10 @@ class ProgramState:
606
606
  out = str(self.stream_executor.variables[var_name][prev:])
607
607
  prev += len(out)
608
608
  if out:
609
- yield out
609
+ if return_meta_data:
610
+ yield out, self.stream_executor.meta_info[var_name]
611
+ else:
612
+ yield out
610
613
  if self.stream_executor.variable_event[var_name].is_set():
611
614
  break
612
615
  else:
@@ -632,11 +635,7 @@ class ProgramState:
632
635
  self.stream_executor.end()
633
636
 
634
637
  def __repr__(self) -> str:
635
- msgs = self.messages()
636
- ret = ""
637
- for msg in msgs:
638
- ret += msg["role"] + ":\n" + msg["content"] + "\n"
639
- return ret
638
+ return f"ProgramState({self.text()})"
640
639
 
641
640
 
642
641
  class ProgramStateGroup:
sglang/lang/ir.py CHANGED
@@ -2,6 +2,7 @@
2
2
 
3
3
  import dataclasses
4
4
  import inspect
5
+ import warnings
5
6
  from typing import List, Optional, Union
6
7
 
7
8
  from sglang.global_config import global_config
@@ -40,6 +41,8 @@ class SglSamplingParams:
40
41
 
41
42
  def to_openai_kwargs(self):
42
43
  # OpenAI does not support top_k, so we drop it here
44
+ if self.regex is not None:
45
+ warnings.warn("Regular expression is not supported in the OpenAI backend.")
43
46
  return {
44
47
  "max_tokens": self.max_new_tokens,
45
48
  "stop": self.stop or None,
@@ -49,8 +52,26 @@ class SglSamplingParams:
49
52
  "presence_penalty": self.presence_penalty,
50
53
  }
51
54
 
55
+ def to_vertexai_kwargs(self):
56
+ if self.regex is not None:
57
+ warnings.warn(
58
+ "Regular expression is not supported in the VertexAI backend."
59
+ )
60
+ return {
61
+ "candidate_count": 1,
62
+ "max_output_tokens": self.max_new_tokens,
63
+ "stop_sequences": self.stop,
64
+ "temperature": self.temperature,
65
+ "top_p": self.top_p,
66
+ "top_k": self.top_k if self.top_k > 0 else None,
67
+ }
68
+
52
69
  def to_anthropic_kwargs(self):
53
70
  # Anthropic does not support frequency_penalty or presence_penalty, so we drop it here
71
+ if self.regex is not None:
72
+ warnings.warn(
73
+ "Regular expression is not supported in the Anthropic backend."
74
+ )
54
75
  return {
55
76
  "max_tokens_to_sample": self.max_new_tokens,
56
77
  "stop_sequences": self.stop,
@@ -5,7 +5,6 @@ import triton
5
5
  import triton.language as tl
6
6
  from sglang.srt.utils import wrap_kernel_launcher
7
7
 
8
-
9
8
  CUDA_CAPABILITY = torch.cuda.get_device_capability()
10
9
 
11
10
 
@@ -4,7 +4,6 @@ import triton.language as tl
4
4
  from sglang.srt.layers.context_flashattention_nopad import context_attention_fwd
5
5
  from sglang.srt.utils import wrap_kernel_launcher
6
6
 
7
-
8
7
  CUDA_CAPABILITY = torch.cuda.get_device_capability()
9
8
 
10
9
 
@@ -28,7 +28,7 @@ class RouterManager:
28
28
  self.model_client = model_client
29
29
  self.recv_reqs = []
30
30
 
31
- # Init Some Configs
31
+ # Init some configs
32
32
  self.extend_dependency_time = GLOBAL_BACKEND_CONFIG.extend_dependency_time
33
33
 
34
34
  async def loop_for_forward(self):
@@ -46,7 +46,7 @@ class RouterManager:
46
46
  if has_finished:
47
47
  await asyncio.sleep(self.extend_dependency_time)
48
48
 
49
- await asyncio.sleep(0.001)
49
+ await asyncio.sleep(0.0006)
50
50
 
51
51
  async def loop_for_recv_requests(self):
52
52
  while True:
@@ -2,10 +2,10 @@ import asyncio
2
2
  import logging
3
3
  import multiprocessing
4
4
  import time
5
+ import warnings
5
6
  from concurrent.futures import ThreadPoolExecutor
6
7
  from enum import Enum, auto
7
8
  from typing import Dict, List, Optional, Tuple, Union
8
- import warnings
9
9
 
10
10
  import numpy as np
11
11
  import rpyc
@@ -45,6 +45,7 @@ class ModelRpcServer(rpyc.Service):
45
45
  self.tp_rank = tp_rank
46
46
  self.tp_size = server_args.tp_size
47
47
  self.schedule_heuristic = server_args.schedule_heuristic
48
+ self.schedule_conservativeness = server_args.schedule_conservativeness
48
49
 
49
50
  # Init model and tokenizer
50
51
  self.model_config = ModelConfig(
@@ -108,7 +109,7 @@ class ModelRpcServer(rpyc.Service):
108
109
  self.running_batch: Batch = None
109
110
  self.out_pyobjs = []
110
111
  self.decode_forward_ct = 0
111
- self.stream_interval = 2
112
+ self.stream_interval = server_args.stream_interval
112
113
 
113
114
  # Init the FSM cache for constrained generation
114
115
  self.regex_fsm_cache = FSMCache(self.tokenizer)
@@ -248,7 +249,9 @@ class ModelRpcServer(rpyc.Service):
248
249
  available_size = (
249
250
  self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
250
251
  )
251
- new_ratio = self.scheduler.new_token_estimation_ratio()
252
+ new_ratio = (
253
+ self.scheduler.new_token_estimation_ratio() * self.schedule_conservativeness
254
+ )
252
255
  if self.running_batch:
253
256
  available_size -= sum(
254
257
  [
@@ -278,7 +278,7 @@ class ModelRunner:
278
278
  load_format=self.load_format,
279
279
  revision=None,
280
280
  )
281
- self.model = model
281
+ self.model = model.eval()
282
282
 
283
283
  def profile_max_num_token(self, total_gpu_memory):
284
284
  available_gpu_memory = get_available_gpu_memory(
@@ -355,7 +355,7 @@ class MixtralForCausalLM(nn.Module):
355
355
  ):
356
356
  if "rotary_emb.inv_freq" in name:
357
357
  continue
358
- for (param_name, weight_name, shard_id) in stacked_params_mapping:
358
+ for param_name, weight_name, shard_id in stacked_params_mapping:
359
359
  if weight_name not in name:
360
360
  continue
361
361
  name = name.replace(weight_name, param_name)
sglang/srt/server_args.py CHANGED
@@ -16,7 +16,9 @@ class ServerArgs:
16
16
  tp_size: int = 1
17
17
  model_mode: List[str] = ()
18
18
  schedule_heuristic: str = "lpm"
19
+ schedule_conservativeness: float = 1.0
19
20
  random_seed: int = 42
21
+ stream_interval: int = 2
20
22
  disable_log_stats: bool = False
21
23
  log_stats_interval: int = 10
22
24
  log_level: str = "info"
@@ -25,10 +27,14 @@ class ServerArgs:
25
27
  if self.tokenizer_path is None:
26
28
  self.tokenizer_path = self.model_path
27
29
  if self.mem_fraction_static is None:
28
- if self.tp_size > 1:
29
- self.mem_fraction_static = 0.8
30
+ if self.tp_size >= 8:
31
+ self.mem_fraction_static = 0.80
32
+ elif self.tp_size >= 4:
33
+ self.mem_fraction_static = 0.82
34
+ elif self.tp_size >= 2:
35
+ self.mem_fraction_static = 0.85
30
36
  else:
31
- self.mem_fraction_static = 0.9
37
+ self.mem_fraction_static = 0.90
32
38
 
33
39
  @staticmethod
34
40
  def add_cli_args(parser: argparse.ArgumentParser):
@@ -80,7 +86,7 @@ class ServerArgs:
80
86
  "--mem-fraction-static",
81
87
  type=float,
82
88
  default=ServerArgs.mem_fraction_static,
83
- help="The fraction of the memory used for static allocation (model weights and KV cache memory pool)",
89
+ help="The fraction of the memory used for static allocation (model weights and KV cache memory pool). Use a smaller value if you see out-of-memory errors.",
84
90
  )
85
91
  parser.add_argument(
86
92
  "--tp-size",
@@ -102,12 +108,24 @@ class ServerArgs:
102
108
  default=ServerArgs.schedule_heuristic,
103
109
  help="Schudule mode: [lpm, weight, random, fcfs]",
104
110
  )
111
+ parser.add_argument(
112
+ "--schedule-conservativeness",
113
+ type=float,
114
+ default=ServerArgs.schedule_conservativeness,
115
+ help="How conservative the schedule policy is. A larger value means more conservative scheduling. Use a larger value if you see out-of-memory errors.",
116
+ )
105
117
  parser.add_argument(
106
118
  "--random-seed",
107
119
  type=int,
108
120
  default=ServerArgs.random_seed,
109
121
  help="Random seed.",
110
122
  )
123
+ parser.add_argument(
124
+ "--stream-interval",
125
+ type=int,
126
+ default=ServerArgs.stream_interval,
127
+ help="The interval in terms of token length for streaming",
128
+ )
111
129
  parser.add_argument(
112
130
  "--log-level",
113
131
  type=str,
@@ -304,7 +304,10 @@ def test_image_qa():
304
304
  temperature=0,
305
305
  max_new_tokens=64,
306
306
  )
307
- assert "taxi" in state.messages()[-1]["content"]
307
+ assert (
308
+ "taxi" in state.messages()[-1]["content"]
309
+ or "car" in state.messages()[-1]["content"]
310
+ )
308
311
 
309
312
 
310
313
  def test_stream():
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: sglang
3
- Version: 0.1.4
3
+ Version: 0.1.5
4
4
  Summary: A structured generation langauge for LLMs.
5
5
  License: Apache License
6
6
  Version 2.0, January 2004
@@ -234,6 +234,7 @@ Requires-Dist: lark ; extra == 'srt'
234
234
  Requires-Dist: numba ; extra == 'srt'
235
235
 
236
236
  # SGLang
237
+ | [**Blog**](https://lmsys.org/blog/2024-01-17-sglang/) | [**Paper**](https://arxiv.org/abs/2312.07104) |
237
238
 
238
239
  SGLang is a structured generation language designed for large language models (LLMs).
239
240
  It makes your interaction with LLMs faster and more controllable by co-designing the frontend language and the runtime system.
@@ -277,7 +278,7 @@ The example below shows how to use sglang to answer a mulit-turn question.
277
278
  ### Using OpenAI Models
278
279
  Set the OpenAI API Key
279
280
  ```
280
- export OPENAI_API_KEY=sk-xxxxxx
281
+ export OPENAI_API_KEY=sk-******
281
282
  ```
282
283
 
283
284
  Then, answer a multi-turn question.
@@ -335,6 +336,7 @@ for m in state.messages():
335
336
 
336
337
  ### More Examples
337
338
 
339
+ Anthropic and VertexAI (Gemini) models are also supported.
338
340
  You can find more examples at [examples/quick_start](examples/quick_start).
339
341
 
340
342
  ## Frontend: Structured Generation Langauge (SGLang)
@@ -350,13 +352,14 @@ You can then invoke the function with `run` or `run_batch`.
350
352
  The system will manage the state, chat template, and parallelism for you.
351
353
 
352
354
  ### Control Flow
355
+ You can use any Python code within the function body, including control flow, nested function calls, and external libraries.
356
+
353
357
  ```python
354
358
  @sgl.function
355
359
  def control_flow(s, question):
356
360
  s += "To answer this question: " + question + ", "
357
361
  s += "I need to use a " + sgl.gen("tool", choices=["calculator", "web browser"]) + ". "
358
362
 
359
- # You can use if or nested function calls
360
363
  if s["tool"] == "calculator":
361
364
  s += "The math expression is" + sgl.gen("expression")
362
365
  elif s["tool"] == "web browser":
@@ -364,6 +367,9 @@ def control_flow(s, question):
364
367
  ```
365
368
 
366
369
  ### Parallelism
370
+ Use `fork` to launch parallel prompts.
371
+ Because `sgl.gen` is non-blocking, the for loop below issues two generation calls in parallel.
372
+
367
373
  ```python
368
374
  @sgl.function
369
375
  def tip_suggestion(s):
@@ -372,7 +378,7 @@ def tip_suggestion(s):
372
378
  "1. Balanced Diet. 2. Regular Exercise.\n\n"
373
379
  )
374
380
 
375
- forks = s.fork(2) # Launch parallel prompts
381
+ forks = s.fork(2)
376
382
  for i, f in enumerate(forks):
377
383
  f += f"Now, expand tip {i+1} into a paragraph:\n"
378
384
  f += sgl.gen(f"detailed_tip", max_tokens=256, stop="\n\n")
@@ -383,6 +389,8 @@ def tip_suggestion(s):
383
389
  ```
384
390
 
385
391
  ### Multi Modality
392
+ Use `sgl.image` to pass an image as input.
393
+
386
394
  ```python
387
395
  @sgl.function
388
396
  def image_qa(s, image_file, question):
@@ -391,6 +399,8 @@ def image_qa(s, image_file, question):
391
399
  ```
392
400
 
393
401
  ### Constrained Decoding
402
+ Use `regex=` to specify a regular expression as a decoding constraint.
403
+
394
404
  ```python
395
405
  @sgl.function
396
406
  def regular_expression_gen(s):
@@ -403,6 +413,8 @@ def regular_expression_gen(s):
403
413
  ```
404
414
 
405
415
  ### Batching
416
+ Use `run_batch` to run a batch of requests with continuous batching.
417
+
406
418
  ```python
407
419
  @sgl.function
408
420
  def text_qa(s, question):
@@ -415,10 +427,13 @@ states = text_qa.run_batch(
415
427
  {"question": "What is the capital of France?"},
416
428
  {"question": "What is the capital of Japan?"},
417
429
  ],
430
+ progress_bar=True
418
431
  )
419
432
  ```
420
433
 
421
434
  ### Streaming
435
+ Add `stream=True` to enable streaming.
436
+
422
437
  ```python
423
438
  @sgl.function
424
439
  def text_qa(s, question):
@@ -427,7 +442,9 @@ def text_qa(s, question):
427
442
 
428
443
  states = text_qa.run(
429
444
  question="What is the capital of France?",
430
- temperature=0.1)
445
+ temperature=0.1,
446
+ stream=True
447
+ )
431
448
 
432
449
  for out in state.text_iter():
433
450
  print(out, end="", flush=True)
@@ -471,6 +488,7 @@ python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port
471
488
  - Mixtral
472
489
  - LLaVA
473
490
  - `python3 -m sglang.launch_server --model-path liuhaotian/llava-v1.5-7b --tokenizer-path llava-hf/llava-1.5-7b-hf --port 30000`
491
+ - AWQ quantization
474
492
 
475
493
  ## Benchmark And Performance
476
494
 
@@ -483,10 +501,10 @@ python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port
483
501
  Learn more [here](docs/benchmark_results.md).
484
502
 
485
503
  ## Roadmap
486
- - [ ] Function call
487
- - [ ] Quantization
504
+ - [ ] Function call APIs
488
505
  - [ ] S-LoRA
489
- - [ ] More models
506
+ - [ ] Support more models
507
+ - [ ] Support more hardware backends
490
508
 
491
509
  ## Citation And Acknowledgment
492
510
  ```
@@ -1,5 +1,5 @@
1
- sglang/__init__.py,sha256=lfYPLrb_Fy-J-l7NGMhnRDk0hlvAkCIzJEjzN6AsV0g,95
2
- sglang/api.py,sha256=tJuEyB28BUQfl0-dQr4vi6UMHBhUbmyu9Z3iAE5xFcU,3883
1
+ sglang/__init__.py,sha256=G73L_PWJ_6mF3NIE4ZAOWcb1CUbETSeRWr3wDTePrZ4,95
2
+ sglang/api.py,sha256=SxmPP_PMYi4DfUcwz_V9UvYOwGmQdHPgpMV6jDDJq68,3928
3
3
  sglang/flush_cache.py,sha256=cCD_MTlQ5qEv__w0nOthDnVitdAfyscYjksBljwC5Mw,1835
4
4
  sglang/global_config.py,sha256=PAX7TWeFcq0HBzNUWyCONAOjqIokWqw8vT7I6sBSKTc,797
5
5
  sglang/launch_server.py,sha256=jKPZRDN5bUe8Wgz5eoDkqeePhmKa8DLD4DpXQLT5auo,294
@@ -7,15 +7,14 @@ sglang/utils.py,sha256=tvJs95QGZ_PcnTjvm-CDGQ8dJe84qUUOfG7BeF79nsA,5670
7
7
  sglang/backend/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
8
8
  sglang/backend/anthropic.py,sha256=y5TN9EDrJtOH4JEUxpXu-endloeYBy7xMUr3r7Ah3MA,1462
9
9
  sglang/backend/base_backend.py,sha256=pPalZfoezxnUBs752j7lm0uMwa8tZuCWd-ijSdStMO8,1745
10
- sglang/backend/huggingface.py,sha256=roQlt8y41PQbmnAY47CXiR0KJaxhtljH6j8RhbsR4f0,10533
11
10
  sglang/backend/openai.py,sha256=umTWzC2p4PypDaXHe6Kc8By5IM_Doi0Ob97vK_fFWDc,7367
12
11
  sglang/backend/runtime_endpoint.py,sha256=rIhwtKJaLLCJAc6q6kqxEVC8xO_NNjmJs7BnxlOydLM,5860
13
- sglang/backend/tgi.py,sha256=2wlfparGJNLN806bvPi_8jsk6ezJG1QviSZu2IBf1No,5935
12
+ sglang/backend/vertexai.py,sha256=BLfWf_tEgoHY9srCufJM5PLe3tql2j0G6ia7cPykxCM,4713
14
13
  sglang/lang/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
15
14
  sglang/lang/chat_template.py,sha256=1x4724K2oxu7VID40-5Megk7SbZb97PQCbRjLpoescU,5599
16
15
  sglang/lang/compiler.py,sha256=wNn_UqV6Sxl22mv-PpzFUtRgiFFV-Y4OYpO4LshEoRM,7527
17
- sglang/lang/interpreter.py,sha256=YqCqsVonZt_xwL1ZMBNXHRSyxnGUVQr736ESn1Q7NWE,22339
18
- sglang/lang/ir.py,sha256=9vUL68VkT3gmcDaLjTiwerM21UwlQg-95FRaIt32jSM,12380
16
+ sglang/lang/interpreter.py,sha256=0WTJxCB57WDBr_E6kW39wByhcG2nRFjEMTzOjAaNhrY,22453
17
+ sglang/lang/ir.py,sha256=uUnBRyaM-8suVOEb2qf4EAt_VN2pWbXV6V88jLk6wsI,13160
19
18
  sglang/lang/tracer.py,sha256=zH9DENdJBPEvWkThgwqvHOW7aC1EPC8xha_WpEj-SRs,8243
20
19
  sglang/srt/backend_config.py,sha256=7MdHjNsZeAKB9IWWxyrvyOjJJAdI5tl9hWl-MV7yHrI,226
21
20
  sglang/srt/hf_transformers_utils.py,sha256=soRyYLoCn7GxgxvonufGFkdFBA3eH5i3Izk_wi7p1l0,5285
@@ -23,14 +22,14 @@ sglang/srt/memory_pool.py,sha256=cN3Lrs9fn0DFmt67_IN4g06mPzKUxpbAJGUw4O33xbo,360
23
22
  sglang/srt/model_config.py,sha256=R7YaR8H8AmCJl_1XcSP0zII_5ebZNl0wMXNVANGWd2c,997
24
23
  sglang/srt/sampling_params.py,sha256=Sd9l_uIIuS_mhbzljKwTGDO9ESMviNOYGxOifc71RrY,2895
25
24
  sglang/srt/server.py,sha256=XxTS1K4N5y-ZknLBQefxk1UxC50l6DABVqJOrJ-NG74,6388
26
- sglang/srt/server_args.py,sha256=Fpj3To5hEgmWn9qCS-pfypOEh34x9xVmiHBoEx5Smbo,4932
25
+ sglang/srt/server_args.py,sha256=ojox8nu2tgPEy_JlKKEvRenby4HKkmWk-1MpHy3PmnI,5771
27
26
  sglang/srt/utils.py,sha256=-2F99bqYT99x1jScMjciJxgQec6CaH6PcCHSmrKHhhY,5692
28
27
  sglang/srt/constrained/fsm.py,sha256=H4kXSsV4IX2ow5TMmnmd-8ho4qqJ5mpVZ4MOH5FUtnY,12900
29
28
  sglang/srt/constrained/fsm_cache.py,sha256=KX4bFX5hj0W66SC9pSvst1ew7etaOMTtTC75z0enRME,1087
30
29
  sglang/srt/constrained/regex.py,sha256=CcV7KBOKS2ZxGoEr6BHG5okagNIGEXYvGvhKXu5gtDA,18689
31
30
  sglang/srt/constrained/tokenizer.py,sha256=rei9yKHFETcbDPOpI7bpIYdrBFgIBhGr_U-zb3r5Beo,7951
32
- sglang/srt/layers/context_flashattention_nopad.py,sha256=yTQBOo-kKKu5F-YFBo26lCWKtyaGae1M5gsn2GZpfAE,5205
33
- sglang/srt/layers/extend_attention.py,sha256=nERsTpimVwdF-gHXmjy3D7zbSb4RrbVswmlzuA2NpWA,12559
31
+ sglang/srt/layers/context_flashattention_nopad.py,sha256=GkjLiTkS4px_uLcW0aDocE3_OBXtujZ-SlsN2b2U7ng,5204
32
+ sglang/srt/layers/extend_attention.py,sha256=pWVE6ySnPiVLFON__bie73eDhmXHk4tECMK8zTiJNbI,12558
34
33
  sglang/srt/layers/get_selected_logprob.py,sha256=CpMXM9WXMSB-AVaxBB_aVl1Qx_ZtAFFnjDTm4CgNDpU,2199
35
34
  sglang/srt/layers/logits_processor.py,sha256=rwcXwdZ7-dW9zvJX3MF_EHSxMLbU7TIQ9xUIYRu-WAs,3013
36
35
  sglang/srt/layers/radix_attention.py,sha256=hmPNFg2TkN4EAVUj376N_89RRtUYRwFgUpjj5SydnRk,6170
@@ -40,18 +39,18 @@ sglang/srt/managers/io_struct.py,sha256=5jMWj6_U8yTQd5V3tpDtThnoFyF0A3ln-4Z5bSL3
40
39
  sglang/srt/managers/openai_protocol.py,sha256=Eid_734Wup4jsL1ZS2Op0vwRuzvNbF4mV2UcwFxqEvI,327
41
40
  sglang/srt/managers/tokenizer_manager.py,sha256=jVwr0lM18RFJLhDb5TWlUpQ4Q8tALT4L6GY0jmaZkLw,7861
42
41
  sglang/srt/managers/router/infer_batch.py,sha256=UfS1uVhGnM-62Xv1cfu_IoTeIUxkjkKc4W3trtGbadc,11541
43
- sglang/srt/managers/router/manager.py,sha256=H-T-LlnIssHw-FXMHbs3yDQewkTMBCqG6jTYjugopCA,2527
44
- sglang/srt/managers/router/model_rpc.py,sha256=G7NvEDguSNj-ZAXBo7GpNQJJHW5WAy_1-qQ7bzqltTU,19286
45
- sglang/srt/managers/router/model_runner.py,sha256=U-SBnEeLvwolLcaxyxrPgVG7PnR2rRvuXWV50t9y0Fo,16480
42
+ sglang/srt/managers/router/manager.py,sha256=AVCdYKKYcIQsIwpudkfFY4jh6M--ubLjXeYGzfi2ebw,2528
43
+ sglang/srt/managers/router/model_rpc.py,sha256=CR3qbHvShttlC19qAZ8B8nhT6UPobeu2Dy3Z0n6WdC8,19448
44
+ sglang/srt/managers/router/model_runner.py,sha256=IhSdpBcd54HN01HDi_PAkJztFxEGDcnktdoPZDWEx3s,16487
46
45
  sglang/srt/managers/router/radix_cache.py,sha256=ZQPm9HhQ7vD3Gl5nhuvw3ZW4ZRARcplqWed1GYUvHCg,6441
47
46
  sglang/srt/managers/router/scheduler.py,sha256=ejuIRwqqMZVXFKUionRJxy5AtNvK25YoGRO9rFY-rc8,2926
48
47
  sglang/srt/models/llama2.py,sha256=D3j-NtyM8PA74UhXM7wSPogI2HKX-JcQAWcOusrZZo0,11320
49
48
  sglang/srt/models/llava.py,sha256=COS0IC6Yo-QiwKe5emgCbtEe9HgaSu5tt6CQA7UtV38,8533
50
- sglang/srt/models/mixtral.py,sha256=j91xOt6NZ5tJiyTPqmUSzgJqFAw7vTDnfBtEs5x0jDM,13714
51
- sglang/test/test_programs.py,sha256=ua3wufnS3x6d_U3aboY4ivqoglrRPZj18j96vuiUtiE,11348
49
+ sglang/srt/models/mixtral.py,sha256=frd2XsNZwP0XsQtRiYhgy4PErLNLgtIsLakmNrOKBAU,13712
50
+ sglang/test/test_programs.py,sha256=EovA2xL7fODcTbFj2wAAmYKlg1mLZ1x1BRU6nrXFRdE,11416
52
51
  sglang/test/test_utils.py,sha256=Knxg3BTA6d_7XSlprbBCdvfDr2SN5x7LhkT-tZFk5EQ,4828
53
- sglang-0.1.4.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
54
- sglang-0.1.4.dist-info/METADATA,sha256=qYnmtS2k2ddncYRUOCZoE_SXEli5cFD7yh_JkP7IVWk,22676
55
- sglang-0.1.4.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
56
- sglang-0.1.4.dist-info/top_level.txt,sha256=yxhh3pYQkcnA7v3Bg889C2jZhvtJdEincysO7PEB09M,7
57
- sglang-0.1.4.dist-info/RECORD,,
52
+ sglang-0.1.5.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
53
+ sglang-0.1.5.dist-info/METADATA,sha256=aepmAL6VoXRcxZBIDKvxwikCYSbvWFm_JFGTxb3Mgfw,23345
54
+ sglang-0.1.5.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
55
+ sglang-0.1.5.dist-info/top_level.txt,sha256=yxhh3pYQkcnA7v3Bg889C2jZhvtJdEincysO7PEB09M,7
56
+ sglang-0.1.5.dist-info/RECORD,,
@@ -1,349 +0,0 @@
1
- import functools
2
- from enum import Enum, auto
3
- from typing import Callable, List, Optional, Union
4
-
5
- import numpy as np
6
- import torch
7
- import transformers
8
- from sglang.backend.base_backend import BaseBackend
9
- from sglang.lang.chat_template import get_chat_template_by_model_path
10
- from sglang.lang.interpreter import ProgramState
11
- from sglang.utils import get_available_gpu_memory
12
- from transformers import (
13
- AutoModelForCausalLM,
14
- AutoTokenizer,
15
- StoppingCriteria,
16
- StoppingCriteriaList,
17
- )
18
- from transformersgl.generation.logits_process import (
19
- LogitsProcessorList,
20
- RepetitionPenaltyLogitsProcessor,
21
- TemperatureLogitsWarper,
22
- TopKLogitsWarper,
23
- TopPLogitsWarper,
24
- )
25
-
26
-
27
- class StopReason(Enum):
28
- EOS_TOKEN = auto()
29
- STOP_STR = auto()
30
- LENGTH = auto()
31
-
32
-
33
- def load_model(
34
- model_name: str,
35
- device,
36
- num_gpus,
37
- max_gpu_memory,
38
- model_kwargs=None,
39
- tokenizer_kwargs=None,
40
- ):
41
- model_kwargs = model_kwargs or {}
42
- tokenizer_kwargs = tokenizer_kwargs or {}
43
-
44
- if device == "cuda":
45
- model_kwargs["torch_dtype"] = torch.float16
46
- if num_gpus != 1:
47
- model_kwargs["device_map"] = "auto"
48
- if max_gpu_memory is None:
49
- model_kwargs[
50
- "device_map"
51
- ] = "sequential" # This is important for not the same VRAM sizes
52
- available_gpu_memory = [
53
- get_available_gpu_memory(i, False) for i in range(num_gpus)
54
- ]
55
- model_kwargs["max_memory"] = {
56
- i: str(int(available_gpu_memory[i] * 0.85)) + "GiB"
57
- for i in range(num_gpus)
58
- }
59
- else:
60
- model_kwargs["max_memory"] = {
61
- i: max_gpu_memory for i in range(num_gpus)
62
- }
63
- elif device == "cpu":
64
- model_kwargs["torch_dtype"] = torch.float32
65
- else:
66
- raise ValueError(f"Invalid device: {device}")
67
-
68
- model = AutoModelForCausalLM.from_pretrained(
69
- model_name, low_cpu_mem_usage=True, **model_kwargs
70
- )
71
- tokenizer = AutoTokenizer.from_pretrained(model_name, **tokenizer_kwargs)
72
-
73
- if num_gpus == 1:
74
- model.to(device).eval()
75
-
76
- return model, tokenizer
77
-
78
-
79
- def prepare_logits_processor(
80
- temperature: float, repetition_penalty: float, top_p: float, top_k: int
81
- ) -> LogitsProcessorList:
82
- processor_list = LogitsProcessorList()
83
- # TemperatureLogitsWarper doesn't accept 0.0, 1.0 makes it a no-op so we skip two cases.
84
- if temperature >= 1e-5 and temperature != 1.0:
85
- processor_list.append(TemperatureLogitsWarper(temperature))
86
- if repetition_penalty > 1.0:
87
- processor_list.append(RepetitionPenaltyLogitsProcessor(repetition_penalty))
88
- if 1e-8 <= top_p < 1.0:
89
- processor_list.append(TopPLogitsWarper(top_p))
90
- if top_k > 0:
91
- processor_list.append(TopKLogitsWarper(top_k))
92
- return processor_list
93
-
94
-
95
- @functools.lru_cache
96
- def get_token_healing_mask(tokenizer, prompt_last_token):
97
- last_str = tokenizer.convert_ids_to_tokens(prompt_last_token)
98
- disallowed = torch.zeros(len(tokenizer), dtype=bool)
99
- for s, t_id in tokenizer.get_vocab().items():
100
- if not s.startswith(last_str):
101
- disallowed[t_id] = 1
102
- return disallowed
103
-
104
-
105
- @functools.lru_cache
106
- def get_int_token_mask(tokenizer):
107
- disallowed = torch.zeros(len(tokenizer), dtype=bool)
108
- for s, t_id in tokenizer.get_vocab().items():
109
- s = s.replace("▁", "").strip()
110
- if not (s.isdigit() or len(s) == 0 or s == ","):
111
- disallowed[t_id] = 1
112
- disallowed[tokenizer.eos_token_id] = 0
113
- return disallowed
114
-
115
-
116
- @torch.inference_mode()
117
- def generate_stream(
118
- model,
119
- tokenizer,
120
- prompt,
121
- max_new_tokens,
122
- stop: List[str],
123
- temperature,
124
- top_p,
125
- token_healing,
126
- logit_mask=None,
127
- ):
128
- logits_processor = prepare_logits_processor(
129
- temperature=temperature, repetition_penalty=1.0, top_p=top_p, top_k=0
130
- )
131
- device = model.device
132
- input_ids = tokenizer.encode(prompt)
133
- output_ids = list(input_ids)
134
- prompt_len = len(prompt)
135
-
136
- # Resolve stop
137
- stop_token_ids = [tokenizer.eos_token_id]
138
-
139
- # Token healing
140
- token_healing = token_healing and len(input_ids) > 0
141
- if token_healing:
142
- token_healing_mask = get_token_healing_mask(tokenizer, input_ids[-1])
143
- del output_ids[-1]
144
-
145
- # Generate
146
- past_key_values = None
147
- stop_reason = None
148
- for i in range(max_new_tokens):
149
- # Forward
150
- if i == 0: # prefill
151
- out = model(torch.as_tensor([output_ids], device=device), use_cache=True)
152
- else: # decoding
153
- out = model(
154
- input_ids=torch.as_tensor([[token]], device=device),
155
- use_cache=True,
156
- past_key_values=past_key_values,
157
- )
158
- logits = out.logits
159
- past_key_values = out.past_key_values
160
-
161
- # Logit mask
162
- if token_healing and i == 0:
163
- logits[0, -1, token_healing_mask] = -1e4
164
- if logit_mask is not None:
165
- logits[0, -1, logit_mask] = -1e4
166
-
167
- # Sample next token
168
- last_token_logits = logits_processor(None, logits[:, -1, :])[0]
169
- if temperature < 1e-5 or top_p < 1e-8: # greedy
170
- token = int(torch.argmax(last_token_logits))
171
- else:
172
- probs = torch.softmax(last_token_logits, dim=-1)
173
- token = int(torch.multinomial(probs, num_samples=1))
174
- output_ids.append(token)
175
-
176
- # Stop condition
177
- if token in stop_token_ids:
178
- stop_reason = StopReason.EOS_TOKEN
179
- break
180
-
181
- output_str = tokenizer.decode(output_ids, skip_special_tokens=True)
182
- for stop_str in stop:
183
- pos = output_str[prompt_len:].find(stop_str)
184
- if pos != -1:
185
- stop_reason = StopReason.STOP_STR
186
- output_str = output_str[: prompt_len + pos]
187
- break
188
-
189
- if stop_reason:
190
- break
191
-
192
- return output_str[prompt_len:]
193
-
194
-
195
- class HuggingFaceTransformers(BaseBackend):
196
- def __init__(
197
- self,
198
- model_name,
199
- device="cuda",
200
- num_gpus=1,
201
- max_gpu_memory=None,
202
- model_kwargs=None,
203
- tokenizer_kwargs=None,
204
- ):
205
- self.model_name = model_name
206
- self.device = device
207
-
208
- self.model, self.tokenizer = load_model(
209
- model_name, device, num_gpus, max_gpu_memory, model_kwargs, tokenizer_kwargs
210
- )
211
-
212
- self.chat_template = get_chat_template_by_model_path(model_name)
213
-
214
- def get_chat_template(self):
215
- return self.chat_template
216
-
217
- def cache_prefix(self, prefix_str: str):
218
- pass
219
-
220
- def uncache_prefix(self, rid: str):
221
- pass
222
-
223
- def end_request(self, rid: str):
224
- pass
225
-
226
- def begin_program(self, s: ProgramState):
227
- pass
228
-
229
- def end_program(self, s: ProgramState):
230
- pass
231
-
232
- def fill(self, s: ProgramState, text: str):
233
- return False
234
-
235
- def generate_internal(
236
- self,
237
- prompt: str,
238
- max_tokens: int,
239
- stop: Union[str, List[str]],
240
- temperature: float,
241
- top_p: float,
242
- dtype: Optional[str] = None,
243
- ):
244
- if dtype is None:
245
- comp = generate_stream(
246
- self.model,
247
- self.tokenizer,
248
- prompt,
249
- max_new_tokens=max_tokens,
250
- stop=stop,
251
- temperature=temperature,
252
- top_p=top_p,
253
- token_healing=True,
254
- )
255
- elif dtype in [str, "str", "string"]:
256
- comp = generate_stream(
257
- self.model,
258
- self.tokenizer,
259
- prompt + '"',
260
- max_new_tokens=max_tokens,
261
- stop=['"'],
262
- temperature=temperature,
263
- top_p=top_p,
264
- token_healing=False,
265
- )
266
- comp = '"' + comp + '"'
267
- elif dtype in [int, "int"]:
268
- logit_mask = get_int_token_mask(self.tokenizer)
269
- comp = generate_stream(
270
- self.model,
271
- self.tokenizer,
272
- prompt,
273
- max_new_tokens=max_tokens,
274
- stop=stop + [" ", ","],
275
- temperature=temperature,
276
- top_p=top_p,
277
- token_healing=False,
278
- logit_mask=logit_mask,
279
- )
280
- return comp
281
-
282
- def generate(
283
- self,
284
- s: ProgramState,
285
- max_tokens: int,
286
- stop: Union[str, List[str]],
287
- temperature: float,
288
- top_p: float,
289
- dtype: Optional[str] = None,
290
- ):
291
- prompt = s.text
292
- comp = self.generate_internal(
293
- prompt, max_tokens, stop, temperature, top_p, dtype
294
- )
295
- return comp
296
-
297
- def parallel_generate(
298
- self,
299
- s: ProgramState,
300
- prefixes: List[str],
301
- join_func: Callable,
302
- max_tokens: int,
303
- stop: Union[str, List[str]],
304
- temperature: float,
305
- top_p: float,
306
- dtype: Optional[str] = None,
307
- ):
308
- prompt = s.text
309
- parallel_prompts = [prompt + prefix for prefix in prefixes]
310
-
311
- comps = []
312
- for i in range(len(parallel_prompts)):
313
- comps.append(
314
- self.generate_internal(
315
- parallel_prompts[i], max_tokens, stop, temperature, top_p, dtype
316
- )
317
- )
318
-
319
- joined = join_func([p + c for p, c in zip(prefixes, comps)])
320
- return joined, comps
321
-
322
- @torch.inference_mode()
323
- def select(
324
- self, s: ProgramState, choices: List[str], temperature: float, top_p: float
325
- ):
326
- loss_fct = torch.nn.CrossEntropyLoss()
327
- prompt = s.text
328
-
329
- prompt_len = self.tokenizer.encode(prompt, return_tensors="pt").shape[1]
330
- prompt_choices = [prompt + choice for choice in choices]
331
-
332
- scores = []
333
- for i in range(len(choices)):
334
- choice_ids = self.tokenizer.encode(
335
- prompt_choices[i], return_tensors="pt"
336
- ).to(self.model.device)
337
- logits = self.model(choice_ids).logits
338
-
339
- # score = -loss_fct(logits[0, :-1, :], choice_ids[0, 1:]).item()
340
-
341
- logprobs = torch.log(torch.softmax(logits, dim=-1))
342
- idx1 = torch.arange(0, logits.shape[1] - 1, device=logits.device)
343
- idx2 = choice_ids[0, 1:]
344
- selected_logprobs = logprobs[0, idx1, idx2]
345
- score = selected_logprobs.mean().item()
346
- scores.append(score)
347
-
348
- decision = choices[np.argmax(scores)]
349
- return decision, scores
sglang/backend/tgi.py DELETED
@@ -1,190 +0,0 @@
1
- import re
2
- from concurrent.futures import ThreadPoolExecutor
3
- from functools import partial
4
- from itertools import repeat
5
- from typing import List, Optional, Union
6
-
7
- from sglang.backend.base_backend import BaseBackend
8
- from sglang.lang.chat_template import get_chat_template_by_model_path
9
- from sglang.lang.interpreter import StreamExecutor
10
- from sglang.lang.ir import SglSamplingParams
11
- from sglang.utils import http_request
12
-
13
-
14
- class TGI(BaseBackend):
15
- def __init__(self, base_url):
16
- super().__init__()
17
-
18
- self.base_url = base_url
19
-
20
- res = http_request(self.base_url + "/info")
21
- assert res.status_code == 200
22
- self.model_info = res.json()
23
- self.chat_template = get_chat_template_by_model_path(
24
- self.model_info["model_id"]
25
- )
26
-
27
- def get_model_name(self):
28
- return self.model_info["model_id"]
29
-
30
- def get_chat_template(self):
31
- return self.chat_template
32
-
33
- @staticmethod
34
- def adapt_params(max_tokens, stop, sampling_params, **override_params):
35
- temperature = sampling_params.temperature
36
- do_sample = True
37
- if temperature == 0:
38
- do_sample = False
39
- temperature = None
40
-
41
- if stop is None:
42
- stop = []
43
- elif isinstance(stop, str):
44
- stop = [stop]
45
-
46
- top_p = sampling_params.top_p
47
- if top_p == 0:
48
- top_p = 0.001
49
- if top_p == 1:
50
- top_p = 0.999
51
-
52
- top_k = sampling_params.top_k
53
- if top_k == -1:
54
- top_k = None
55
-
56
- params = {
57
- "decoder_input_details": False,
58
- "details": False,
59
- "do_sample": do_sample,
60
- "max_new_tokens": max_tokens,
61
- "stop": stop,
62
- "temperature": temperature,
63
- "top_p": top_p,
64
- "top_k": top_k,
65
- "return_full_text": False,
66
- }
67
- params.update(override_params)
68
- return params
69
-
70
- @staticmethod
71
- def _extract_int(text):
72
- words = re.split("\ |'|\/|\(|\)|\n|\.|,", text)
73
- for word in words:
74
- try:
75
- int(word)
76
- return word
77
- except ValueError:
78
- continue
79
- raise ValueError
80
-
81
- @staticmethod
82
- def _extract_choice(choices, text):
83
- # FIXME: Current only support the case where the choices are single words.
84
- words = re.split("\ |'|\/|\(|\)|\n|\.|,", text)
85
- for word in words:
86
- if word in choices:
87
- return word
88
- raise ValueError
89
-
90
- @staticmethod
91
- def _truncate_to_stop(text, stop):
92
- # The stop sequence may not be a single token. In this case TGI will generate
93
- # too many tokens so we need to truncate the output.
94
- if stop:
95
- stop = [stop] if isinstance(stop, str) else stop
96
- for stop_seq in stop:
97
- pos = text.find(stop_seq)
98
- if pos != -1:
99
- return text[:pos]
100
- return text
101
-
102
- def _make_request(self, params):
103
- res = http_request(self.base_url + "/generate", json=params)
104
- if res.status_code != 200:
105
- raise ValueError(f"Error from TGI backend: {res.text}")
106
- return res.json()
107
-
108
- def retry_for_expected(self, prompt, params, extract_fn, retry=5):
109
- # TGI does not support logis_bias (yet), so we have to use an inefficient hack.
110
- failed = []
111
- while retry > 0:
112
- res_json = self._make_request(
113
- {
114
- "inputs": prompt,
115
- "parameters": params,
116
- }
117
- )
118
- text = res_json["generated_text"]
119
- try:
120
- return extract_fn(text)
121
- except ValueError:
122
- retry -= 1
123
- failed.append(text)
124
-
125
- msg = "=" * 20 + "\n"
126
- msg += f"Prompt:\n{prompt}\n"
127
- msg += "=" * 20 + "\n"
128
- for i, text in enumerate(failed):
129
- msg += f"====== Try {i+1}:\n{text}\n"
130
-
131
- raise ValueError(
132
- f"Model {self.model_info['model_id']} served by TGI backend does not generate"
133
- "expected output. Please improve the prompt, increase the temperature, or "
134
- f"use different models.\n{msg}"
135
- )
136
-
137
- def select(
138
- self,
139
- s: StreamExecutor,
140
- choices: List[str],
141
- sampling_params: SglSamplingParams,
142
- ):
143
- decision = self.retry_for_expected(
144
- s.text_,
145
- self.adapt_params(16, [], sampling_params),
146
- partial(self._extract_choice, choices),
147
- )
148
- return decision, [1 if choice == decision else 0 for choice in choices]
149
-
150
- def generate(
151
- self,
152
- s: StreamExecutor,
153
- max_tokens: int,
154
- stop: Union[str, List[str]],
155
- sampling_params: SglSamplingParams,
156
- dtype: Optional[str] = None,
157
- ):
158
- if dtype is None:
159
- res_json = self._make_request(
160
- {
161
- "inputs": s.text_,
162
- "parameters": self.adapt_params(max_tokens, stop, sampling_params),
163
- }
164
- )
165
- return self._truncate_to_stop(res_json["generated_text"], stop), {}
166
-
167
- if dtype in [str, "str", "string"]:
168
- stop = ['"']
169
- res_json = self._make_request(
170
- {
171
- "inputs": f'{s.text_}"',
172
- "parameters": self.adapt_params(max_tokens, stop, sampling_params),
173
- }
174
- )
175
- return (
176
- '"' + self._truncate_to_stop(res_json["generated_text"], stop) + '"',
177
- {},
178
- )
179
-
180
- if dtype in [int, "int"]:
181
- return (
182
- self.retry_for_expected(
183
- s.text_,
184
- self.adapt_params(max_tokens, stop, sampling_params),
185
- self._extract_int,
186
- ),
187
- {},
188
- )
189
-
190
- raise ValueError(f"Unknown dtype: {dtype}")
File without changes