sglang 0.1.1__tar.gz → 0.1.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. {sglang-0.1.1 → sglang-0.1.3}/PKG-INFO +86 -4
  2. {sglang-0.1.1 → sglang-0.1.3}/README.md +85 -3
  3. {sglang-0.1.1 → sglang-0.1.3}/pyproject.toml +1 -1
  4. {sglang-0.1.1 → sglang-0.1.3}/sglang/__init__.py +1 -1
  5. {sglang-0.1.1 → sglang-0.1.3}/sglang/api.py +13 -1
  6. {sglang-0.1.1 → sglang-0.1.3}/sglang/backend/anthropic.py +3 -3
  7. {sglang-0.1.1 → sglang-0.1.3}/sglang/backend/base_backend.py +3 -3
  8. {sglang-0.1.1 → sglang-0.1.3}/sglang/backend/openai.py +3 -3
  9. {sglang-0.1.1 → sglang-0.1.3}/sglang/backend/runtime_endpoint.py +3 -3
  10. {sglang-0.1.1 → sglang-0.1.3}/sglang/backend/tgi.py +3 -3
  11. {sglang-0.1.1 → sglang-0.1.3}/sglang/lang/compiler.py +3 -7
  12. {sglang-0.1.1 → sglang-0.1.3}/sglang/lang/interpreter.py +5 -7
  13. {sglang-0.1.1 → sglang-0.1.3}/sglang/lang/ir.py +13 -9
  14. {sglang-0.1.1 → sglang-0.1.3}/sglang/lang/tracer.py +2 -1
  15. sglang-0.1.3/sglang/srt/backend_config.py +12 -0
  16. {sglang-0.1.1 → sglang-0.1.3}/sglang/srt/managers/router/manager.py +10 -2
  17. {sglang-0.1.1 → sglang-0.1.3}/sglang/srt/managers/router/model_rpc.py +15 -2
  18. {sglang-0.1.1 → sglang-0.1.3}/sglang/srt/managers/router/radix_cache.py +2 -2
  19. {sglang-0.1.1 → sglang-0.1.3}/sglang/srt/managers/router/scheduler.py +1 -1
  20. {sglang-0.1.1 → sglang-0.1.3}/sglang/srt/models/mixtral.py +1 -1
  21. {sglang-0.1.1 → sglang-0.1.3}/sglang/srt/server_args.py +8 -2
  22. {sglang-0.1.1 → sglang-0.1.3}/sglang/test/test_programs.py +1 -1
  23. {sglang-0.1.1 → sglang-0.1.3}/sglang/test/test_utils.py +22 -2
  24. {sglang-0.1.1 → sglang-0.1.3}/sglang/utils.py +1 -1
  25. {sglang-0.1.1 → sglang-0.1.3}/sglang.egg-info/PKG-INFO +86 -4
  26. {sglang-0.1.1 → sglang-0.1.3}/sglang.egg-info/SOURCES.txt +1 -0
  27. {sglang-0.1.1 → sglang-0.1.3}/LICENSE +0 -0
  28. {sglang-0.1.1 → sglang-0.1.3}/setup.cfg +0 -0
  29. {sglang-0.1.1 → sglang-0.1.3}/sglang/backend/__init__.py +0 -0
  30. {sglang-0.1.1 → sglang-0.1.3}/sglang/backend/huggingface.py +0 -0
  31. {sglang-0.1.1 → sglang-0.1.3}/sglang/flush_cache.py +0 -0
  32. {sglang-0.1.1 → sglang-0.1.3}/sglang/global_config.py +0 -0
  33. {sglang-0.1.1 → sglang-0.1.3}/sglang/lang/__init__.py +0 -0
  34. {sglang-0.1.1 → sglang-0.1.3}/sglang/lang/chat_template.py +0 -0
  35. {sglang-0.1.1 → sglang-0.1.3}/sglang/launch_server.py +0 -0
  36. {sglang-0.1.1 → sglang-0.1.3}/sglang/srt/constrained/fsm.py +0 -0
  37. {sglang-0.1.1 → sglang-0.1.3}/sglang/srt/constrained/fsm_cache.py +0 -0
  38. {sglang-0.1.1 → sglang-0.1.3}/sglang/srt/constrained/regex.py +0 -0
  39. {sglang-0.1.1 → sglang-0.1.3}/sglang/srt/constrained/tokenizer.py +0 -0
  40. {sglang-0.1.1 → sglang-0.1.3}/sglang/srt/hf_transformers_utils.py +0 -0
  41. {sglang-0.1.1 → sglang-0.1.3}/sglang/srt/layers/context_flashattention_nopad.py +0 -0
  42. {sglang-0.1.1 → sglang-0.1.3}/sglang/srt/layers/extend_attention.py +0 -0
  43. {sglang-0.1.1 → sglang-0.1.3}/sglang/srt/layers/get_selected_logprob.py +0 -0
  44. {sglang-0.1.1 → sglang-0.1.3}/sglang/srt/layers/logits_processor.py +0 -0
  45. {sglang-0.1.1 → sglang-0.1.3}/sglang/srt/layers/radix_attention.py +0 -0
  46. {sglang-0.1.1 → sglang-0.1.3}/sglang/srt/layers/token_attention.py +0 -0
  47. {sglang-0.1.1 → sglang-0.1.3}/sglang/srt/managers/detokenizer_manager.py +0 -0
  48. {sglang-0.1.1 → sglang-0.1.3}/sglang/srt/managers/io_struct.py +0 -0
  49. {sglang-0.1.1 → sglang-0.1.3}/sglang/srt/managers/openai_protocol.py +0 -0
  50. {sglang-0.1.1 → sglang-0.1.3}/sglang/srt/managers/router/infer_batch.py +0 -0
  51. {sglang-0.1.1 → sglang-0.1.3}/sglang/srt/managers/router/model_runner.py +0 -0
  52. {sglang-0.1.1 → sglang-0.1.3}/sglang/srt/managers/tokenizer_manager.py +0 -0
  53. {sglang-0.1.1 → sglang-0.1.3}/sglang/srt/memory_pool.py +0 -0
  54. {sglang-0.1.1 → sglang-0.1.3}/sglang/srt/model_config.py +0 -0
  55. {sglang-0.1.1 → sglang-0.1.3}/sglang/srt/models/llama2.py +0 -0
  56. {sglang-0.1.1 → sglang-0.1.3}/sglang/srt/models/llava.py +0 -0
  57. {sglang-0.1.1 → sglang-0.1.3}/sglang/srt/sampling_params.py +2 -2
  58. {sglang-0.1.1 → sglang-0.1.3}/sglang/srt/server.py +0 -0
  59. {sglang-0.1.1 → sglang-0.1.3}/sglang/srt/utils.py +0 -0
  60. {sglang-0.1.1 → sglang-0.1.3}/sglang.egg-info/dependency_links.txt +0 -0
  61. {sglang-0.1.1 → sglang-0.1.3}/sglang.egg-info/requires.txt +0 -0
  62. {sglang-0.1.1 → sglang-0.1.3}/sglang.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: sglang
3
- Version: 0.1.1
3
+ Version: 0.1.3
4
4
  Summary: A structured generation langauge for LLMs.
5
5
  License: Apache License
6
6
  Version 2.0, January 2004
@@ -329,25 +329,99 @@ You can find more examples at [examples/quick_start](examples/quick_start).
329
329
 
330
330
  ## Frontend: Structured Generation Langauge (SGLang)
331
331
 
332
+ To begin with, import sglang.
333
+ ```python
334
+ import sglang as sgl
335
+ ```
336
+
337
+ `sglang` provides some simple primitives such as `gen`, `select`, `fork`.
338
+ You can implement your prompt flow in a function decorated by `sgl.function`.
339
+ You can then invoke the function with `run` or `run_batch`.
340
+ The system will manage the state, chat template, and parallelism for you.
341
+
332
342
  ### Control Flow
343
+ ```python
344
+ @sgl.function
345
+ def control_flow(s, question):
346
+ s += "To answer this question: " + question + ", "
347
+ s += "I need to use a " + sgl.gen("tool", choices=["calculator", "web browser"]) + ". "
348
+
349
+ # You can use if or nested function calls
350
+ if s["tool"] == "calculator":
351
+ s += "The math expression is" + sgl.gen("expression")
352
+ elif s["tool"] == "web browser":
353
+ s += "The website url is" + sgl.gen("url")
354
+ ```
333
355
 
334
356
  ### Parallelism
357
+ ```python
358
+ @sgl.function
359
+ def tip_suggestion(s):
360
+ s += (
361
+ "Here are two tips for staying healthy: "
362
+ "1. Balanced Diet. 2. Regular Exercise.\n\n"
363
+ )
364
+
365
+ forks = s.fork(2) # Launch parallel prompts
366
+ for i, f in enumerate(forks):
367
+ f += f"Now, expand tip {i+1} into a paragraph:\n"
368
+ f += sgl.gen(f"detailed_tip", max_tokens=256, stop="\n\n")
369
+
370
+ s += "Tip 1:" + forks[0]["detailed_tip"] + "\n"
371
+ s += "Tip 2:" + forks[1]["detailed_tip"] + "\n"
372
+ s += "In summary" + sgl.gen("summary")
373
+ ```
335
374
 
336
375
  ### Multi Modality
337
376
  ```python
338
377
  @sgl.function
339
378
  def image_qa(s, image_file, question):
340
379
  s += sgl.user(sgl.image(image_file) + question)
341
- s += sgl.assistant(sgl.gen("answer_1", max_tokens=256))
380
+ s += sgl.assistant(sgl.gen("answer", max_tokens=256)
342
381
  ```
343
382
 
344
- ### Constrained decoding
383
+ ### Constrained Decoding
384
+ ```python
385
+ @function
386
+ def regular_expression_gen(s):
387
+ s += "Q: What is the IP address of the Google DNS servers?\n"
388
+ s += "A: " + gen(
389
+ "answer",
390
+ temperature=0,
391
+ regex=r"((25[0-5]|2[0-4]\d|[01]?\d\d?).){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)",
392
+ )
393
+ ```
345
394
 
346
395
  ### Batching
396
+ ```python
397
+ @sgl.function
398
+ def text_qa(s, question):
399
+ s += "Q: " + question + "\n"
400
+ s += "A:" + sgl.gen("answer", stop="\n")
401
+
402
+ states = text_qa.run_batch(
403
+ [
404
+ {"question": "What is the capital of the United Kingdom?"},
405
+ {"question": "What is the capital of France?"},
406
+ {"question": "What is the capital of Japan?"},
407
+ ],
408
+ )
409
+ ```
347
410
 
348
411
  ### Streaming
412
+ ```python
413
+ @sgl.function
414
+ def text_qa(s, question):
415
+ s += "Q: " + question + "\n"
416
+ s += "A:" + sgl.gen("answer", stop="\n")
417
+
418
+ states = text_qa.run(
419
+ question="What is the capital of France?",
420
+ temperature=0.1)
349
421
 
350
- ### Other Backends
422
+ for out in state.text_iter():
423
+ print(out, end="", flush=True)
424
+ ```
351
425
 
352
426
  ## Backend: SGLang Runtime (SRT)
353
427
  The SGLang Runtime (SRT) is designed to work best with the SGLang frontend.
@@ -386,6 +460,14 @@ python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port
386
460
 
387
461
  ## Benchmark And Performance
388
462
 
463
+ - Llama-7B on NVIDIA A10G, FP16, Tensor Parallelism=1
464
+ ![llama_7b](assets/llama_7b.jpg)
465
+
466
+ - Mixtral-8x7B on NVIDIA A10G, FP16, Tensor Parallelism=8
467
+ ![mixtral_8x7b](assets/mixtral_8x7b.jpg)
468
+
469
+ Learn more [here]().
470
+
389
471
  ## Roadmap
390
472
  - [ ] Function call
391
473
  - [ ] Quantization
@@ -94,25 +94,99 @@ You can find more examples at [examples/quick_start](examples/quick_start).
94
94
 
95
95
  ## Frontend: Structured Generation Langauge (SGLang)
96
96
 
97
+ To begin with, import sglang.
98
+ ```python
99
+ import sglang as sgl
100
+ ```
101
+
102
+ `sglang` provides some simple primitives such as `gen`, `select`, `fork`.
103
+ You can implement your prompt flow in a function decorated by `sgl.function`.
104
+ You can then invoke the function with `run` or `run_batch`.
105
+ The system will manage the state, chat template, and parallelism for you.
106
+
97
107
  ### Control Flow
108
+ ```python
109
+ @sgl.function
110
+ def control_flow(s, question):
111
+ s += "To answer this question: " + question + ", "
112
+ s += "I need to use a " + sgl.gen("tool", choices=["calculator", "web browser"]) + ". "
113
+
114
+ # You can use if or nested function calls
115
+ if s["tool"] == "calculator":
116
+ s += "The math expression is" + sgl.gen("expression")
117
+ elif s["tool"] == "web browser":
118
+ s += "The website url is" + sgl.gen("url")
119
+ ```
98
120
 
99
121
  ### Parallelism
122
+ ```python
123
+ @sgl.function
124
+ def tip_suggestion(s):
125
+ s += (
126
+ "Here are two tips for staying healthy: "
127
+ "1. Balanced Diet. 2. Regular Exercise.\n\n"
128
+ )
129
+
130
+ forks = s.fork(2) # Launch parallel prompts
131
+ for i, f in enumerate(forks):
132
+ f += f"Now, expand tip {i+1} into a paragraph:\n"
133
+ f += sgl.gen(f"detailed_tip", max_tokens=256, stop="\n\n")
134
+
135
+ s += "Tip 1:" + forks[0]["detailed_tip"] + "\n"
136
+ s += "Tip 2:" + forks[1]["detailed_tip"] + "\n"
137
+ s += "In summary" + sgl.gen("summary")
138
+ ```
100
139
 
101
140
  ### Multi Modality
102
141
  ```python
103
142
  @sgl.function
104
143
  def image_qa(s, image_file, question):
105
144
  s += sgl.user(sgl.image(image_file) + question)
106
- s += sgl.assistant(sgl.gen("answer_1", max_tokens=256))
145
+ s += sgl.assistant(sgl.gen("answer", max_tokens=256)
107
146
  ```
108
147
 
109
- ### Constrained decoding
148
+ ### Constrained Decoding
149
+ ```python
150
+ @function
151
+ def regular_expression_gen(s):
152
+ s += "Q: What is the IP address of the Google DNS servers?\n"
153
+ s += "A: " + gen(
154
+ "answer",
155
+ temperature=0,
156
+ regex=r"((25[0-5]|2[0-4]\d|[01]?\d\d?).){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)",
157
+ )
158
+ ```
110
159
 
111
160
  ### Batching
161
+ ```python
162
+ @sgl.function
163
+ def text_qa(s, question):
164
+ s += "Q: " + question + "\n"
165
+ s += "A:" + sgl.gen("answer", stop="\n")
166
+
167
+ states = text_qa.run_batch(
168
+ [
169
+ {"question": "What is the capital of the United Kingdom?"},
170
+ {"question": "What is the capital of France?"},
171
+ {"question": "What is the capital of Japan?"},
172
+ ],
173
+ )
174
+ ```
112
175
 
113
176
  ### Streaming
177
+ ```python
178
+ @sgl.function
179
+ def text_qa(s, question):
180
+ s += "Q: " + question + "\n"
181
+ s += "A:" + sgl.gen("answer", stop="\n")
182
+
183
+ states = text_qa.run(
184
+ question="What is the capital of France?",
185
+ temperature=0.1)
114
186
 
115
- ### Other Backends
187
+ for out in state.text_iter():
188
+ print(out, end="", flush=True)
189
+ ```
116
190
 
117
191
  ## Backend: SGLang Runtime (SRT)
118
192
  The SGLang Runtime (SRT) is designed to work best with the SGLang frontend.
@@ -151,6 +225,14 @@ python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port
151
225
 
152
226
  ## Benchmark And Performance
153
227
 
228
+ - Llama-7B on NVIDIA A10G, FP16, Tensor Parallelism=1
229
+ ![llama_7b](assets/llama_7b.jpg)
230
+
231
+ - Mixtral-8x7B on NVIDIA A10G, FP16, Tensor Parallelism=8
232
+ ![mixtral_8x7b](assets/mixtral_8x7b.jpg)
233
+
234
+ Learn more [here]().
235
+
154
236
  ## Roadmap
155
237
  - [ ] Function call
156
238
  - [ ] Quantization
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "sglang"
7
- version = "0.1.1"
7
+ version = "0.1.3"
8
8
  description = "A structured generation langauge for LLMs."
9
9
  readme = "README.md"
10
10
  requires-python = ">=3.8"
@@ -1,4 +1,4 @@
1
- __version__ == "0.1.1"
1
+ __version__ = "0.1.3"
2
2
 
3
3
  from sglang.api import *
4
4
  from sglang.global_config import global_config
@@ -17,13 +17,19 @@ from sglang.lang.ir import (
17
17
  SglRoleEnd,
18
18
  SglSelect,
19
19
  )
20
- from sglang.srt.server import Runtime
21
20
 
22
21
 
23
22
  def function(func: Callable):
24
23
  return SglFunction(func)
25
24
 
26
25
 
26
+ def Runtime(*args, **kwargs):
27
+ # Avoid importing unnecessary dependency
28
+ from sglang.srt.server import Runtime
29
+
30
+ return Runtime(*args, **kwargs)
31
+
32
+
27
33
  def set_default_backend(backend: BaseBackend):
28
34
  global_config.default_backend = backend
29
35
 
@@ -37,6 +43,7 @@ def gen(
37
43
  top_k: Optional[int] = None,
38
44
  frequency_penalty: Optional[float] = None,
39
45
  presence_penalty: Optional[float] = None,
46
+ ignore_eos: Optional[bool] = None,
40
47
  dtype: Optional[type] = None,
41
48
  choices: Optional[List[str]] = None,
42
49
  regex: Optional[str] = None,
@@ -60,6 +67,7 @@ def gen(
60
67
  top_k,
61
68
  frequency_penalty,
62
69
  presence_penalty,
70
+ ignore_eos,
63
71
  dtype,
64
72
  regex,
65
73
  )
@@ -74,6 +82,7 @@ def gen_int(
74
82
  top_k: Optional[int] = None,
75
83
  frequency_penalty: Optional[float] = None,
76
84
  presence_penalty: Optional[float] = None,
85
+ ignore_eos: Optional[bool] = None,
77
86
  ):
78
87
  return SglGen(
79
88
  name,
@@ -84,6 +93,7 @@ def gen_int(
84
93
  top_k,
85
94
  frequency_penalty,
86
95
  presence_penalty,
96
+ ignore_eos,
87
97
  int,
88
98
  None,
89
99
  )
@@ -98,6 +108,7 @@ def gen_string(
98
108
  top_k: Optional[int] = None,
99
109
  frequency_penalty: Optional[float] = None,
100
110
  presence_penalty: Optional[float] = None,
111
+ ignore_eos: Optional[bool] = None,
101
112
  ):
102
113
  return SglGen(
103
114
  name,
@@ -108,6 +119,7 @@ def gen_string(
108
119
  top_k,
109
120
  frequency_penalty,
110
121
  presence_penalty,
122
+ ignore_eos,
111
123
  str,
112
124
  None,
113
125
  )
@@ -4,7 +4,7 @@ import numpy as np
4
4
  from sglang.backend.base_backend import BaseBackend
5
5
  from sglang.lang.chat_template import get_chat_template
6
6
  from sglang.lang.interpreter import StreamExecutor
7
- from sglang.lang.ir import SamplingParams
7
+ from sglang.lang.ir import SglSamplingParams
8
8
 
9
9
  try:
10
10
  import anthropic
@@ -28,7 +28,7 @@ class Anthropic(BaseBackend):
28
28
  def generate(
29
29
  self,
30
30
  s: StreamExecutor,
31
- sampling_params: SamplingParams,
31
+ sampling_params: SglSamplingParams,
32
32
  ):
33
33
  prompt = s.text_
34
34
  ret = anthropic.Anthropic().completions.create(
@@ -43,7 +43,7 @@ class Anthropic(BaseBackend):
43
43
  def generate_stream(
44
44
  self,
45
45
  s: StreamExecutor,
46
- sampling_params: SamplingParams,
46
+ sampling_params: SglSamplingParams,
47
47
  ):
48
48
  prompt = s.text_
49
49
  generator = anthropic.Anthropic().completions.create(
@@ -2,7 +2,7 @@ from typing import Callable, List, Optional, Union
2
2
 
3
3
  from sglang.lang.chat_template import get_chat_template
4
4
  from sglang.lang.interpreter import StreamExecutor
5
- from sglang.lang.ir import SamplingParams
5
+ from sglang.lang.ir import SglSamplingParams
6
6
 
7
7
 
8
8
  class BaseBackend:
@@ -48,14 +48,14 @@ class BaseBackend:
48
48
  def generate(
49
49
  self,
50
50
  s: StreamExecutor,
51
- sampling_params: SamplingParams,
51
+ sampling_params: SglSamplingParams,
52
52
  ):
53
53
  raise NotImplementedError()
54
54
 
55
55
  def generate_stream(
56
56
  self,
57
57
  s: StreamExecutor,
58
- sampling_params: SamplingParams,
58
+ sampling_params: SglSamplingParams,
59
59
  ):
60
60
  raise NotImplementedError()
61
61
 
@@ -4,7 +4,7 @@ import numpy as np
4
4
  from sglang.backend.base_backend import BaseBackend
5
5
  from sglang.lang.chat_template import get_chat_template
6
6
  from sglang.lang.interpreter import StreamExecutor
7
- from sglang.lang.ir import SamplingParams
7
+ from sglang.lang.ir import SglSamplingParams
8
8
 
9
9
  try:
10
10
  import openai
@@ -73,7 +73,7 @@ class OpenAI(BaseBackend):
73
73
  def generate(
74
74
  self,
75
75
  s: StreamExecutor,
76
- sampling_params: SamplingParams,
76
+ sampling_params: SglSamplingParams,
77
77
  ):
78
78
  if sampling_params.dtype is None:
79
79
  if self.is_chat_model:
@@ -122,7 +122,7 @@ class OpenAI(BaseBackend):
122
122
  def generate_stream(
123
123
  self,
124
124
  s: StreamExecutor,
125
- sampling_params: SamplingParams,
125
+ sampling_params: SglSamplingParams,
126
126
  ):
127
127
  if sampling_params.dtype is None:
128
128
  if self.is_chat_model:
@@ -7,7 +7,7 @@ from sglang.backend.base_backend import BaseBackend
7
7
  from sglang.global_config import global_config
8
8
  from sglang.lang.chat_template import get_chat_template_by_model_path
9
9
  from sglang.lang.interpreter import StreamExecutor
10
- from sglang.lang.ir import SamplingParams, SglArgument
10
+ from sglang.lang.ir import SglArgument, SglSamplingParams
11
11
  from sglang.utils import encode_image_base64, find_printable_text, http_request
12
12
 
13
13
 
@@ -55,7 +55,7 @@ class RuntimeEndpoint(BaseBackend):
55
55
  def generate(
56
56
  self,
57
57
  s: StreamExecutor,
58
- sampling_params: SamplingParams,
58
+ sampling_params: SglSamplingParams,
59
59
  ):
60
60
  if sampling_params.dtype is None:
61
61
  data = {
@@ -87,7 +87,7 @@ class RuntimeEndpoint(BaseBackend):
87
87
  def generate_stream(
88
88
  self,
89
89
  s: StreamExecutor,
90
- sampling_params: SamplingParams,
90
+ sampling_params: SglSamplingParams,
91
91
  ):
92
92
  if sampling_params.dtype is None:
93
93
  data = {
@@ -7,7 +7,7 @@ from typing import List, Optional, Union
7
7
  from sglang.backend.base_backend import BaseBackend
8
8
  from sglang.lang.chat_template import get_chat_template_by_model_path
9
9
  from sglang.lang.interpreter import StreamExecutor
10
- from sglang.lang.ir import SamplingParams
10
+ from sglang.lang.ir import SglSamplingParams
11
11
  from sglang.utils import http_request
12
12
 
13
13
 
@@ -138,7 +138,7 @@ class TGI(BaseBackend):
138
138
  self,
139
139
  s: StreamExecutor,
140
140
  choices: List[str],
141
- sampling_params: SamplingParams,
141
+ sampling_params: SglSamplingParams,
142
142
  ):
143
143
  decision = self.retry_for_expected(
144
144
  s.text_,
@@ -152,7 +152,7 @@ class TGI(BaseBackend):
152
152
  s: StreamExecutor,
153
153
  max_tokens: int,
154
154
  stop: Union[str, List[str]],
155
- sampling_params: SamplingParams,
155
+ sampling_params: SglSamplingParams,
156
156
  dtype: Optional[str] = None,
157
157
  ):
158
158
  if dtype is None:
@@ -6,10 +6,10 @@ from typing import List, Union
6
6
  from sglang.global_config import global_config
7
7
  from sglang.lang.interpreter import ProgramState, StreamExecutor, pin_program
8
8
  from sglang.lang.ir import (
9
- SamplingParams,
10
9
  SglArgument,
11
10
  SglConstantText,
12
11
  SglExpr,
12
+ SglSamplingParams,
13
13
  SglVariable,
14
14
  )
15
15
 
@@ -137,10 +137,9 @@ class CompiledFunction:
137
137
  ):
138
138
  backend = backend or global_config.default_backend
139
139
 
140
- kwargs = {k: SglArgument(k, v) for k, v in kwargs.items()}
141
140
  kwargs.update(self.function.bind_arguments)
142
141
 
143
- default_sampling_para = SamplingParams(
142
+ default_sampling_para = SglSamplingParams(
144
143
  max_new_tokens=max_new_tokens,
145
144
  stop=stop,
146
145
  temperature=temperature,
@@ -173,7 +172,7 @@ class CompiledFunction:
173
172
 
174
173
  backend = backend or global_config.default_backend
175
174
 
176
- default_sampling_para = SamplingParams(
175
+ default_sampling_para = SglSamplingParams(
177
176
  max_new_tokens=max_new_tokens,
178
177
  stop=stop,
179
178
  temperature=temperature,
@@ -182,9 +181,6 @@ class CompiledFunction:
182
181
  frequency_penalty=frequency_penalty,
183
182
  presence_penalty=presence_penalty,
184
183
  )
185
- batch_kwargs = [
186
- {k: SglArgument(k, v) for k, v in kwargs.items()} for kwargs in batch_kwargs
187
- ]
188
184
 
189
185
  # Extract prefix by tracing and cache it
190
186
  if len(batch_kwargs) > 1:
@@ -12,7 +12,6 @@ from typing import Any, Callable, Dict, List, Optional, Union
12
12
  import tqdm
13
13
  from sglang.global_config import global_config
14
14
  from sglang.lang.ir import (
15
- SglArgument,
16
15
  SglCommitLazy,
17
16
  SglConcateAndAppend,
18
17
  SglConstantText,
@@ -89,7 +88,7 @@ def run_program_batch(
89
88
  for arguments in batch_arguments:
90
89
  rets.append(
91
90
  run_program(
92
- program, backend, (), arguments, default_sampling_para, False, False
91
+ program, backend, (), arguments, default_sampling_para, False, True
93
92
  )
94
93
  )
95
94
  else:
@@ -108,7 +107,7 @@ def run_program_batch(
108
107
  arguments,
109
108
  default_sampling_para,
110
109
  False,
111
- False,
110
+ True,
112
111
  )
113
112
  )
114
113
  if progress_bar:
@@ -292,7 +291,7 @@ class StreamExecutor:
292
291
 
293
292
  assert isinstance(other, SglExpr), f"{other}"
294
293
 
295
- if isinstance(other, (SglConstantText, SglArgument)):
294
+ if isinstance(other, SglConstantText):
296
295
  self._execute_fill(other.value)
297
296
  elif isinstance(other, SglGen):
298
297
  self._execute_gen(other)
@@ -332,8 +331,6 @@ class StreamExecutor:
332
331
 
333
332
  def _execute_image(self, expr: SglImage):
334
333
  path = expr.path
335
- if isinstance(path, SglArgument):
336
- path = path.value
337
334
 
338
335
  base64_data = encode_image_base64(path)
339
336
 
@@ -419,7 +416,7 @@ class StreamExecutor:
419
416
  "role": expr.role,
420
417
  "content": [{"type": "text", "text": new_text}],
421
418
  }
422
- for (image_path, image_base64_data) in self.cur_images:
419
+ for image_path, image_base64_data in self.cur_images:
423
420
  last_msg["content"].append(
424
421
  {
425
422
  "type": "image_url",
@@ -480,6 +477,7 @@ class StreamExecutor:
480
477
  "top_k",
481
478
  "frequency_penalty",
482
479
  "presence_penalty",
480
+ "ignore_eos",
483
481
  "dtype",
484
482
  "regex",
485
483
  ]:
@@ -13,7 +13,7 @@ REGEX_STRING = r"\"[\w\d\s]*\"" # bugs with regex r"\".*\"" in interegular pkg
13
13
 
14
14
 
15
15
  @dataclasses.dataclass
16
- class SamplingParams:
16
+ class SglSamplingParams:
17
17
  max_new_tokens: int = 16
18
18
  stop: Union[str, List[str]] = ()
19
19
  temperature: float = 1.0
@@ -21,13 +21,14 @@ class SamplingParams:
21
21
  top_k: int = -1 # -1 means disable
22
22
  frequency_penalty: float = 0.0
23
23
  presence_penalty: float = 0.0
24
+ ignore_eos: bool = False
24
25
 
25
26
  # for constrained generation, not included in to_xxx_kwargs
26
27
  dtype: Optional[str] = None
27
28
  regex: Optional[str] = None
28
29
 
29
30
  def clone(self):
30
- return SamplingParams(
31
+ return SglSamplingParams(
31
32
  self.max_new_tokens,
32
33
  self.stop,
33
34
  self.temperature,
@@ -67,6 +68,7 @@ class SamplingParams:
67
68
  "top_k": self.top_k,
68
69
  "frequency_penalty": self.frequency_penalty,
69
70
  "presence_penalty": self.presence_penalty,
71
+ "ignore_eos": self.ignore_eos,
70
72
  "regex": self.regex,
71
73
  }
72
74
 
@@ -98,13 +100,14 @@ class SglFunction:
98
100
  top_k: int = -1,
99
101
  frequency_penalty: float = 0.0,
100
102
  presence_penalty: float = 0.0,
103
+ ignore_eos: bool = False,
101
104
  stream: bool = False,
102
105
  backend=None,
103
106
  **kwargs,
104
107
  ):
105
108
  from sglang.lang.interpreter import run_program
106
109
 
107
- default_sampling_para = SamplingParams(
110
+ default_sampling_para = SglSamplingParams(
108
111
  max_new_tokens=max_new_tokens,
109
112
  stop=stop,
110
113
  temperature=temperature,
@@ -112,9 +115,9 @@ class SglFunction:
112
115
  top_k=top_k,
113
116
  frequency_penalty=frequency_penalty,
114
117
  presence_penalty=presence_penalty,
118
+ ignore_eos=ignore_eos,
115
119
  )
116
120
  backend = backend or global_config.default_backend
117
- kwargs = {k: SglArgument(k, v) for k, v in kwargs.items()}
118
121
  return run_program(self, backend, args, kwargs, default_sampling_para, stream)
119
122
 
120
123
  def run_batch(
@@ -128,6 +131,7 @@ class SglFunction:
128
131
  top_k: int = -1,
129
132
  frequency_penalty: float = 0.0,
130
133
  presence_penalty: float = 0.0,
134
+ ignore_eos: bool = False,
131
135
  backend=None,
132
136
  num_threads: Union[str, int] = "auto",
133
137
  progress_bar: bool = False,
@@ -139,7 +143,7 @@ class SglFunction:
139
143
  return []
140
144
  assert isinstance(batch_kwargs[0], dict)
141
145
 
142
- default_sampling_para = SamplingParams(
146
+ default_sampling_para = SglSamplingParams(
143
147
  max_new_tokens=max_new_tokens,
144
148
  stop=stop,
145
149
  temperature=temperature,
@@ -147,11 +151,9 @@ class SglFunction:
147
151
  top_k=top_k,
148
152
  frequency_penalty=frequency_penalty,
149
153
  presence_penalty=presence_penalty,
154
+ ignore_eos=ignore_eos,
150
155
  )
151
156
  backend = backend or global_config.default_backend
152
- batch_kwargs = [
153
- {k: SglArgument(k, v) for k, v in kwargs.items()} for kwargs in batch_kwargs
154
- ]
155
157
  return run_program_batch(
156
158
  self,
157
159
  backend,
@@ -321,12 +323,13 @@ class SglGen(SglExpr):
321
323
  top_k,
322
324
  frequency_penalty,
323
325
  presence_penalty,
326
+ ignore_eos,
324
327
  dtype,
325
328
  regex,
326
329
  ):
327
330
  super().__init__()
328
331
  self.name = name
329
- self.sampling_params = SamplingParams(
332
+ self.sampling_params = SglSamplingParams(
330
333
  max_new_tokens=max_new_tokens,
331
334
  stop=stop,
332
335
  temperature=temperature,
@@ -334,6 +337,7 @@ class SglGen(SglExpr):
334
337
  top_k=top_k,
335
338
  frequency_penalty=frequency_penalty,
336
339
  presence_penalty=presence_penalty,
340
+ ignore_eos=ignore_eos,
337
341
  dtype=dtype,
338
342
  regex=regex,
339
343
  )
@@ -40,7 +40,8 @@ def extract_prefix_by_tracing(program, backend):
40
40
  try:
41
41
  with TracingScope(tracer):
42
42
  tracer.ret_value = program.func(tracer, **arguments)
43
- except StopTracing:
43
+ except (StopTracing, TypeError):
44
+ # Some exceptions may not be catched
44
45
  pass
45
46
 
46
47
  # Run and cache prefix
@@ -0,0 +1,12 @@
1
+ """
2
+ Backend configurations, may vary with different serving platforms.
3
+ """
4
+ from dataclasses import dataclass
5
+
6
+
7
+ @dataclass
8
+ class BackendConfig:
9
+ extend_dependency_time: float = 0.03
10
+
11
+
12
+ GLOBAL_BACKEND_CONFIG = BackendConfig()
@@ -1,10 +1,10 @@
1
1
  import asyncio
2
2
  import logging
3
- from typing import List, Tuple
4
3
 
5
4
  import uvloop
6
5
  import zmq
7
6
  import zmq.asyncio
7
+ from sglang.srt.backend_config import GLOBAL_BACKEND_CONFIG
8
8
  from sglang.srt.managers.router.model_rpc import ModelRpcClient
9
9
  from sglang.srt.server_args import PortArgs, ServerArgs
10
10
  from sglang.srt.utils import get_exception_traceback
@@ -28,6 +28,9 @@ class RouterManager:
28
28
  self.model_client = model_client
29
29
  self.recv_reqs = []
30
30
 
31
+ # Init Some Configs
32
+ self.extend_dependency_time = GLOBAL_BACKEND_CONFIG.extend_dependency_time
33
+
31
34
  async def loop_for_forward(self):
32
35
  while True:
33
36
  next_step_input = list(self.recv_reqs)
@@ -37,7 +40,12 @@ class RouterManager:
37
40
  for obj in out_pyobjs:
38
41
  self.send_to_detokenizer.send_pyobj(obj)
39
42
 
40
- # await for a while to accept input requests
43
+ # async sleep for recving the subsequent request, and avoiding cache miss
44
+ if len(out_pyobjs) != 0:
45
+ has_finished = any([obj.finished for obj in out_pyobjs])
46
+ if has_finished:
47
+ await asyncio.sleep(self.extend_dependency_time)
48
+
41
49
  await asyncio.sleep(0.001)
42
50
 
43
51
  async def loop_for_recv_requests(self):
@@ -19,7 +19,6 @@ from sglang.srt.managers.router.model_runner import ModelRunner
19
19
  from sglang.srt.managers.router.radix_cache import RadixCache
20
20
  from sglang.srt.managers.router.scheduler import Scheduler
21
21
  from sglang.srt.model_config import ModelConfig
22
- from sglang.srt.sampling_params import SamplingParams
23
22
  from sglang.srt.server_args import PortArgs, ServerArgs
24
23
  from sglang.srt.utils import (
25
24
  get_exception_traceback,
@@ -158,6 +157,18 @@ class ModelRpcServer(rpyc.Service):
158
157
  if self.running_batch.is_empty():
159
158
  self.running_batch = None
160
159
  break
160
+ else:
161
+ # check the available size
162
+ available_size = (
163
+ self.token_to_kv_pool.available_size()
164
+ + self.tree_cache.evictable_size()
165
+ )
166
+ if available_size != self.max_total_num_token:
167
+ logger.warning(
168
+ "Warning: "
169
+ f"available_size={available_size}, max_total_num_token={self.max_total_num_token}\n"
170
+ "KV cache pool leak detected!"
171
+ )
161
172
 
162
173
  if self.running_batch is not None and self.tp_rank == 0:
163
174
  if self.decode_forward_ct >= 20:
@@ -408,7 +419,9 @@ class ModelRpcServer(rpyc.Service):
408
419
  token_ids = tuple(req.input_ids + req.output_ids)
409
420
  seq_len = len(token_ids) - 1
410
421
  indices = self.req_to_token_pool.req_to_token[req_pool_idx, :seq_len]
411
- prefix_len = self.tree_cache.insert(token_ids, indices.clone())
422
+ prefix_len = self.tree_cache.insert(
423
+ token_ids[:seq_len], indices.clone()
424
+ )
412
425
 
413
426
  self.token_to_kv_pool.free(indices[:prefix_len])
414
427
  self.req_to_token_pool.free(req_pool_idx)
@@ -116,12 +116,12 @@ class RadixCache:
116
116
  for c_key, child in node.children.items():
117
117
  prefix_len = match(c_key, key)
118
118
  if prefix_len != 0:
119
- if prefix_len == len(key) and prefix_len != len(c_key):
119
+ if prefix_len < len(c_key):
120
120
  new_node = self._split_node(c_key, child, prefix_len)
121
121
  value.append(new_node.value)
122
122
  last_node[0] = new_node
123
123
  else:
124
- value.append(child.value[:prefix_len])
124
+ value.append(child.value)
125
125
  last_node[0] = child
126
126
  self._match_prefix_helper(child, key[prefix_len:], value, last_node)
127
127
  break
@@ -18,7 +18,7 @@ class Scheduler:
18
18
  self.tree_cache = tree_cache
19
19
 
20
20
  def new_token_estimation_ratio(self):
21
- return 0.4 if self.schedule_heuristic != "fcfs" else 0.5
21
+ return 0.5 if self.schedule_heuristic != "fcfs" else 0.6
22
22
 
23
23
  def get_priority_queue(self, forward_queue):
24
24
  if self.schedule_heuristic == "lpm":
@@ -351,7 +351,7 @@ class MixtralForCausalLM(nn.Module):
351
351
 
352
352
  params_dict = dict(self.named_parameters())
353
353
  for name, loaded_weight in hf_model_weights_iterator(
354
- model_name_or_path, cache_dir, load_format, revision, fall_back_to_pt=False
354
+ model_name_or_path, cache_dir, load_format, revision
355
355
  ):
356
356
  if "rotary_emb.inv_freq" in name:
357
357
  continue
@@ -12,7 +12,7 @@ class ServerArgs:
12
12
  load_format: str = "auto"
13
13
  tokenizer_mode: str = "auto"
14
14
  trust_remote_code: bool = True
15
- mem_fraction_static: float = 0.91
15
+ mem_fraction_static: Optional[float] = None
16
16
  tp_size: int = 1
17
17
  model_mode: List[str] = ()
18
18
  schedule_heuristic: str = "lpm"
@@ -24,6 +24,11 @@ class ServerArgs:
24
24
  def __post_init__(self):
25
25
  if self.tokenizer_path is None:
26
26
  self.tokenizer_path = self.model_path
27
+ if self.mem_fraction_static is None:
28
+ if self.tp_size > 1:
29
+ self.mem_fraction_static = 0.8
30
+ else:
31
+ self.mem_fraction_static = 0.9
27
32
 
28
33
  @staticmethod
29
34
  def add_cli_args(parser: argparse.ArgumentParser):
@@ -88,7 +93,8 @@ class ServerArgs:
88
93
  type=str,
89
94
  default=[],
90
95
  nargs="+",
91
- help="Model mode: [flashinfer, no-cache, aggressive-new-fill]",
96
+ choices=["flashinfer", "no-cache"],
97
+ help="Model mode: [flashinfer, no-cache]",
92
98
  )
93
99
  parser.add_argument(
94
100
  "--schedule-heuristic",
@@ -174,7 +174,7 @@ def test_tool_use():
174
174
  def tool_use(s, lhs, rhs):
175
175
  s += "Please perform computations using a calculator. You can use calculate(expression) to get the results.\n"
176
176
  s += "For example,\ncalculate(1+2)=3\ncalculate(3*4)=12\n"
177
- s += "Question: What is the product of " + lhs + " and " + rhs + "?\n"
177
+ s += "Question: What is the product of " + str(lhs) + " and " + str(rhs) + "?\n"
178
178
  s += (
179
179
  "Answer: The answer is calculate("
180
180
  + sgl.gen("expression", stop=")")
@@ -38,6 +38,26 @@ def call_generate_vllm(prompt, temperature, max_tokens, stop, url, n=1):
38
38
  return pred
39
39
 
40
40
 
41
+ def call_generate_outlines(
42
+ prompt, temperature, max_tokens, url, stop=[], regex=None, n=1
43
+ ):
44
+ data = {
45
+ "prompt": prompt,
46
+ "temperature": temperature,
47
+ "max_tokens": max_tokens,
48
+ "stop": stop,
49
+ "regex": regex,
50
+ "n": n,
51
+ }
52
+ res = requests.post(url, json=data)
53
+ assert res.status_code == 200
54
+ if n == 1:
55
+ pred = res.json()["text"][0][len(prompt) :]
56
+ else:
57
+ pred = [x[len(prompt) :] for x in res.json()["text"]]
58
+ return pred
59
+
60
+
41
61
  def call_generate_srt_raw(prompt, temperature, max_tokens, stop, url):
42
62
  data = {
43
63
  "text": prompt,
@@ -79,7 +99,7 @@ def call_select_vllm(context, choices, url):
79
99
  }
80
100
  res = requests.post(url, json=data)
81
101
  assert res.status_code == 200
82
- scores.append(res.json()["prompt_score"])
102
+ scores.append(res.json().get("prompt_score", 0))
83
103
  return np.argmax(scores)
84
104
 
85
105
  """
@@ -92,7 +112,7 @@ def call_select_vllm(context, choices, url):
92
112
 
93
113
 
94
114
  def add_common_other_args_and_parse(parser):
95
- parser.add_argument("--parallel", type=int, default=96)
115
+ parser.add_argument("--parallel", type=int, default=64)
96
116
  parser.add_argument("--host", type=str, default="http://127.0.0.1")
97
117
  parser.add_argument("--port", type=int, default=None)
98
118
  parser.add_argument(
@@ -67,7 +67,7 @@ def dump_state_text(filename, states, mode="w"):
67
67
  if isinstance(s, str):
68
68
  pass
69
69
  elif isinstance(s, ProgramState):
70
- s = s.text().strip()
70
+ s = s.text()
71
71
  else:
72
72
  s = str(s)
73
73
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: sglang
3
- Version: 0.1.1
3
+ Version: 0.1.3
4
4
  Summary: A structured generation langauge for LLMs.
5
5
  License: Apache License
6
6
  Version 2.0, January 2004
@@ -329,25 +329,99 @@ You can find more examples at [examples/quick_start](examples/quick_start).
329
329
 
330
330
  ## Frontend: Structured Generation Langauge (SGLang)
331
331
 
332
+ To begin with, import sglang.
333
+ ```python
334
+ import sglang as sgl
335
+ ```
336
+
337
+ `sglang` provides some simple primitives such as `gen`, `select`, `fork`.
338
+ You can implement your prompt flow in a function decorated by `sgl.function`.
339
+ You can then invoke the function with `run` or `run_batch`.
340
+ The system will manage the state, chat template, and parallelism for you.
341
+
332
342
  ### Control Flow
343
+ ```python
344
+ @sgl.function
345
+ def control_flow(s, question):
346
+ s += "To answer this question: " + question + ", "
347
+ s += "I need to use a " + sgl.gen("tool", choices=["calculator", "web browser"]) + ". "
348
+
349
+ # You can use if or nested function calls
350
+ if s["tool"] == "calculator":
351
+ s += "The math expression is" + sgl.gen("expression")
352
+ elif s["tool"] == "web browser":
353
+ s += "The website url is" + sgl.gen("url")
354
+ ```
333
355
 
334
356
  ### Parallelism
357
+ ```python
358
+ @sgl.function
359
+ def tip_suggestion(s):
360
+ s += (
361
+ "Here are two tips for staying healthy: "
362
+ "1. Balanced Diet. 2. Regular Exercise.\n\n"
363
+ )
364
+
365
+ forks = s.fork(2) # Launch parallel prompts
366
+ for i, f in enumerate(forks):
367
+ f += f"Now, expand tip {i+1} into a paragraph:\n"
368
+ f += sgl.gen(f"detailed_tip", max_tokens=256, stop="\n\n")
369
+
370
+ s += "Tip 1:" + forks[0]["detailed_tip"] + "\n"
371
+ s += "Tip 2:" + forks[1]["detailed_tip"] + "\n"
372
+ s += "In summary" + sgl.gen("summary")
373
+ ```
335
374
 
336
375
  ### Multi Modality
337
376
  ```python
338
377
  @sgl.function
339
378
  def image_qa(s, image_file, question):
340
379
  s += sgl.user(sgl.image(image_file) + question)
341
- s += sgl.assistant(sgl.gen("answer_1", max_tokens=256))
380
+ s += sgl.assistant(sgl.gen("answer", max_tokens=256)
342
381
  ```
343
382
 
344
- ### Constrained decoding
383
+ ### Constrained Decoding
384
+ ```python
385
+ @function
386
+ def regular_expression_gen(s):
387
+ s += "Q: What is the IP address of the Google DNS servers?\n"
388
+ s += "A: " + gen(
389
+ "answer",
390
+ temperature=0,
391
+ regex=r"((25[0-5]|2[0-4]\d|[01]?\d\d?).){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)",
392
+ )
393
+ ```
345
394
 
346
395
  ### Batching
396
+ ```python
397
+ @sgl.function
398
+ def text_qa(s, question):
399
+ s += "Q: " + question + "\n"
400
+ s += "A:" + sgl.gen("answer", stop="\n")
401
+
402
+ states = text_qa.run_batch(
403
+ [
404
+ {"question": "What is the capital of the United Kingdom?"},
405
+ {"question": "What is the capital of France?"},
406
+ {"question": "What is the capital of Japan?"},
407
+ ],
408
+ )
409
+ ```
347
410
 
348
411
  ### Streaming
412
+ ```python
413
+ @sgl.function
414
+ def text_qa(s, question):
415
+ s += "Q: " + question + "\n"
416
+ s += "A:" + sgl.gen("answer", stop="\n")
417
+
418
+ states = text_qa.run(
419
+ question="What is the capital of France?",
420
+ temperature=0.1)
349
421
 
350
- ### Other Backends
422
+ for out in state.text_iter():
423
+ print(out, end="", flush=True)
424
+ ```
351
425
 
352
426
  ## Backend: SGLang Runtime (SRT)
353
427
  The SGLang Runtime (SRT) is designed to work best with the SGLang frontend.
@@ -386,6 +460,14 @@ python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port
386
460
 
387
461
  ## Benchmark And Performance
388
462
 
463
+ - Llama-7B on NVIDIA A10G, FP16, Tensor Parallelism=1
464
+ ![llama_7b](assets/llama_7b.jpg)
465
+
466
+ - Mixtral-8x7B on NVIDIA A10G, FP16, Tensor Parallelism=8
467
+ ![mixtral_8x7b](assets/mixtral_8x7b.jpg)
468
+
469
+ Learn more [here]().
470
+
389
471
  ## Roadmap
390
472
  - [ ] Function call
391
473
  - [ ] Quantization
@@ -25,6 +25,7 @@ sglang/lang/compiler.py
25
25
  sglang/lang/interpreter.py
26
26
  sglang/lang/ir.py
27
27
  sglang/lang/tracer.py
28
+ sglang/srt/backend_config.py
28
29
  sglang/srt/hf_transformers_utils.py
29
30
  sglang/srt/memory_pool.py
30
31
  sglang/srt/model_config.py
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
@@ -7,13 +7,13 @@ _SAMPLING_EPS = 1e-6
7
7
  class SamplingParams:
8
8
  def __init__(
9
9
  self,
10
+ max_new_tokens: int = 16,
11
+ stop: Optional[Union[str, List[str]]] = None,
10
12
  temperature: float = 1.0,
11
13
  top_p: float = 1.0,
12
14
  top_k: int = -1,
13
15
  frequency_penalty: float = 0.0,
14
16
  presence_penalty: float = 0.0,
15
- stop: Optional[Union[str, List[str]]] = None,
16
- max_new_tokens: int = 16,
17
17
  ignore_eos: bool = False,
18
18
  skip_special_tokens: bool = True,
19
19
  dtype: Optional[str] = None,
File without changes
File without changes