sglang 0.1.18__py3-none-any.whl → 0.1.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. sglang/__init__.py +1 -1
  2. sglang/api.py +26 -0
  3. sglang/backend/runtime_endpoint.py +18 -14
  4. sglang/bench_latency.py +34 -16
  5. sglang/global_config.py +1 -0
  6. sglang/lang/chat_template.py +41 -6
  7. sglang/lang/interpreter.py +5 -1
  8. sglang/lang/ir.py +61 -25
  9. sglang/srt/constrained/__init__.py +3 -2
  10. sglang/srt/hf_transformers_utils.py +7 -3
  11. sglang/srt/layers/extend_attention.py +2 -1
  12. sglang/srt/layers/fused_moe.py +181 -167
  13. sglang/srt/layers/logits_processor.py +55 -19
  14. sglang/srt/layers/radix_attention.py +24 -27
  15. sglang/srt/layers/token_attention.py +4 -1
  16. sglang/srt/managers/controller/infer_batch.py +2 -2
  17. sglang/srt/managers/controller/manager_single.py +1 -1
  18. sglang/srt/managers/controller/model_runner.py +27 -15
  19. sglang/srt/managers/controller/tp_worker.py +31 -14
  20. sglang/srt/managers/detokenizer_manager.py +4 -2
  21. sglang/srt/managers/io_struct.py +1 -1
  22. sglang/srt/managers/tokenizer_manager.py +14 -13
  23. sglang/srt/model_config.py +6 -0
  24. sglang/srt/models/gemma2.py +436 -0
  25. sglang/srt/models/llama2.py +3 -3
  26. sglang/srt/models/llama_classification.py +10 -7
  27. sglang/srt/models/minicpm.py +373 -0
  28. sglang/srt/models/qwen2_moe.py +454 -0
  29. sglang/srt/openai_api_adapter.py +2 -2
  30. sglang/srt/openai_protocol.py +1 -1
  31. sglang/srt/server.py +17 -8
  32. sglang/srt/server_args.py +14 -16
  33. sglang/srt/utils.py +68 -35
  34. {sglang-0.1.18.dist-info → sglang-0.1.19.dist-info}/METADATA +19 -13
  35. {sglang-0.1.18.dist-info → sglang-0.1.19.dist-info}/RECORD +38 -35
  36. {sglang-0.1.18.dist-info → sglang-0.1.19.dist-info}/LICENSE +0 -0
  37. {sglang-0.1.18.dist-info → sglang-0.1.19.dist-info}/WHEEL +0 -0
  38. {sglang-0.1.18.dist-info → sglang-0.1.19.dist-info}/top_level.txt +0 -0
sglang/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- __version__ = "0.1.18"
1
+ __version__ = "0.1.19"
2
2
 
3
3
  # SGL API Components
4
4
  from sglang.api import (
sglang/api.py CHANGED
@@ -67,10 +67,16 @@ def gen(
67
67
  frequency_penalty: Optional[float] = None,
68
68
  presence_penalty: Optional[float] = None,
69
69
  ignore_eos: Optional[bool] = None,
70
+ return_logprob: Optional[bool] = None,
71
+ logprob_start_len: Optional[int] = None,
72
+ top_logprobs_num: Optional[int] = None,
73
+ return_text_in_logprobs: Optional[bool] = None,
70
74
  dtype: Optional[type] = None,
71
75
  choices: Optional[List[str]] = None,
72
76
  regex: Optional[str] = None,
73
77
  ):
78
+ """Call the model to generate. See the meaning of the arguments in docs/sampling_params.md"""
79
+
74
80
  if choices:
75
81
  return SglSelect(name, choices, 0.0 if temperature is None else temperature)
76
82
 
@@ -91,6 +97,10 @@ def gen(
91
97
  frequency_penalty,
92
98
  presence_penalty,
93
99
  ignore_eos,
100
+ return_logprob,
101
+ logprob_start_len,
102
+ top_logprobs_num,
103
+ return_text_in_logprobs,
94
104
  dtype,
95
105
  regex,
96
106
  )
@@ -106,6 +116,10 @@ def gen_int(
106
116
  frequency_penalty: Optional[float] = None,
107
117
  presence_penalty: Optional[float] = None,
108
118
  ignore_eos: Optional[bool] = None,
119
+ return_logprob: Optional[bool] = None,
120
+ logprob_start_len: Optional[int] = None,
121
+ top_logprobs_num: Optional[int] = None,
122
+ return_text_in_logprobs: Optional[bool] = None,
109
123
  ):
110
124
  return SglGen(
111
125
  name,
@@ -117,6 +131,10 @@ def gen_int(
117
131
  frequency_penalty,
118
132
  presence_penalty,
119
133
  ignore_eos,
134
+ return_logprob,
135
+ logprob_start_len,
136
+ top_logprobs_num,
137
+ return_text_in_logprobs,
120
138
  int,
121
139
  None,
122
140
  )
@@ -132,6 +150,10 @@ def gen_string(
132
150
  frequency_penalty: Optional[float] = None,
133
151
  presence_penalty: Optional[float] = None,
134
152
  ignore_eos: Optional[bool] = None,
153
+ return_logprob: Optional[bool] = None,
154
+ logprob_start_len: Optional[int] = None,
155
+ top_logprobs_num: Optional[int] = None,
156
+ return_text_in_logprobs: Optional[bool] = None,
135
157
  ):
136
158
  return SglGen(
137
159
  name,
@@ -143,6 +165,10 @@ def gen_string(
143
165
  frequency_penalty,
144
166
  presence_penalty,
145
167
  ignore_eos,
168
+ return_logprob,
169
+ logprob_start_len,
170
+ top_logprobs_num,
171
+ return_text_in_logprobs,
146
172
  str,
147
173
  None,
148
174
  )
@@ -1,18 +1,18 @@
1
1
  import json
2
- from typing import Callable, List, Optional, Union
2
+ from typing import List, Optional
3
3
 
4
4
  import numpy as np
5
- import requests
6
5
 
7
6
  from sglang.backend.base_backend import BaseBackend
8
7
  from sglang.global_config import global_config
9
8
  from sglang.lang.chat_template import get_chat_template_by_model_path
10
9
  from sglang.lang.interpreter import StreamExecutor
11
- from sglang.lang.ir import SglArgument, SglSamplingParams
12
- from sglang.utils import encode_image_base64, find_printable_text, http_request
10
+ from sglang.lang.ir import SglSamplingParams
11
+ from sglang.utils import http_request
13
12
 
14
13
 
15
14
  class RuntimeEndpoint(BaseBackend):
15
+
16
16
  def __init__(
17
17
  self,
18
18
  base_url: str,
@@ -38,8 +38,7 @@ class RuntimeEndpoint(BaseBackend):
38
38
  self.model_info = res.json()
39
39
 
40
40
  self.chat_template = get_chat_template_by_model_path(
41
- self.model_info["model_path"]
42
- )
41
+ self.model_info["model_path"])
43
42
 
44
43
  def get_model_name(self):
45
44
  return self.model_info["model_path"]
@@ -125,6 +124,11 @@ class RuntimeEndpoint(BaseBackend):
125
124
  else:
126
125
  raise RuntimeError(f"Invalid dtype: {sampling_params.dtype}")
127
126
 
127
+ for item in ["return_logprob", "logprob_start_len", "top_logprobs_num", "return_text_in_logprobs"]:
128
+ value = getattr(sampling_params, item, None)
129
+ if value is not None:
130
+ data[item] = value
131
+
128
132
  self._add_images(s, data)
129
133
 
130
134
  res = http_request(
@@ -167,6 +171,11 @@ class RuntimeEndpoint(BaseBackend):
167
171
  else:
168
172
  raise RuntimeError(f"Invalid dtype: {sampling_params.dtype}")
169
173
 
174
+ for item in ["return_logprob", "logprob_start_len", "top_logprobs_num", "return_text_in_logprobs"]:
175
+ value = getattr(sampling_params, item, None)
176
+ if value is not None:
177
+ data[item] = value
178
+
170
179
  data["stream"] = True
171
180
  self._add_images(s, data)
172
181
 
@@ -181,21 +190,16 @@ class RuntimeEndpoint(BaseBackend):
181
190
  self._assert_success(res)
182
191
  pos = 0
183
192
 
184
- incomplete_text = ""
185
193
  for chunk in res.iter_lines(decode_unicode=False):
186
194
  chunk = chunk.decode("utf-8")
187
195
  if chunk and chunk.startswith("data:"):
188
196
  if chunk == "data: [DONE]":
189
197
  break
190
198
  data = json.loads(chunk[5:].strip("\n"))
191
- text = find_printable_text(data["text"][pos:])
199
+ chunk_text = data["text"][pos:]
192
200
  meta_info = data["meta_info"]
193
- pos += len(text)
194
- incomplete_text = data["text"][pos:]
195
- yield text, meta_info
196
-
197
- if len(incomplete_text) > 0:
198
- yield incomplete_text, meta_info
201
+ pos += len(chunk_text)
202
+ yield chunk_text, meta_info
199
203
 
200
204
  def select(
201
205
  self,
sglang/bench_latency.py CHANGED
@@ -108,7 +108,7 @@ def prepare_inputs(bench_args, tokenizer):
108
108
  for i in range(len(prompts)):
109
109
  assert len(input_ids[i]) > bench_args.cut_len
110
110
 
111
- tmp_input_ids = input_ids[i][:bench_args.cut_len]
111
+ tmp_input_ids = input_ids[i][: bench_args.cut_len]
112
112
  req = Req(rid=i, origin_input_text=prompts[i], origin_input_ids=tmp_input_ids)
113
113
  req.prefix_indices = []
114
114
  req.sampling_params = sampling_params
@@ -121,9 +121,9 @@ def prepare_inputs(bench_args, tokenizer):
121
121
  def prepare_extend_inputs(bench_args, input_ids, reqs, model_runner):
122
122
  for i in range(len(reqs)):
123
123
  req = reqs[i]
124
- req.input_ids += input_ids[i][bench_args.cut_len:]
124
+ req.input_ids += input_ids[i][bench_args.cut_len :]
125
125
  req.prefix_indices = model_runner.req_to_token_pool.req_to_token[
126
- i, :bench_args.cut_len
126
+ i, : bench_args.cut_len
127
127
  ]
128
128
  return reqs
129
129
 
@@ -151,7 +151,8 @@ def extend(reqs, model_runner):
151
151
  reqs=reqs,
152
152
  req_to_token_pool=model_runner.req_to_token_pool,
153
153
  token_to_kv_pool=model_runner.token_to_kv_pool,
154
- tree_cache=None)
154
+ tree_cache=None,
155
+ )
155
156
  batch.prepare_for_extend(model_runner.model_config.vocab_size, None)
156
157
  output = model_runner.forward(batch, ForwardMode.EXTEND)
157
158
  next_token_ids, _ = batch.sample(output.next_token_logits)
@@ -165,6 +166,7 @@ def decode(input_token_ids, batch, model_runner):
165
166
  return next_token_ids, output.next_token_logits
166
167
 
167
168
 
169
+ @torch.inference_mode()
168
170
  def correctness_test(
169
171
  server_args,
170
172
  bench_args,
@@ -178,9 +180,10 @@ def correctness_test(
178
180
  # Prepare inputs
179
181
  input_ids, reqs = prepare_inputs(bench_args, tokenizer)
180
182
 
181
- # Prefill
182
- next_token_ids, next_token_logits, batch = extend(reqs, model_runner)
183
- rank_print("prefill logits (first half)", next_token_logits)
183
+ if bench_args.cut_len > 0:
184
+ # Prefill
185
+ next_token_ids, next_token_logits, batch = extend(reqs, model_runner)
186
+ rank_print("prefill logits (first half)", next_token_logits)
184
187
 
185
188
  # Prepare extend inputs
186
189
  reqs = prepare_extend_inputs(bench_args, input_ids, reqs, model_runner)
@@ -190,7 +193,7 @@ def correctness_test(
190
193
  rank_print("prefill logits (final)", next_token_logits)
191
194
 
192
195
  # Decode
193
- output_ids = [list(req.input_ids) for req in reqs]
196
+ output_ids = [input_ids[i] + [next_token_ids[i]] for i in range(len(input_ids))]
194
197
  for _ in range(bench_args.output_len):
195
198
  next_token_ids, _ = decode(next_token_ids, batch, model_runner)
196
199
  for i in range(len(reqs)):
@@ -210,7 +213,9 @@ def latency_test(
210
213
 
211
214
  # Load the model
212
215
  model_runner, tokenizer = load_model(server_args, tp_rank)
213
- print(f"max_batch_size={model_runner.max_total_num_tokens // (bench_args.input_len + bench_args.output_len)}")
216
+ print(
217
+ f"max_batch_size={model_runner.max_total_num_tokens // (bench_args.input_len + bench_args.output_len)}"
218
+ )
214
219
 
215
220
  # Prepare inputs
216
221
  reqs = prepare_synthetic_inputs(bench_args, tokenizer)
@@ -230,7 +235,9 @@ def latency_test(
230
235
  prefill_latency = time.time() - tic
231
236
  tot_latency += prefill_latency
232
237
  throughput = bench_args.input_len * bench_args.batch_size / prefill_latency
233
- rank_print(f"Prefill. latency: {prefill_latency:6.5f} s, throughput: {throughput:9.2f} token/s")
238
+ rank_print(
239
+ f"Prefill. latency: {prefill_latency:6.5f} s, throughput: {throughput:9.2f} token/s"
240
+ )
234
241
 
235
242
  # Decode
236
243
  for i in range(output_len):
@@ -241,13 +248,24 @@ def latency_test(
241
248
  latency = time.time() - tic
242
249
  tot_latency += latency
243
250
  throughput = bench_args.batch_size / latency
244
- if i < 5: rank_print(f"Decode. latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s")
251
+ if i < 5:
252
+ rank_print(
253
+ f"Decode. latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s"
254
+ )
245
255
  avg_decode_latency = (tot_latency - prefill_latency) / output_len
246
256
  avg_decode_throughput = bench_args.batch_size / avg_decode_latency
247
- rank_print(f"Decode. avg latency: {avg_decode_latency:6.5f} s, avg throughput: {avg_decode_throughput:9.2f} token/s")
248
-
249
- throughput = (bench_args.input_len + bench_args.output_len) * bench_args.batch_size / tot_latency
250
- rank_print(f"Total. latency: {tot_latency:6.3f} s, throughput: {throughput:9.2f} token/s")
257
+ rank_print(
258
+ f"Decode. avg latency: {avg_decode_latency:6.5f} s, avg throughput: {avg_decode_throughput:9.2f} token/s"
259
+ )
260
+
261
+ throughput = (
262
+ (bench_args.input_len + bench_args.output_len)
263
+ * bench_args.batch_size
264
+ / tot_latency
265
+ )
266
+ rank_print(
267
+ f"Total. latency: {tot_latency:6.3f} s, throughput: {throughput:9.2f} token/s"
268
+ )
251
269
 
252
270
  # Warm up
253
271
  run_once(4)
@@ -296,4 +314,4 @@ if __name__ == "__main__":
296
314
  format="%(message)s",
297
315
  )
298
316
 
299
- main(server_args, bench_args)
317
+ main(server_args, bench_args)
sglang/global_config.py CHANGED
@@ -39,4 +39,5 @@ class GlobalConfig:
39
39
  # This can improve the speed for large batch sizes during prefill.
40
40
  self.layer_sync_threshold = 8192
41
41
 
42
+
42
43
  global_config = GlobalConfig()
@@ -84,7 +84,7 @@ register_chat_template(
84
84
  "system": ("SYSTEM:", "\n"),
85
85
  "user": ("USER:", "\n"),
86
86
  "assistant": ("ASSISTANT:", "\n"),
87
- },
87
+ }
88
88
  )
89
89
  )
90
90
 
@@ -116,6 +116,23 @@ register_chat_template(
116
116
  )
117
117
  )
118
118
 
119
+ # There is default system prompt for qwen
120
+ # reference: https://modelscope.cn/models/qwen/Qwen2-72B-Instruct/file/view/master?fileName=tokenizer_config.json&status=1
121
+ # The chat template is: "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
122
+ register_chat_template(
123
+ ChatTemplate(
124
+ name="qwen",
125
+ default_system_prompt="You are a helpful assistant.",
126
+ role_prefix_and_suffix={
127
+ "system": ("<|im_start|>system\n", "<|im_end|>\n"),
128
+ "user": ("<|im_start|>user\n", "<|im_end|>\n"),
129
+ "assistant": ("<|im_start|>assistant\n", "<|im_end|>\n"),
130
+ },
131
+ style=ChatTemplateStyle.PLAIN,
132
+ stop_str=("<|im_end|>",),
133
+ )
134
+ )
135
+
119
136
 
120
137
  register_chat_template(
121
138
  ChatTemplate(
@@ -132,6 +149,7 @@ register_chat_template(
132
149
  )
133
150
  )
134
151
 
152
+ # Reference: https://github.com/lm-sys/FastChat/blob/main/docs/vicuna_weights_version.md#prompt-template
135
153
  register_chat_template(
136
154
  ChatTemplate(
137
155
  name="vicuna_v1.1",
@@ -148,6 +166,20 @@ register_chat_template(
148
166
  )
149
167
  )
150
168
 
169
+ # Reference: https://modelscope.cn/models/01ai/Yi-1.5-34B-Chat/file/view/master?fileName=tokenizer_config.json&status=1
170
+ register_chat_template(
171
+ ChatTemplate(
172
+ name="yi-1.5",
173
+ default_system_prompt=None,
174
+ role_prefix_and_suffix={
175
+ "system": ("", ""),
176
+ "user": ("<|im_start|>user\n", "<|im_end|>\n<|im_start|>assistant\n"),
177
+ "assistant": ("", "<|im_end|>\n"),
178
+ },
179
+ style=ChatTemplateStyle.PLAIN,
180
+ stop_str=("<|im_end|>",)
181
+ )
182
+ )
151
183
 
152
184
  register_chat_template(
153
185
  ChatTemplate(
@@ -187,7 +219,7 @@ register_chat_template(
187
219
  # Reference: https://github.com/01-ai/Yi/tree/main/VL#major-difference-with-llava
188
220
  register_chat_template(
189
221
  ChatTemplate(
190
- name="yi",
222
+ name="yi-vl",
191
223
  default_system_prompt=(
192
224
  "This is a chat between an inquisitive human and an AI assistant. Assume the role of the AI assistant. Read all the images carefully, and respond to the human's questions with informative, helpful, detailed and polite answers."
193
225
  "这是一个好奇的人类和一个人工智能助手之间的对话。假设你扮演这个AI助手的角色。仔细阅读所有的图像,并对人类的问题做出信息丰富、有帮助、详细的和礼貌的回答。"
@@ -289,8 +321,9 @@ def match_chat_ml(model_path: str):
289
321
  model_path = model_path.lower()
290
322
  if "tinyllama" in model_path:
291
323
  return get_chat_template("chatml")
292
- if "qwen" in model_path and "chat" in model_path:
293
- return get_chat_template("chatml")
324
+ # Now the suffix for qwen2 chat model is "instruct"
325
+ if "qwen" in model_path and ("chat" in model_path or "instruct" in model_path):
326
+ return get_chat_template("qwen")
294
327
  if (
295
328
  "llava-v1.6-34b" in model_path
296
329
  or "llava-v1.6-yi-34b" in model_path
@@ -302,8 +335,10 @@ def match_chat_ml(model_path: str):
302
335
  @register_chat_template_matching_function
303
336
  def match_chat_yi(model_path: str):
304
337
  model_path = model_path.lower()
305
- if "yi" in model_path and "llava" not in model_path:
306
- return get_chat_template("yi")
338
+ if "yi-vl" in model_path and "llava" not in model_path:
339
+ return get_chat_template("yi-vl")
340
+ elif "yi-1.5" in model_path and "chat" in model_path:
341
+ return get_chat_template("yi-1.5")
307
342
 
308
343
 
309
344
  @register_chat_template_matching_function
@@ -523,9 +523,9 @@ class StreamExecutor:
523
523
  self, sampling_params=sampling_params
524
524
  )
525
525
 
526
+ self.variables[name] = ""
526
527
  self.stream_var_event[name].set()
527
528
 
528
- self.variables[name] = ""
529
529
  for comp, meta_info in generator:
530
530
  self.text_ += comp
531
531
  self.variables[name] += comp
@@ -668,6 +668,10 @@ class StreamExecutor:
668
668
  "frequency_penalty",
669
669
  "presence_penalty",
670
670
  "ignore_eos",
671
+ "return_logprob",
672
+ "logprob_start_len",
673
+ "top_logprobs_num",
674
+ "return_text_in_logprobs",
671
675
  "dtype",
672
676
  "regex",
673
677
  ]:
sglang/lang/ir.py CHANGED
@@ -23,6 +23,10 @@ class SglSamplingParams:
23
23
  frequency_penalty: float = 0.0
24
24
  presence_penalty: float = 0.0
25
25
  ignore_eos: bool = False
26
+ return_logprob: Optional[bool] = None
27
+ logprob_start_len: Optional[int] = None,
28
+ top_logprobs_num: Optional[int] = None,
29
+ return_text_in_logprobs: Optional[bool] = None,
26
30
 
27
31
  # for constrained generation, not included in to_xxx_kwargs
28
32
  dtype: Optional[str] = None
@@ -37,6 +41,11 @@ class SglSamplingParams:
37
41
  self.top_k,
38
42
  self.frequency_penalty,
39
43
  self.presence_penalty,
44
+ self.ignore_eos,
45
+ self.return_logprob,
46
+ self.logprob_start_len,
47
+ self.top_logprobs_num,
48
+ self.return_text_in_logprobs,
40
49
  )
41
50
 
42
51
  def to_openai_kwargs(self):
@@ -139,6 +148,10 @@ class SglFunction:
139
148
  frequency_penalty: float = 0.0,
140
149
  presence_penalty: float = 0.0,
141
150
  ignore_eos: bool = False,
151
+ return_logprob: Optional[bool] = None,
152
+ logprob_start_len: Optional[int] = None,
153
+ top_logprobs_num: Optional[int] = None,
154
+ return_text_in_logprobs: Optional[bool] = None,
142
155
  stream: bool = False,
143
156
  backend=None,
144
157
  **kwargs,
@@ -154,6 +167,10 @@ class SglFunction:
154
167
  frequency_penalty=frequency_penalty,
155
168
  presence_penalty=presence_penalty,
156
169
  ignore_eos=ignore_eos,
170
+ return_logprob=return_logprob,
171
+ logprob_start_len=logprob_start_len,
172
+ top_logprobs_num=top_logprobs_num,
173
+ return_text_in_logprobs=return_text_in_logprobs,
157
174
  )
158
175
  backend = backend or global_config.default_backend
159
176
  return run_program(self, backend, args, kwargs, default_sampling_para, stream)
@@ -170,6 +187,10 @@ class SglFunction:
170
187
  frequency_penalty: float = 0.0,
171
188
  presence_penalty: float = 0.0,
172
189
  ignore_eos: bool = False,
190
+ return_logprob: Optional[bool] = None,
191
+ logprob_start_len: Optional[int] = None,
192
+ top_logprobs_num: Optional[int] = None,
193
+ return_text_in_logprobs: Optional[bool] = None,
173
194
  backend=None,
174
195
  num_threads: Union[str, int] = "auto",
175
196
  progress_bar: bool = False,
@@ -185,8 +206,10 @@ class SglFunction:
185
206
  batch_kwargs = [
186
207
  {self.arg_names[i]: v for i, v in enumerate(arg_values)}
187
208
  for arg_values in batch_kwargs
188
- if isinstance(arg_values, (list, tuple)) and
189
- len(self.arg_names) - len(self.arg_defaults) <= len(arg_values) <= len(self.arg_names)
209
+ if isinstance(arg_values, (list, tuple))
210
+ and len(self.arg_names) - len(self.arg_defaults)
211
+ <= len(arg_values)
212
+ <= len(self.arg_names)
190
213
  ]
191
214
  # Ensure to raise an exception if the number of arguments mismatch
192
215
  if len(batch_kwargs) != num_programs:
@@ -201,6 +224,10 @@ class SglFunction:
201
224
  frequency_penalty=frequency_penalty,
202
225
  presence_penalty=presence_penalty,
203
226
  ignore_eos=ignore_eos,
227
+ return_logprob=return_logprob,
228
+ logprob_start_len=logprob_start_len,
229
+ top_logprobs_num=top_logprobs_num,
230
+ return_text_in_logprobs=return_text_in_logprobs,
204
231
  )
205
232
  backend = backend or global_config.default_backend
206
233
  return run_program_batch(
@@ -348,7 +375,7 @@ class SglArgument(SglExpr):
348
375
 
349
376
 
350
377
  class SglImage(SglExpr):
351
- def __init__(self, path):
378
+ def __init__(self, path: str):
352
379
  self.path = path
353
380
 
354
381
  def __repr__(self) -> str:
@@ -356,7 +383,7 @@ class SglImage(SglExpr):
356
383
 
357
384
 
358
385
  class SglVideo(SglExpr):
359
- def __init__(self, path, num_frames):
386
+ def __init__(self, path: str, num_frames: int):
360
387
  self.path = path
361
388
  self.num_frames = num_frames
362
389
 
@@ -367,18 +394,23 @@ class SglVideo(SglExpr):
367
394
  class SglGen(SglExpr):
368
395
  def __init__(
369
396
  self,
370
- name,
371
- max_new_tokens,
372
- stop,
373
- temperature,
374
- top_p,
375
- top_k,
376
- frequency_penalty,
377
- presence_penalty,
378
- ignore_eos,
379
- dtype,
380
- regex,
397
+ name: Optional[str] = None,
398
+ max_new_tokens: Optional[int] = None,
399
+ stop: Optional[Union[str, List[str]]] = None,
400
+ temperature: Optional[float] = None,
401
+ top_p: Optional[float] = None,
402
+ top_k: Optional[int] = None,
403
+ frequency_penalty: Optional[float] = None,
404
+ presence_penalty: Optional[float] = None,
405
+ ignore_eos: Optional[bool] = None,
406
+ return_logprob: Optional[bool] = None,
407
+ logprob_start_len: Optional[int] = None,
408
+ top_logprobs_num: Optional[int] = None,
409
+ return_text_in_logprobs: Optional[bool] = None,
410
+ dtype: Optional[type] = None,
411
+ regex: Optional[str] = None,
381
412
  ):
413
+ """Call the model to generate. See the meaning of the arguments in docs/sampling_params.md"""
382
414
  super().__init__()
383
415
  self.name = name
384
416
  self.sampling_params = SglSamplingParams(
@@ -390,6 +422,10 @@ class SglGen(SglExpr):
390
422
  frequency_penalty=frequency_penalty,
391
423
  presence_penalty=presence_penalty,
392
424
  ignore_eos=ignore_eos,
425
+ return_logprob=return_logprob,
426
+ logprob_start_len=logprob_start_len,
427
+ top_logprobs_num=top_logprobs_num,
428
+ return_text_in_logprobs=return_text_in_logprobs,
393
429
  dtype=dtype,
394
430
  regex=regex,
395
431
  )
@@ -399,7 +435,7 @@ class SglGen(SglExpr):
399
435
 
400
436
 
401
437
  class SglConstantText(SglExpr):
402
- def __init__(self, value):
438
+ def __init__(self, value: str):
403
439
  super().__init__()
404
440
  self.value = value
405
441
 
@@ -408,7 +444,7 @@ class SglConstantText(SglExpr):
408
444
 
409
445
 
410
446
  class SglRoleBegin(SglExpr):
411
- def __init__(self, role):
447
+ def __init__(self, role: str):
412
448
  super().__init__()
413
449
  self.role = role
414
450
 
@@ -417,7 +453,7 @@ class SglRoleBegin(SglExpr):
417
453
 
418
454
 
419
455
  class SglRoleEnd(SglExpr):
420
- def __init__(self, role):
456
+ def __init__(self, role: str):
421
457
  super().__init__()
422
458
  self.role = role
423
459
 
@@ -426,7 +462,7 @@ class SglRoleEnd(SglExpr):
426
462
 
427
463
 
428
464
  class SglSelect(SglExpr):
429
- def __init__(self, name, choices, temperature):
465
+ def __init__(self, name: str, choices: List[str], temperature: float):
430
466
  super().__init__()
431
467
  self.name = name
432
468
  self.choices = choices
@@ -437,7 +473,7 @@ class SglSelect(SglExpr):
437
473
 
438
474
 
439
475
  class SglFork(SglExpr):
440
- def __init__(self, number, position_ids_offset=None):
476
+ def __init__(self, number: int, position_ids_offset=None):
441
477
  super().__init__()
442
478
  self.number = number
443
479
  self.position_ids_offset = position_ids_offset
@@ -450,7 +486,7 @@ class SglFork(SglExpr):
450
486
 
451
487
 
452
488
  class SglGetForkItem(SglExpr):
453
- def __init__(self, index):
489
+ def __init__(self, index: int):
454
490
  super().__init__()
455
491
  self.index = index
456
492
 
@@ -459,7 +495,7 @@ class SglGetForkItem(SglExpr):
459
495
 
460
496
 
461
497
  class SglVariable(SglExpr):
462
- def __init__(self, name, source):
498
+ def __init__(self, name: str, source):
463
499
  super().__init__()
464
500
  self.name = name
465
501
  self.source = source
@@ -469,7 +505,7 @@ class SglVariable(SglExpr):
469
505
 
470
506
 
471
507
  class SglVarScopeBegin(SglExpr):
472
- def __init__(self, name):
508
+ def __init__(self, name: str):
473
509
  super().__init__()
474
510
  self.name = name
475
511
 
@@ -478,7 +514,7 @@ class SglVarScopeBegin(SglExpr):
478
514
 
479
515
 
480
516
  class SglVarScopeEnd(SglExpr):
481
- def __init__(self, name):
517
+ def __init__(self, name: str):
482
518
  super().__init__()
483
519
  self.name = name
484
520
 
@@ -500,4 +536,4 @@ class SglCommitLazy(SglExpr):
500
536
  super().__init__()
501
537
 
502
538
  def __repr__(self):
503
- return f"CommitLazy()"
539
+ return "CommitLazy()"
@@ -5,13 +5,14 @@ from pydantic import BaseModel
5
5
 
6
6
  try:
7
7
  from outlines.caching import cache as disk_cache
8
- from outlines.fsm.guide import RegexGuide
9
8
  from outlines.caching import disable_cache
10
9
  from outlines.fsm.guide import RegexGuide
11
10
  from outlines.fsm.regex import FSMInfo, make_byte_level_fsm, make_deterministic_fsm
12
11
  from outlines.models.transformers import TransformerTokenizer
13
12
  except ImportError as e:
14
- print(f'\nError: {e}. Please install a new version of outlines by `pip install "outlines>=0.0.44"`\n')
13
+ print(
14
+ f'\nError: {e}. Please install a new version of outlines by `pip install "outlines>=0.0.44"`\n'
15
+ )
15
16
  raise
16
17
 
17
18
  try:
@@ -264,7 +264,9 @@ class TiktokenTokenizer:
264
264
  return self.tokenizer.decode_batch(batch)
265
265
 
266
266
  def apply_chat_template(self, messages, tokenize, add_generation_prompt):
267
- ret = self.chat_template.render(messages=messages, add_generation_prompt=add_generation_prompt)
267
+ ret = self.chat_template.render(
268
+ messages=messages, add_generation_prompt=add_generation_prompt
269
+ )
268
270
  return self.encode(ret) if tokenize else ret
269
271
 
270
272
 
@@ -297,5 +299,7 @@ class SentencePieceTokenizer:
297
299
  return self.tokenizer.decode(batch)
298
300
 
299
301
  def apply_chat_template(self, messages, tokenize, add_generation_prompt):
300
- ret = self.chat_template.render(messages=messages, add_generation_prompt=add_generation_prompt)
301
- return self.encode(ret) if tokenize else ret
302
+ ret = self.chat_template.render(
303
+ messages=messages, add_generation_prompt=add_generation_prompt
304
+ )
305
+ return self.encode(ret) if tokenize else ret
@@ -191,6 +191,7 @@ def extend_attention_fwd(
191
191
  b_seq_len_extend,
192
192
  max_len_in_batch,
193
193
  max_len_extend,
194
+ sm_scale=None,
194
195
  logit_cap=-1,
195
196
  ):
196
197
  """
@@ -213,7 +214,7 @@ def extend_attention_fwd(
213
214
  else:
214
215
  BLOCK_M, BLOCK_N = (64, 64) if Lq <= 128 else (32, 32)
215
216
 
216
- sm_scale = 1.0 / (Lq**0.5)
217
+ sm_scale = 1.0 / (Lq**0.5) if sm_scale is None else sm_scale
217
218
  batch_size, head_num = b_seq_len.shape[0], q_extend.shape[1]
218
219
  kv_group_num = q_extend.shape[1] // k_extend.shape[1]
219
220