sglang 0.1.12__tar.gz → 0.1.14__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. {sglang-0.1.12/sglang.egg-info → sglang-0.1.14}/PKG-INFO +16 -6
  2. {sglang-0.1.12 → sglang-0.1.14}/README.md +13 -3
  3. {sglang-0.1.12 → sglang-0.1.14}/pyproject.toml +3 -3
  4. {sglang-0.1.12 → sglang-0.1.14}/sglang/__init__.py +1 -1
  5. {sglang-0.1.12 → sglang-0.1.14}/sglang/api.py +14 -0
  6. {sglang-0.1.12 → sglang-0.1.14}/sglang/backend/anthropic.py +18 -12
  7. {sglang-0.1.12 → sglang-0.1.14}/sglang/backend/base_backend.py +6 -0
  8. {sglang-0.1.12 → sglang-0.1.14}/sglang/backend/openai.py +41 -12
  9. {sglang-0.1.12 → sglang-0.1.14}/sglang/backend/runtime_endpoint.py +57 -6
  10. {sglang-0.1.12 → sglang-0.1.14}/sglang/lang/chat_template.py +47 -26
  11. {sglang-0.1.12 → sglang-0.1.14}/sglang/lang/interpreter.py +15 -2
  12. {sglang-0.1.12 → sglang-0.1.14}/sglang/lang/ir.py +1 -1
  13. sglang-0.1.14/sglang/srt/constrained/__init__.py +38 -0
  14. sglang-0.1.14/sglang/srt/constrained/fsm_cache.py +24 -0
  15. {sglang-0.1.12 → sglang-0.1.14}/sglang/srt/layers/context_flashattention_nopad.py +1 -1
  16. {sglang-0.1.12 → sglang-0.1.14}/sglang/srt/layers/extend_attention.py +7 -6
  17. {sglang-0.1.12 → sglang-0.1.14}/sglang/srt/layers/radix_attention.py +2 -10
  18. {sglang-0.1.12 → sglang-0.1.14}/sglang/srt/layers/token_attention.py +12 -4
  19. {sglang-0.1.12 → sglang-0.1.14}/sglang/srt/managers/io_struct.py +3 -1
  20. {sglang-0.1.12 → sglang-0.1.14}/sglang/srt/managers/router/infer_batch.py +6 -2
  21. {sglang-0.1.12 → sglang-0.1.14}/sglang/srt/managers/router/model_rpc.py +45 -32
  22. {sglang-0.1.12 → sglang-0.1.14}/sglang/srt/managers/router/model_runner.py +40 -25
  23. {sglang-0.1.12 → sglang-0.1.14}/sglang/srt/managers/tokenizer_manager.py +2 -0
  24. {sglang-0.1.12 → sglang-0.1.14}/sglang/srt/model_config.py +12 -5
  25. sglang-0.1.14/sglang/srt/models/gemma.py +340 -0
  26. {sglang-0.1.12 → sglang-0.1.14}/sglang/srt/models/llama2.py +5 -5
  27. {sglang-0.1.12 → sglang-0.1.14}/sglang/srt/models/llava.py +2 -4
  28. {sglang-0.1.12 → sglang-0.1.14}/sglang/srt/models/mixtral.py +5 -5
  29. {sglang-0.1.12 → sglang-0.1.14}/sglang/srt/models/qwen.py +4 -4
  30. {sglang-0.1.12 → sglang-0.1.14}/sglang/srt/models/qwen2.py +5 -5
  31. sglang-0.1.14/sglang/srt/models/stablelm.py +293 -0
  32. {sglang-0.1.12 → sglang-0.1.14}/sglang/srt/server.py +111 -47
  33. {sglang-0.1.12 → sglang-0.1.14}/sglang/srt/server_args.py +44 -9
  34. {sglang-0.1.12 → sglang-0.1.14}/sglang/srt/utils.py +1 -0
  35. {sglang-0.1.12 → sglang-0.1.14}/sglang/test/test_utils.py +1 -1
  36. {sglang-0.1.12 → sglang-0.1.14}/sglang/utils.py +15 -12
  37. {sglang-0.1.12 → sglang-0.1.14/sglang.egg-info}/PKG-INFO +16 -6
  38. {sglang-0.1.12 → sglang-0.1.14}/sglang.egg-info/SOURCES.txt +2 -1
  39. {sglang-0.1.12 → sglang-0.1.14}/sglang.egg-info/requires.txt +2 -2
  40. sglang-0.1.12/sglang/srt/constrained/__init__.py +0 -16
  41. sglang-0.1.12/sglang/srt/constrained/fsm_cache.py +0 -13
  42. sglang-0.1.12/sglang/srt/models/gpt_neox.py +0 -274
  43. {sglang-0.1.12 → sglang-0.1.14}/LICENSE +0 -0
  44. {sglang-0.1.12 → sglang-0.1.14}/setup.cfg +0 -0
  45. {sglang-0.1.12 → sglang-0.1.14}/sglang/backend/__init__.py +0 -0
  46. {sglang-0.1.12 → sglang-0.1.14}/sglang/backend/vertexai.py +0 -0
  47. {sglang-0.1.12 → sglang-0.1.14}/sglang/global_config.py +0 -0
  48. {sglang-0.1.12 → sglang-0.1.14}/sglang/lang/__init__.py +0 -0
  49. {sglang-0.1.12 → sglang-0.1.14}/sglang/lang/compiler.py +0 -0
  50. {sglang-0.1.12 → sglang-0.1.14}/sglang/lang/tracer.py +0 -0
  51. {sglang-0.1.12 → sglang-0.1.14}/sglang/launch_server.py +0 -0
  52. {sglang-0.1.12 → sglang-0.1.14}/sglang/srt/backend_config.py +0 -0
  53. {sglang-0.1.12 → sglang-0.1.14}/sglang/srt/constrained/base_cache.py +0 -0
  54. {sglang-0.1.12 → sglang-0.1.14}/sglang/srt/constrained/jump_forward.py +0 -0
  55. {sglang-0.1.12 → sglang-0.1.14}/sglang/srt/conversation.py +0 -0
  56. {sglang-0.1.12 → sglang-0.1.14}/sglang/srt/hf_transformers_utils.py +0 -0
  57. {sglang-0.1.12 → sglang-0.1.14}/sglang/srt/layers/logits_processor.py +0 -0
  58. {sglang-0.1.12 → sglang-0.1.14}/sglang/srt/managers/detokenizer_manager.py +0 -0
  59. {sglang-0.1.12 → sglang-0.1.14}/sglang/srt/managers/openai_protocol.py +0 -0
  60. {sglang-0.1.12 → sglang-0.1.14}/sglang/srt/managers/router/manager.py +0 -0
  61. {sglang-0.1.12 → sglang-0.1.14}/sglang/srt/managers/router/radix_cache.py +0 -0
  62. {sglang-0.1.12 → sglang-0.1.14}/sglang/srt/managers/router/scheduler.py +0 -0
  63. {sglang-0.1.12 → sglang-0.1.14}/sglang/srt/memory_pool.py +0 -0
  64. {sglang-0.1.12 → sglang-0.1.14}/sglang/srt/mm_utils.py +0 -0
  65. {sglang-0.1.12 → sglang-0.1.14}/sglang/srt/models/mistral.py +0 -0
  66. {sglang-0.1.12 → sglang-0.1.14}/sglang/srt/models/yivl.py +0 -0
  67. {sglang-0.1.12 → sglang-0.1.14}/sglang/srt/sampling_params.py +0 -0
  68. {sglang-0.1.12 → sglang-0.1.14}/sglang/test/test_conversation.py +0 -0
  69. {sglang-0.1.12 → sglang-0.1.14}/sglang/test/test_openai_protocol.py +0 -0
  70. {sglang-0.1.12 → sglang-0.1.14}/sglang/test/test_programs.py +0 -0
  71. {sglang-0.1.12 → sglang-0.1.14}/sglang.egg-info/dependency_links.txt +0 -0
  72. {sglang-0.1.12 → sglang-0.1.14}/sglang.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: sglang
3
- Version: 0.1.12
3
+ Version: 0.1.14
4
4
  Summary: A structured generation langauge for LLMs.
5
5
  License: Apache License
6
6
  Version 2.0, January 2004
@@ -221,7 +221,7 @@ Requires-Dist: torch; extra == "srt"
221
221
  Requires-Dist: uvloop; extra == "srt"
222
222
  Requires-Dist: uvicorn; extra == "srt"
223
223
  Requires-Dist: zmq; extra == "srt"
224
- Requires-Dist: vllm>=0.2.5; extra == "srt"
224
+ Requires-Dist: vllm>=0.3.3; extra == "srt"
225
225
  Requires-Dist: interegular; extra == "srt"
226
226
  Requires-Dist: lark; extra == "srt"
227
227
  Requires-Dist: numba; extra == "srt"
@@ -235,14 +235,19 @@ Provides-Extra: openai
235
235
  Requires-Dist: openai>=1.0; extra == "openai"
236
236
  Requires-Dist: numpy; extra == "openai"
237
237
  Provides-Extra: anthropic
238
- Requires-Dist: anthropic; extra == "anthropic"
238
+ Requires-Dist: anthropic>=0.20.0; extra == "anthropic"
239
239
  Requires-Dist: numpy; extra == "anthropic"
240
240
  Provides-Extra: all
241
241
  Requires-Dist: sglang[srt]; extra == "all"
242
242
  Requires-Dist: sglang[openai]; extra == "all"
243
243
  Requires-Dist: sglang[anthropic]; extra == "all"
244
244
 
245
- # SGLang
245
+ <div align="center">
246
+ <img src="assets/logo.png" alt="logo" width="400"></img>
247
+ </div>
248
+
249
+ --------------------------------------------------------------------------------
250
+
246
251
  | [**Blog**](https://lmsys.org/blog/2024-01-17-sglang/) | [**Paper**](https://arxiv.org/abs/2312.07104) |
247
252
 
248
253
  SGLang is a structured generation language designed for large language models (LLMs).
@@ -254,7 +259,7 @@ The core features of SGLang include:
254
259
 
255
260
  ## News
256
261
  - [2024/02] 🔥 SGLang enables **3x faster JSON decoding** with compressed finite state machine ([blog](https://lmsys.org/blog/2024-02-05-compressed-fsm/)).
257
- - [2024/01] 🔥 SGLang powers the serving of the offical **LLaVA v1.6** release demo ([usage](https://github.com/haotian-liu/LLaVA?tab=readme-ov-file#demo)).
262
+ - [2024/01] 🔥 SGLang powers the serving of the official **LLaVA v1.6** release demo ([usage](https://github.com/haotian-liu/LLaVA?tab=readme-ov-file#demo)).
258
263
  - [2024/01] SGLang provides up to **5x faster inference** with RadixAttention ([blog](https://lmsys.org/blog/2024-01-17-sglang/)).
259
264
 
260
265
  ## Contents
@@ -496,7 +501,7 @@ def text_qa(s, question):
496
501
  s += "Q: " + question + "\n"
497
502
  s += "A:" + sgl.gen("answer", stop="\n")
498
503
 
499
- states = text_qa.run(
504
+ state = text_qa.run(
500
505
  question="What is the capital of France?",
501
506
  temperature=0.1,
502
507
  stream=True
@@ -608,8 +613,13 @@ python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port
608
613
  - Mistral
609
614
  - Mixtral
610
615
  - Qwen / Qwen 2
616
+ - Gemma
617
+ - Please add a new flag `--attention-reduce-in-fp32` to avoid some precision errors.
618
+ - `python -m sglang.launch_server --model-path google/gemma-7b-it --port 30000 --attention-reduce-in-fp32`
611
619
  - LLaVA
612
620
  - `python3 -m sglang.launch_server --model-path liuhaotian/llava-v1.5-7b --tokenizer-path llava-hf/llava-1.5-7b-hf --chat-template vicuna_v1.1 --port 30000`
621
+ - `python3 -m sglang.launch_server --model-path liuhaotian/llava-v1.6-vicuna-7b --tokenizer-path llava-hf/llava-1.5-7b-hf --chat-template vicuna_v1.1 --port 30000`
622
+ - `python3 -m sglang.launch_server --model-path liuhaotian/llava-v1.6-34b --tokenizer-path liuhaotian/llava-v1.6-34b-tokenizer --port 3000`
613
623
  - Yi-VL
614
624
  - see [srt_example_yi_vl.py](examples/quick_start/srt_example_yi_vl.py).
615
625
  - AWQ/GPTQ quantization
@@ -1,4 +1,9 @@
1
- # SGLang
1
+ <div align="center">
2
+ <img src="assets/logo.png" alt="logo" width="400"></img>
3
+ </div>
4
+
5
+ --------------------------------------------------------------------------------
6
+
2
7
  | [**Blog**](https://lmsys.org/blog/2024-01-17-sglang/) | [**Paper**](https://arxiv.org/abs/2312.07104) |
3
8
 
4
9
  SGLang is a structured generation language designed for large language models (LLMs).
@@ -10,7 +15,7 @@ The core features of SGLang include:
10
15
 
11
16
  ## News
12
17
  - [2024/02] 🔥 SGLang enables **3x faster JSON decoding** with compressed finite state machine ([blog](https://lmsys.org/blog/2024-02-05-compressed-fsm/)).
13
- - [2024/01] 🔥 SGLang powers the serving of the offical **LLaVA v1.6** release demo ([usage](https://github.com/haotian-liu/LLaVA?tab=readme-ov-file#demo)).
18
+ - [2024/01] 🔥 SGLang powers the serving of the official **LLaVA v1.6** release demo ([usage](https://github.com/haotian-liu/LLaVA?tab=readme-ov-file#demo)).
14
19
  - [2024/01] SGLang provides up to **5x faster inference** with RadixAttention ([blog](https://lmsys.org/blog/2024-01-17-sglang/)).
15
20
 
16
21
  ## Contents
@@ -252,7 +257,7 @@ def text_qa(s, question):
252
257
  s += "Q: " + question + "\n"
253
258
  s += "A:" + sgl.gen("answer", stop="\n")
254
259
 
255
- states = text_qa.run(
260
+ state = text_qa.run(
256
261
  question="What is the capital of France?",
257
262
  temperature=0.1,
258
263
  stream=True
@@ -364,8 +369,13 @@ python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port
364
369
  - Mistral
365
370
  - Mixtral
366
371
  - Qwen / Qwen 2
372
+ - Gemma
373
+ - Please add a new flag `--attention-reduce-in-fp32` to avoid some precision errors.
374
+ - `python -m sglang.launch_server --model-path google/gemma-7b-it --port 30000 --attention-reduce-in-fp32`
367
375
  - LLaVA
368
376
  - `python3 -m sglang.launch_server --model-path liuhaotian/llava-v1.5-7b --tokenizer-path llava-hf/llava-1.5-7b-hf --chat-template vicuna_v1.1 --port 30000`
377
+ - `python3 -m sglang.launch_server --model-path liuhaotian/llava-v1.6-vicuna-7b --tokenizer-path llava-hf/llava-1.5-7b-hf --chat-template vicuna_v1.1 --port 30000`
378
+ - `python3 -m sglang.launch_server --model-path liuhaotian/llava-v1.6-34b --tokenizer-path liuhaotian/llava-v1.6-34b-tokenizer --port 3000`
369
379
  - Yi-VL
370
380
  - see [srt_example_yi_vl.py](examples/quick_start/srt_example_yi_vl.py).
371
381
  - AWQ/GPTQ quantization
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "sglang"
7
- version = "0.1.12"
7
+ version = "0.1.14"
8
8
  description = "A structured generation langauge for LLMs."
9
9
  readme = "README.md"
10
10
  requires-python = ">=3.8"
@@ -19,10 +19,10 @@ dependencies = [
19
19
 
20
20
  [project.optional-dependencies]
21
21
  srt = ["aiohttp", "fastapi", "psutil", "rpyc", "torch", "uvloop", "uvicorn",
22
- "zmq", "vllm>=0.2.5", "interegular", "lark", "numba",
22
+ "zmq", "vllm>=0.3.3", "interegular", "lark", "numba",
23
23
  "pydantic", "referencing", "diskcache", "cloudpickle", "pillow", "outlines>=0.0.27"]
24
24
  openai = ["openai>=1.0", "numpy"]
25
- anthropic = ["anthropic", "numpy"]
25
+ anthropic = ["anthropic>=0.20.0", "numpy"]
26
26
  all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]"]
27
27
 
28
28
  [project.urls]
@@ -1,4 +1,4 @@
1
- __version__ = "0.1.12"
1
+ __version__ = "0.1.14"
2
2
 
3
3
  from sglang.api import *
4
4
  from sglang.global_config import global_config
@@ -44,6 +44,20 @@ def set_default_backend(backend: BaseBackend):
44
44
  global_config.default_backend = backend
45
45
 
46
46
 
47
+ def flush_cache(backend: BaseBackend = None):
48
+ backend = backend or global_config.default_backend
49
+ if backend is None:
50
+ return False
51
+ return backend.flush_cache()
52
+
53
+
54
+ def get_server_args(backend: BaseBackend = None):
55
+ backend = backend or global_config.default_backend
56
+ if backend is None:
57
+ return None
58
+ return backend.get_server_args()
59
+
60
+
47
61
  def gen(
48
62
  name: Optional[str] = None,
49
63
  max_tokens: Optional[int] = None,
@@ -30,13 +30,17 @@ class Anthropic(BaseBackend):
30
30
  s: StreamExecutor,
31
31
  sampling_params: SglSamplingParams,
32
32
  ):
33
- prompt = s.text_
34
- ret = anthropic.Anthropic().completions.create(
33
+ if s.messages_:
34
+ messages = s.messages_
35
+ else:
36
+ messages = [{"role": "user", "content": s.text_}]
37
+
38
+ ret = anthropic.Anthropic().messages.create(
35
39
  model=self.model_name,
36
- prompt=prompt,
40
+ messages=messages,
37
41
  **sampling_params.to_anthropic_kwargs(),
38
42
  )
39
- comp = ret.completion
43
+ comp = ret.content[0].text
40
44
 
41
45
  return comp, {}
42
46
 
@@ -45,13 +49,15 @@ class Anthropic(BaseBackend):
45
49
  s: StreamExecutor,
46
50
  sampling_params: SglSamplingParams,
47
51
  ):
48
- prompt = s.text_
49
- generator = anthropic.Anthropic().completions.create(
52
+ if s.messages_:
53
+ messages = s.messages_
54
+ else:
55
+ messages = [{"role": "user", "content": s.text_}]
56
+
57
+ with anthropic.Anthropic().messages.stream(
50
58
  model=self.model_name,
51
- prompt=prompt,
52
- stream=True,
59
+ messages=messages,
53
60
  **sampling_params.to_anthropic_kwargs(),
54
- )
55
-
56
- for ret in generator:
57
- yield ret.completion, {}
61
+ ) as stream:
62
+ for text in stream.text_stream:
63
+ yield text, {}
@@ -72,3 +72,9 @@ class BaseBackend:
72
72
 
73
73
  def shutdown(self):
74
74
  pass
75
+
76
+ def flush_cache(self):
77
+ pass
78
+
79
+ def get_server_args(self):
80
+ pass
@@ -4,7 +4,7 @@ from typing import Callable, List, Optional, Union
4
4
 
5
5
  import numpy as np
6
6
  from sglang.backend.base_backend import BaseBackend
7
- from sglang.lang.chat_template import get_chat_template
7
+ from sglang.lang.chat_template import ChatTemplate, get_chat_template_by_model_path
8
8
  from sglang.lang.interpreter import StreamExecutor
9
9
  from sglang.lang.ir import SglSamplingParams
10
10
 
@@ -41,23 +41,45 @@ INSTRUCT_MODEL_NAMES = [
41
41
 
42
42
 
43
43
  class OpenAI(BaseBackend):
44
- def __init__(self, model_name, *args, **kwargs):
44
+ def __init__(
45
+ self,
46
+ model_name: str,
47
+ is_chat_model: Optional[bool] = None,
48
+ chat_template: Optional[ChatTemplate] = None,
49
+ is_azure: bool = False,
50
+ *args,
51
+ **kwargs,
52
+ ):
45
53
  super().__init__()
46
54
 
47
55
  if isinstance(openai, Exception):
48
56
  raise openai
49
57
 
50
- self.client = openai.OpenAI(*args, **kwargs)
58
+ if is_azure:
59
+ self.client = openai.AzureOpenAI(*args, **kwargs)
60
+ else:
61
+ self.client = openai.OpenAI(*args, **kwargs)
62
+
51
63
  self.model_name = model_name
52
- self.tokenizer = tiktoken.encoding_for_model(model_name)
64
+ try:
65
+ self.tokenizer = tiktoken.encoding_for_model(model_name)
66
+ except KeyError:
67
+ self.tokenizer = tiktoken.get_encoding("cl100k_base")
53
68
  self.logit_bias_int = create_logit_bias_int(self.tokenizer)
54
69
 
55
- if model_name in INSTRUCT_MODEL_NAMES:
56
- self.is_chat_model = False
70
+ self.chat_template = chat_template or get_chat_template_by_model_path(
71
+ model_name
72
+ )
73
+
74
+ if is_chat_model is not None:
75
+ self.is_chat_model = is_chat_model
57
76
  else:
58
- self.is_chat_model = True
77
+ if model_name in INSTRUCT_MODEL_NAMES:
78
+ self.is_chat_model = False
79
+ else:
80
+ self.is_chat_model = True
59
81
 
60
- self.chat_template = get_chat_template("default")
82
+ self.chat_begin_str = self.chat_template.role_prefix_and_suffix["assistant"][0]
61
83
 
62
84
  def get_chat_template(self):
63
85
  return self.chat_template
@@ -69,7 +91,7 @@ class OpenAI(BaseBackend):
69
91
  ):
70
92
  if sampling_params.dtype is None:
71
93
  if self.is_chat_model:
72
- if not s.text_.endswith("ASSISTANT:"):
94
+ if not s.text_.endswith(self.chat_begin_str):
73
95
  raise RuntimeError(
74
96
  "This use case is not supported. "
75
97
  "For OpenAI chat models, sgl.gen must be right after sgl.assistant"
@@ -122,7 +144,11 @@ class OpenAI(BaseBackend):
122
144
  ):
123
145
  if sampling_params.dtype is None:
124
146
  if self.is_chat_model:
125
- assert s.text_.endswith("ASSISTANT:")
147
+ if not s.text_.endswith(self.chat_begin_str):
148
+ raise RuntimeError(
149
+ "This use case is not supported. "
150
+ "For OpenAI chat models, sgl.gen must be right after sgl.assistant"
151
+ )
126
152
  prompt = s.messages_
127
153
  else:
128
154
  prompt = s.text_
@@ -137,7 +163,7 @@ class OpenAI(BaseBackend):
137
163
  )
138
164
  return generator
139
165
  else:
140
- raise ValueError(f"Unknown dtype: {dtype}")
166
+ raise ValueError(f"Unknown dtype: {sampling_params.dtype}")
141
167
 
142
168
  def select(
143
169
  self,
@@ -241,7 +267,10 @@ def openai_completion_stream(client, retries=3, is_chat=None, prompt=None, **kwa
241
267
  messages=prompt, stream=True, **kwargs
242
268
  )
243
269
  for ret in generator:
244
- content = ret.choices[0].delta.content
270
+ try:
271
+ content = ret.choices[0].delta.content
272
+ except IndexError:
273
+ content = None
245
274
  yield content or "", {}
246
275
  else:
247
276
  generator = client.completions.create(
@@ -12,15 +12,26 @@ from sglang.utils import encode_image_base64, find_printable_text, http_request
12
12
 
13
13
 
14
14
  class RuntimeEndpoint(BaseBackend):
15
- def __init__(self, base_url, auth_token=None):
15
+ def __init__(
16
+ self,
17
+ base_url: str,
18
+ auth_token: Optional[str] = None,
19
+ api_key: Optional[str] = None,
20
+ verify: Optional[str] = None,
21
+ ):
16
22
  super().__init__()
17
23
  self.support_concate_and_append = True
18
24
 
19
25
  self.base_url = base_url
20
26
  self.auth_token = auth_token
27
+ self.api_key = api_key
28
+ self.verify = verify
21
29
 
22
30
  res = http_request(
23
- self.base_url + "/get_model_info", auth_token=self.auth_token
31
+ self.base_url + "/get_model_info",
32
+ auth_token=self.auth_token,
33
+ api_key=self.api_key,
34
+ verify=self.verify,
24
35
  )
25
36
  assert res.status_code == 200
26
37
  self.model_info = res.json()
@@ -32,6 +43,22 @@ class RuntimeEndpoint(BaseBackend):
32
43
  def get_model_name(self):
33
44
  return self.model_info["model_path"]
34
45
 
46
+ def flush_cache(self):
47
+ res = http_request(
48
+ self.base_url + "/flush_cache",
49
+ auth_token=self.auth_token,
50
+ verify=self.verify,
51
+ )
52
+ return res.status_code == 200
53
+
54
+ def get_server_args(self):
55
+ res = http_request(
56
+ self.base_url + "/get_server_args",
57
+ auth_token=self.auth_token,
58
+ verify=self.verify,
59
+ )
60
+ return res.json()
61
+
35
62
  def get_chat_template(self):
36
63
  return self.chat_template
37
64
 
@@ -40,6 +67,8 @@ class RuntimeEndpoint(BaseBackend):
40
67
  self.base_url + "/generate",
41
68
  json={"text": prefix_str, "sampling_params": {"max_new_tokens": 0}},
42
69
  auth_token=self.auth_token,
70
+ api_key=self.api_key,
71
+ verify=self.verify,
43
72
  )
44
73
  assert res.status_code == 200
45
74
 
@@ -48,6 +77,8 @@ class RuntimeEndpoint(BaseBackend):
48
77
  self.base_url + "/generate",
49
78
  json={"text": s.text_, "sampling_params": {"max_new_tokens": 0}},
50
79
  auth_token=self.auth_token,
80
+ api_key=self.api_key,
81
+ verify=self.verify,
51
82
  )
52
83
  assert res.status_code == 200
53
84
 
@@ -55,7 +86,11 @@ class RuntimeEndpoint(BaseBackend):
55
86
  data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
56
87
  self._add_images(s, data)
57
88
  res = http_request(
58
- self.base_url + "/generate", json=data, auth_token=self.auth_token
89
+ self.base_url + "/generate",
90
+ json=data,
91
+ auth_token=self.auth_token,
92
+ api_key=self.api_key,
93
+ verify=self.verify,
59
94
  )
60
95
  assert res.status_code == 200
61
96
 
@@ -87,7 +122,11 @@ class RuntimeEndpoint(BaseBackend):
87
122
  self._add_images(s, data)
88
123
 
89
124
  res = http_request(
90
- self.base_url + "/generate", json=data, auth_token=self.auth_token
125
+ self.base_url + "/generate",
126
+ json=data,
127
+ auth_token=self.auth_token,
128
+ api_key=self.api_key,
129
+ verify=self.verify,
91
130
  )
92
131
  obj = res.json()
93
132
  comp = obj["text"]
@@ -126,6 +165,8 @@ class RuntimeEndpoint(BaseBackend):
126
165
  json=data,
127
166
  stream=True,
128
167
  auth_token=self.auth_token,
168
+ api_key=self.api_key,
169
+ verify=self.verify,
129
170
  )
130
171
  pos = 0
131
172
 
@@ -157,7 +198,11 @@ class RuntimeEndpoint(BaseBackend):
157
198
  data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
158
199
  self._add_images(s, data)
159
200
  res = http_request(
160
- self.base_url + "/generate", json=data, auth_token=self.auth_token
201
+ self.base_url + "/generate",
202
+ json=data,
203
+ auth_token=self.auth_token,
204
+ api_key=self.api_key,
205
+ verify=self.verify,
161
206
  )
162
207
  assert res.status_code == 200
163
208
  prompt_len = res.json()["meta_info"]["prompt_tokens"]
@@ -171,7 +216,11 @@ class RuntimeEndpoint(BaseBackend):
171
216
  }
172
217
  self._add_images(s, data)
173
218
  res = http_request(
174
- self.base_url + "/generate", json=data, auth_token=self.auth_token
219
+ self.base_url + "/generate",
220
+ json=data,
221
+ auth_token=self.auth_token,
222
+ api_key=self.api_key,
223
+ verify=self.verify,
175
224
  )
176
225
  assert res.status_code == 200
177
226
  obj = res.json()
@@ -188,6 +237,8 @@ class RuntimeEndpoint(BaseBackend):
188
237
  self.base_url + "/concate_and_append_request",
189
238
  json={"src_rids": src_rids, "dst_rid": dst_rid},
190
239
  auth_token=self.auth_token,
240
+ api_key=self.api_key,
241
+ verify=self.verify,
191
242
  )
192
243
  assert res.status_code == 200
193
244
 
@@ -12,42 +12,43 @@ class ChatTemplateStyle(Enum):
12
12
  class ChatTemplate:
13
13
  name: str
14
14
  default_system_prompt: str
15
- role_prefix_and_suffix: Dict[str, Tuple[str]]
15
+ role_prefix_and_suffix: Dict[str, Tuple[str, str]]
16
16
  stop_str: List[str] = ()
17
17
  image_token: str = "<image>"
18
18
  style: ChatTemplateStyle = ChatTemplateStyle.PLAIN
19
19
 
20
- def get_prefix_and_suffix(self, role, hist_messages):
21
- if self.style == ChatTemplateStyle.PLAIN:
22
- return self.role_prefix_and_suffix[role]
23
- elif self.style == ChatTemplateStyle.LLAMA2:
24
- if len(hist_messages) == 0 and role == "system":
25
- return (
26
- self.role_prefix_and_suffix["user"][0]
27
- + self.role_prefix_and_suffix["system"][0],
28
- self.role_prefix_and_suffix["system"][1],
20
+ def get_prefix_and_suffix(
21
+ self, role: str, hist_messages: List[Dict]
22
+ ) -> Tuple[str, str]:
23
+ prefix, suffix = self.role_prefix_and_suffix.get(role, ("", ""))
24
+
25
+ if self.style == ChatTemplateStyle.LLAMA2:
26
+ if role == "system" and not hist_messages:
27
+ user_prefix, _ = self.role_prefix_and_suffix.get("user", ("", ""))
28
+ system_prefix, system_suffix = self.role_prefix_and_suffix.get(
29
+ "system", ("", "")
29
30
  )
31
+ return (user_prefix + system_prefix, system_suffix)
30
32
  elif (
31
- len(hist_messages) == 1
32
- and role == "user"
33
+ role == "user"
34
+ and len(hist_messages) == 1
33
35
  and hist_messages[0]["content"] is not None
34
36
  ):
35
- return ("", self.role_prefix_and_suffix["user"][1])
36
- return self.role_prefix_and_suffix[role]
37
- else:
38
- raise ValueError(f"Invalid style: {self.style}")
37
+ return ("", suffix)
38
+
39
+ return prefix, suffix
39
40
 
40
- def get_prompt(self, messages):
41
+ def get_prompt(self, messages: List[Dict]) -> str:
41
42
  prompt = ""
42
- for i in range(len(messages)):
43
- role, content = messages[i]["role"], messages[i]["content"]
43
+ for i, message in enumerate(messages):
44
+ role, content = message["role"], message["content"]
44
45
  if role == "system" and content is None:
45
46
  content = self.default_system_prompt
46
47
  if content is None:
47
48
  continue
48
49
 
49
50
  prefix, suffix = self.get_prefix_and_suffix(role, messages[:i])
50
- prompt += prefix + content + suffix
51
+ prompt += f"{prefix}{content}{suffix}"
51
52
  return prompt
52
53
 
53
54
 
@@ -106,9 +107,9 @@ register_chat_template(
106
107
  name="chatml",
107
108
  default_system_prompt=None,
108
109
  role_prefix_and_suffix={
109
- "system": ("<|im_start|>system\n", "\n<|im_end|>\n"),
110
- "user": ("<|im_start|>user\n", "\n<|im_end|>\n"),
111
- "assistant": ("<|im_start|>assistant\n", "\n<|im_end|>\n"),
110
+ "system": ("<|im_start|>system\n", "<|im_end|>\n"),
111
+ "user": ("<|im_start|>user\n", "<|im_end|>\n"),
112
+ "assistant": ("<|im_start|>assistant\n", "<|im_end|>\n"),
112
113
  },
113
114
  style=ChatTemplateStyle.PLAIN,
114
115
  stop_str=("<|im_end|>",),
@@ -121,9 +122,9 @@ register_chat_template(
121
122
  name="chatml-llava",
122
123
  default_system_prompt="Answer the questions.",
123
124
  role_prefix_and_suffix={
124
- "system": ("<|im_start|>system\n", "\n<|im_end|>\n"),
125
- "user": ("<|im_start|>user\n", "\n<|im_end|>\n"),
126
- "assistant": ("<|im_start|>assistant\n", "\n<|im_end|>\n"),
125
+ "system": ("<|im_start|>system\n", "<|im_end|>\n"),
126
+ "user": ("<|im_start|>user\n", "<|im_end|>\n"),
127
+ "assistant": ("<|im_start|>assistant\n", "<|im_end|>\n"),
127
128
  },
128
129
  style=ChatTemplateStyle.PLAIN,
129
130
  stop_str=("<|im_end|>",),
@@ -178,6 +179,19 @@ register_chat_template(
178
179
  )
179
180
  )
180
181
 
182
+ register_chat_template(
183
+ ChatTemplate(
184
+ name="gemma-it",
185
+ default_system_prompt=None,
186
+ role_prefix_and_suffix={
187
+ "system": ("", ""),
188
+ "user": ("<start_of_turn>user\n", "<end_of_turn>\n"),
189
+ "assistant": ("<start_of_turn>model\n", "<end_of_turn>\n"),
190
+ },
191
+ style=ChatTemplateStyle.PLAIN,
192
+ )
193
+ )
194
+
181
195
 
182
196
  @register_chat_template_matching_function
183
197
  def match_vicuna(model_path: str):
@@ -218,6 +232,13 @@ def match_chat_yi(model_path: str):
218
232
  return get_chat_template("yi")
219
233
 
220
234
 
235
+ @register_chat_template_matching_function
236
+ def match_gemma_it(model_path: str):
237
+ model_path = model_path.lower()
238
+ if "gemma" in model_path and "it" in model_path:
239
+ return get_chat_template("gemma-it")
240
+
241
+
221
242
  if __name__ == "__main__":
222
243
  messages = [
223
244
  {"role": "system", "content": None}, # None means default
@@ -245,6 +245,9 @@ class StreamExecutor:
245
245
  self.variable_event[name].wait()
246
246
  return self.variables[name]
247
247
 
248
+ def set_var(self, name, value):
249
+ self.variables[name] = value
250
+
248
251
  def get_meta_info(self, name):
249
252
  if name in self.variable_event:
250
253
  self.variable_event[name].wait()
@@ -583,6 +586,10 @@ class StreamExecutor:
583
586
  if self.chat_template.stop_str:
584
587
  if not clone:
585
588
  clone = self.default_sampling_para.clone()
589
+ if clone.stop == ():
590
+ clone.stop = []
591
+ elif isinstance(clone.stop, str):
592
+ clone.stop = [clone.stop]
586
593
  clone.stop += self.chat_template.stop_str
587
594
 
588
595
  return clone or self.default_sampling_para
@@ -679,7 +686,7 @@ class ProgramState:
679
686
  if var_name is None:
680
687
  yield self.text()
681
688
  else:
682
- yield self.get_var(name)
689
+ yield self.get_var(var_name)
683
690
 
684
691
  async def text_async_iter(
685
692
  self, var_name: Optional[str] = None, return_meta_data: bool = False
@@ -717,11 +724,14 @@ class ProgramState:
717
724
  if var_name is None:
718
725
  yield self.text()
719
726
  else:
720
- yield self.get_var(name)
727
+ yield self.get_var(var_name)
721
728
 
722
729
  def get_var(self, name):
723
730
  return self.stream_executor.get_var(name)
724
731
 
732
+ def set_var(self, name, value):
733
+ return self.stream_executor.set_var(name, value)
734
+
725
735
  def get_meta_info(self, name):
726
736
  return self.stream_executor.get_meta_info(name)
727
737
 
@@ -732,6 +742,9 @@ class ProgramState:
732
742
  def __getitem__(self, name):
733
743
  return self.get_var(name)
734
744
 
745
+ def __setitem__(self, name, value):
746
+ self.set_var(name, value)
747
+
735
748
  def __del__(self):
736
749
  self.stream_executor.end()
737
750
 
@@ -73,7 +73,7 @@ class SglSamplingParams:
73
73
  "Regular expression is not supported in the Anthropic backend."
74
74
  )
75
75
  return {
76
- "max_tokens_to_sample": self.max_new_tokens,
76
+ "max_tokens": self.max_new_tokens,
77
77
  "stop_sequences": (
78
78
  self.stop if isinstance(self.stop, (list, tuple)) else [self.stop]
79
79
  ),