sglang 0.1.16__py3-none-any.whl → 0.1.18__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. sglang/__init__.py +3 -1
  2. sglang/api.py +7 -7
  3. sglang/backend/anthropic.py +1 -1
  4. sglang/backend/litellm.py +90 -0
  5. sglang/backend/openai.py +158 -11
  6. sglang/backend/runtime_endpoint.py +18 -10
  7. sglang/bench_latency.py +299 -0
  8. sglang/global_config.py +12 -2
  9. sglang/lang/compiler.py +2 -2
  10. sglang/lang/interpreter.py +114 -67
  11. sglang/lang/ir.py +28 -3
  12. sglang/launch_server.py +4 -1
  13. sglang/launch_server_llavavid.py +2 -1
  14. sglang/srt/constrained/__init__.py +13 -6
  15. sglang/srt/constrained/fsm_cache.py +8 -2
  16. sglang/srt/constrained/jump_forward.py +113 -25
  17. sglang/srt/conversation.py +2 -0
  18. sglang/srt/flush_cache.py +3 -1
  19. sglang/srt/hf_transformers_utils.py +130 -1
  20. sglang/srt/layers/extend_attention.py +17 -0
  21. sglang/srt/layers/fused_moe.py +582 -0
  22. sglang/srt/layers/logits_processor.py +65 -32
  23. sglang/srt/layers/radix_attention.py +41 -7
  24. sglang/srt/layers/token_attention.py +16 -1
  25. sglang/srt/managers/controller/dp_worker.py +113 -0
  26. sglang/srt/managers/{router → controller}/infer_batch.py +242 -100
  27. sglang/srt/managers/controller/manager_multi.py +191 -0
  28. sglang/srt/managers/{router/manager.py → controller/manager_single.py} +34 -14
  29. sglang/srt/managers/{router → controller}/model_runner.py +262 -158
  30. sglang/srt/managers/{router → controller}/radix_cache.py +11 -1
  31. sglang/srt/managers/{router/scheduler.py → controller/schedule_heuristic.py} +9 -7
  32. sglang/srt/managers/{router/model_rpc.py → controller/tp_worker.py} +298 -267
  33. sglang/srt/managers/detokenizer_manager.py +42 -46
  34. sglang/srt/managers/io_struct.py +22 -12
  35. sglang/srt/managers/tokenizer_manager.py +151 -87
  36. sglang/srt/model_config.py +83 -5
  37. sglang/srt/models/chatglm.py +399 -0
  38. sglang/srt/models/commandr.py +10 -13
  39. sglang/srt/models/dbrx.py +9 -15
  40. sglang/srt/models/gemma.py +12 -15
  41. sglang/srt/models/grok.py +738 -0
  42. sglang/srt/models/llama2.py +26 -15
  43. sglang/srt/models/llama_classification.py +104 -0
  44. sglang/srt/models/llava.py +86 -19
  45. sglang/srt/models/llavavid.py +11 -20
  46. sglang/srt/models/mixtral.py +282 -103
  47. sglang/srt/models/mixtral_quant.py +372 -0
  48. sglang/srt/models/qwen.py +9 -13
  49. sglang/srt/models/qwen2.py +11 -13
  50. sglang/srt/models/stablelm.py +9 -15
  51. sglang/srt/models/yivl.py +17 -22
  52. sglang/srt/openai_api_adapter.py +150 -95
  53. sglang/srt/openai_protocol.py +11 -2
  54. sglang/srt/server.py +124 -48
  55. sglang/srt/server_args.py +128 -48
  56. sglang/srt/utils.py +234 -67
  57. sglang/test/test_programs.py +65 -3
  58. sglang/test/test_utils.py +32 -1
  59. sglang/utils.py +23 -4
  60. {sglang-0.1.16.dist-info → sglang-0.1.18.dist-info}/METADATA +40 -27
  61. sglang-0.1.18.dist-info/RECORD +78 -0
  62. {sglang-0.1.16.dist-info → sglang-0.1.18.dist-info}/WHEEL +1 -1
  63. sglang/srt/backend_config.py +0 -13
  64. sglang/srt/models/dbrx_config.py +0 -281
  65. sglang/srt/weight_utils.py +0 -417
  66. sglang-0.1.16.dist-info/RECORD +0 -72
  67. {sglang-0.1.16.dist-info → sglang-0.1.18.dist-info}/LICENSE +0 -0
  68. {sglang-0.1.16.dist-info → sglang-0.1.18.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,299 @@
1
+ """
2
+ Benchmark the latency of a given model. It accepts arguments similar to those of launch_server.py.
3
+
4
+ # Usage (latency test):
5
+ python -m sglang.bench_latency --model-path meta-llama/Meta-Llama-3-8B-Instruct --load-format dummy
6
+
7
+ # Usage (correctness test):
8
+ python -m sglang.bench_latency --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --correct
9
+
10
+ ### Reference output:
11
+ prefill logits (first half) tensor([[-10.0312, -9.5000, 0.8936, ..., -4.9414, -3.2402, -3.3633],
12
+ [-10.0312, -9.5000, 0.8936, ..., -4.9414, -3.2402, -3.3633],
13
+ [ -9.1875, -10.2500, 2.7109, ..., -4.3359, -4.0664, -4.1328]],
14
+ device='cuda:0', dtype=torch.float16)
15
+ prefill logits (final) tensor([[-8.3203, -7.1211, 3.3379, ..., -4.9570, -4.1328, -3.4141],
16
+ [-8.9062, -9.0156, 4.1445, ..., -4.9922, -4.4961, -4.0742],
17
+ [-9.6328, -9.0547, 4.0117, ..., -5.3047, -4.7148, -4.4609]],
18
+ device='cuda:0', dtype=torch.float16)
19
+ <s> The capital of France is.
20
+ The capital of the United States is Washington, D.C.
21
+
22
+ <s> The capital of the United Kindom is.
23
+ The capital of the United Kingdom is London.
24
+ The capital of the
25
+ <s> Today is a sunny day and I like go for a walk in the park.
26
+ I'm going to the park
27
+ """
28
+
29
+ import argparse
30
+ import dataclasses
31
+ import logging
32
+ import multiprocessing
33
+ import time
34
+
35
+ import numpy as np
36
+ import torch
37
+ import torch.distributed as dist
38
+
39
+ from sglang.srt.hf_transformers_utils import get_tokenizer
40
+ from sglang.srt.managers.controller.infer_batch import Batch, ForwardMode, Req
41
+ from sglang.srt.managers.controller.model_runner import ModelRunner
42
+ from sglang.srt.model_config import ModelConfig
43
+ from sglang.srt.sampling_params import SamplingParams
44
+ from sglang.srt.server_args import ServerArgs
45
+ from sglang.srt.utils import suppress_other_loggers
46
+
47
+
48
+ @dataclasses.dataclass
49
+ class BenchArgs:
50
+ batch_size: int = 1
51
+ input_len: int = 1024
52
+ output_len: int = 4
53
+ correctness_test: bool = False
54
+ # This is only used for correctness test
55
+ cut_len: int = 4
56
+
57
+ @staticmethod
58
+ def add_cli_args(parser: argparse.ArgumentParser):
59
+ parser.add_argument("--batch-size", type=int, default=BenchArgs.batch_size)
60
+ parser.add_argument("--input-len", type=int, default=BenchArgs.input_len)
61
+ parser.add_argument("--output-len", type=int, default=BenchArgs.output_len)
62
+ parser.add_argument("--correctness-test", action="store_true")
63
+ parser.add_argument("--cut-len", type=int, default=BenchArgs.cut_len)
64
+
65
+ @classmethod
66
+ def from_cli_args(cls, args: argparse.Namespace):
67
+ attrs = [attr.name for attr in dataclasses.fields(cls)]
68
+ return cls(**{attr: getattr(args, attr) for attr in attrs})
69
+
70
+
71
+ def load_model(server_args, tp_rank):
72
+ suppress_other_loggers()
73
+
74
+ model_config = ModelConfig(path=server_args.model_path)
75
+ model_runner = ModelRunner(
76
+ model_config=model_config,
77
+ mem_fraction_static=server_args.mem_fraction_static,
78
+ gpu_id=tp_rank,
79
+ tp_rank=tp_rank,
80
+ tp_size=server_args.tp_size,
81
+ nccl_port=28888,
82
+ server_args=server_args,
83
+ )
84
+ print(f"max_total_num_tokens={model_runner.max_total_num_tokens}")
85
+ tokenizer = get_tokenizer(
86
+ server_args.tokenizer_path,
87
+ tokenizer_mode=server_args.tokenizer_mode,
88
+ trust_remote_code=server_args.trust_remote_code,
89
+ )
90
+ if server_args.tp_size > 1:
91
+ dist.barrier()
92
+ return model_runner, tokenizer
93
+
94
+
95
+ def prepare_inputs(bench_args, tokenizer):
96
+ prompts = [
97
+ "The capital of France is",
98
+ "The capital of the United Kindom is",
99
+ "Today is a sunny day and I like",
100
+ ]
101
+ input_ids = [tokenizer.encode(p) for p in prompts]
102
+ sampling_params = SamplingParams(
103
+ temperature=0,
104
+ max_new_tokens=BenchArgs.output_len,
105
+ )
106
+
107
+ reqs = []
108
+ for i in range(len(prompts)):
109
+ assert len(input_ids[i]) > bench_args.cut_len
110
+
111
+ tmp_input_ids = input_ids[i][:bench_args.cut_len]
112
+ req = Req(rid=i, origin_input_text=prompts[i], origin_input_ids=tmp_input_ids)
113
+ req.prefix_indices = []
114
+ req.sampling_params = sampling_params
115
+ req.input_ids = req.origin_input_ids
116
+ reqs.append(req)
117
+
118
+ return input_ids, reqs
119
+
120
+
121
+ def prepare_extend_inputs(bench_args, input_ids, reqs, model_runner):
122
+ for i in range(len(reqs)):
123
+ req = reqs[i]
124
+ req.input_ids += input_ids[i][bench_args.cut_len:]
125
+ req.prefix_indices = model_runner.req_to_token_pool.req_to_token[
126
+ i, :bench_args.cut_len
127
+ ]
128
+ return reqs
129
+
130
+
131
+ def prepare_synthetic_inputs(bench_args, tokenizer):
132
+ input_ids = np.ones((bench_args.batch_size, bench_args.input_len), dtype=np.int32)
133
+ sampling_params = SamplingParams(
134
+ temperature=0,
135
+ max_new_tokens=BenchArgs.output_len,
136
+ )
137
+
138
+ reqs = []
139
+ for i in range(len(input_ids)):
140
+ req = Req(rid=i, origin_input_text="", origin_input_ids=list(input_ids[i]))
141
+ req.prefix_indices = []
142
+ req.sampling_params = sampling_params
143
+ req.input_ids = req.origin_input_ids
144
+ reqs.append(req)
145
+
146
+ return reqs
147
+
148
+
149
+ def extend(reqs, model_runner):
150
+ batch = Batch.init_new(
151
+ reqs=reqs,
152
+ req_to_token_pool=model_runner.req_to_token_pool,
153
+ token_to_kv_pool=model_runner.token_to_kv_pool,
154
+ tree_cache=None)
155
+ batch.prepare_for_extend(model_runner.model_config.vocab_size, None)
156
+ output = model_runner.forward(batch, ForwardMode.EXTEND)
157
+ next_token_ids, _ = batch.sample(output.next_token_logits)
158
+ return next_token_ids, output.next_token_logits, batch
159
+
160
+
161
+ def decode(input_token_ids, batch, model_runner):
162
+ batch.prepare_for_decode(input_token_ids.cpu().numpy())
163
+ output = model_runner.forward(batch, ForwardMode.DECODE)
164
+ next_token_ids, _ = batch.sample(output.next_token_logits)
165
+ return next_token_ids, output.next_token_logits
166
+
167
+
168
+ def correctness_test(
169
+ server_args,
170
+ bench_args,
171
+ tp_rank,
172
+ ):
173
+ rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
174
+
175
+ # Load the model
176
+ model_runner, tokenizer = load_model(server_args, tp_rank)
177
+
178
+ # Prepare inputs
179
+ input_ids, reqs = prepare_inputs(bench_args, tokenizer)
180
+
181
+ # Prefill
182
+ next_token_ids, next_token_logits, batch = extend(reqs, model_runner)
183
+ rank_print("prefill logits (first half)", next_token_logits)
184
+
185
+ # Prepare extend inputs
186
+ reqs = prepare_extend_inputs(bench_args, input_ids, reqs, model_runner)
187
+
188
+ # Extend
189
+ next_token_ids, next_token_logits, batch = extend(reqs, model_runner)
190
+ rank_print("prefill logits (final)", next_token_logits)
191
+
192
+ # Decode
193
+ output_ids = [list(req.input_ids) for req in reqs]
194
+ for _ in range(bench_args.output_len):
195
+ next_token_ids, _ = decode(next_token_ids, batch, model_runner)
196
+ for i in range(len(reqs)):
197
+ output_ids[i].append(next_token_ids[i])
198
+
199
+ # Print
200
+ for i in range(len(reqs)):
201
+ print(tokenizer.decode(output_ids[i]))
202
+
203
+
204
+ def latency_test(
205
+ server_args,
206
+ bench_args,
207
+ tp_rank,
208
+ ):
209
+ rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
210
+
211
+ # Load the model
212
+ model_runner, tokenizer = load_model(server_args, tp_rank)
213
+ print(f"max_batch_size={model_runner.max_total_num_tokens // (bench_args.input_len + bench_args.output_len)}")
214
+
215
+ # Prepare inputs
216
+ reqs = prepare_synthetic_inputs(bench_args, tokenizer)
217
+
218
+ def clear():
219
+ model_runner.req_to_token_pool.clear()
220
+ model_runner.token_to_kv_pool.clear()
221
+
222
+ @torch.inference_mode()
223
+ def run_once(output_len):
224
+ # Prefill
225
+ torch.cuda.synchronize()
226
+ tot_latency = 0
227
+ tic = time.time()
228
+ next_token_ids, _, batch = extend(reqs, model_runner)
229
+ torch.cuda.synchronize()
230
+ prefill_latency = time.time() - tic
231
+ tot_latency += prefill_latency
232
+ throughput = bench_args.input_len * bench_args.batch_size / prefill_latency
233
+ rank_print(f"Prefill. latency: {prefill_latency:6.5f} s, throughput: {throughput:9.2f} token/s")
234
+
235
+ # Decode
236
+ for i in range(output_len):
237
+ torch.cuda.synchronize()
238
+ tic = time.time()
239
+ next_token_ids, _ = decode(next_token_ids, batch, model_runner)
240
+ torch.cuda.synchronize()
241
+ latency = time.time() - tic
242
+ tot_latency += latency
243
+ throughput = bench_args.batch_size / latency
244
+ if i < 5: rank_print(f"Decode. latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s")
245
+ avg_decode_latency = (tot_latency - prefill_latency) / output_len
246
+ avg_decode_throughput = bench_args.batch_size / avg_decode_latency
247
+ rank_print(f"Decode. avg latency: {avg_decode_latency:6.5f} s, avg throughput: {avg_decode_throughput:9.2f} token/s")
248
+
249
+ throughput = (bench_args.input_len + bench_args.output_len) * bench_args.batch_size / tot_latency
250
+ rank_print(f"Total. latency: {tot_latency:6.3f} s, throughput: {throughput:9.2f} token/s")
251
+
252
+ # Warm up
253
+ run_once(4)
254
+ clear()
255
+
256
+ # Run again
257
+ run_once(bench_args.output_len)
258
+
259
+
260
+ def main(server_args, bench_args):
261
+ print(bench_args)
262
+
263
+ if bench_args.correctness_test:
264
+ work_func = correctness_test
265
+ else:
266
+ work_func = latency_test
267
+
268
+ workers = []
269
+ for tp_rank in range(server_args.tp_size):
270
+ proc = multiprocessing.Process(
271
+ target=work_func,
272
+ args=(
273
+ server_args,
274
+ bench_args,
275
+ tp_rank,
276
+ ),
277
+ )
278
+ proc.start()
279
+ workers.append(proc)
280
+
281
+ for proc in workers:
282
+ proc.join()
283
+
284
+
285
+ if __name__ == "__main__":
286
+ parser = argparse.ArgumentParser()
287
+ ServerArgs.add_cli_args(parser)
288
+ BenchArgs.add_cli_args(parser)
289
+ args = parser.parse_args()
290
+
291
+ server_args = ServerArgs.from_cli_args(args)
292
+ bench_args = BenchArgs.from_cli_args(args)
293
+
294
+ logging.basicConfig(
295
+ level=getattr(logging, server_args.log_level.upper()),
296
+ format="%(message)s",
297
+ )
298
+
299
+ main(server_args, bench_args)
sglang/global_config.py CHANGED
@@ -26,7 +26,17 @@ class GlobalConfig:
26
26
  self.concate_and_append_mode = "no_adjust"
27
27
 
28
28
  # Request dependency time due to network delay
29
- self.request_dependency_time = 0.03
30
-
29
+ self.request_dependency_delay = 0.02
30
+ self.wait_for_new_request_delay = 0.0006
31
+
32
+ # New generation token ratio estimation
33
+ self.base_new_token_ratio = 0.4
34
+ self.base_min_new_token_ratio = 0.2
35
+ self.new_token_ratio_decay = 0.0001
36
+ self.new_token_ratio_recovery = 0.05
37
+
38
+ # The threshold (number of tokens) to trigger layer-wise cuda sync.
39
+ # This can improve the speed for large batch sizes during prefill.
40
+ self.layer_sync_threshold = 8192
31
41
 
32
42
  global_config = GlobalConfig()
sglang/lang/compiler.py CHANGED
@@ -4,7 +4,7 @@ from queue import Queue
4
4
  from typing import List, Union
5
5
 
6
6
  from sglang.global_config import global_config
7
- from sglang.lang.interpreter import ProgramState, StreamExecutor, pin_program
7
+ from sglang.lang.interpreter import ProgramState, StreamExecutor, cache_program
8
8
  from sglang.lang.ir import (
9
9
  SglArgument,
10
10
  SglConstantText,
@@ -184,7 +184,7 @@ class CompiledFunction:
184
184
 
185
185
  # Extract prefix by tracing and cache it
186
186
  if len(batch_kwargs) > 1:
187
- pin_program(self.function, backend)
187
+ cache_program(self.function, backend)
188
188
 
189
189
  # Run all programs
190
190
  if num_threads == "auto":
@@ -6,6 +6,7 @@ import multiprocessing
6
6
  import queue
7
7
  import threading
8
8
  import uuid
9
+ import warnings
9
10
  from concurrent.futures import ThreadPoolExecutor
10
11
  from contextlib import contextmanager
11
12
  from typing import Any, Callable, Dict, List, Optional, Union
@@ -30,7 +31,11 @@ from sglang.lang.ir import (
30
31
  SglVarScopeEnd,
31
32
  SglVideo,
32
33
  )
33
- from sglang.utils import encode_image_base64, encode_video_base64, get_exception_traceback
34
+ from sglang.utils import (
35
+ encode_image_base64,
36
+ encode_video_base64,
37
+ get_exception_traceback,
38
+ )
34
39
 
35
40
 
36
41
  def run_internal(state, program, func_args, func_kwargs, sync):
@@ -61,7 +66,7 @@ def run_program(
61
66
  default_sampling_para,
62
67
  chat_template=None,
63
68
  stream=stream,
64
- api_num_spec_tokens=program.api_num_spec_tokens,
69
+ num_api_spec_tokens=program.num_api_spec_tokens,
65
70
  )
66
71
  state = ProgramState(stream_executor)
67
72
 
@@ -173,7 +178,7 @@ class StreamExecutor:
173
178
  default_sampling_para,
174
179
  chat_template,
175
180
  stream,
176
- api_num_spec_tokens=None,
181
+ num_api_spec_tokens=None,
177
182
  use_thread=True,
178
183
  ):
179
184
  self.sid = uuid.uuid4().hex
@@ -181,20 +186,16 @@ class StreamExecutor:
181
186
  self.arguments: Dict[str, Any] = arguments
182
187
  self.default_sampling_para = default_sampling_para
183
188
  self.stream = stream
184
- self.api_num_spec_tokens = api_num_spec_tokens
185
189
 
186
190
  self.variables = {} # Dict[name: str -> value: str]
187
191
  self.variable_event = {} # Dict[name: str -> event: threading.Event]
188
192
  self.meta_info = {} # Dict[name: str -> info: str]
189
193
  self.is_finished = False
190
- self.error = None
194
+ self.error_ = None
191
195
 
192
196
  # For completion
193
197
  self.text_ = "" # The full text
194
198
 
195
- # For speculative execution
196
- self.speculated_text = ""
197
-
198
199
  # For chat
199
200
  self.messages_ = [] # The messages in the OpenAI API format
200
201
  self.chat_template = chat_template or self.backend.get_chat_template()
@@ -208,6 +209,10 @@ class StreamExecutor:
208
209
  # For fork/join
209
210
  self.fork_start_text_pos = None
210
211
 
212
+ # For speculative execution
213
+ self.num_api_spec_tokens = num_api_spec_tokens
214
+ self.speculated_text = ""
215
+
211
216
  # Worker thread
212
217
  self.use_thread = use_thread
213
218
  if self.use_thread:
@@ -286,6 +291,8 @@ class StreamExecutor:
286
291
  exes[i].fork_start_text_pos = len(self.text_)
287
292
  exes[i].images_ = list(self.images_)
288
293
 
294
+ # TODO(ying): handle API speculative execution
295
+
289
296
  return exes
290
297
 
291
298
  def text(self):
@@ -296,6 +303,10 @@ class StreamExecutor:
296
303
  self.sync()
297
304
  return self.messages_
298
305
 
306
+ def error(self):
307
+ self.sync()
308
+ return self.error_
309
+
299
310
  def end(self):
300
311
  if self.use_thread:
301
312
  if self.worker.is_alive():
@@ -314,7 +325,7 @@ class StreamExecutor:
314
325
  try:
315
326
  self._execute(expr)
316
327
  except Exception as e:
317
- # print(f"Error in stream_executor: {get_exception_traceback()}")
328
+ warnings.warn(f"Error in stream_executor: {get_exception_traceback()}")
318
329
  error = e
319
330
  break
320
331
  self.queue.task_done()
@@ -334,7 +345,7 @@ class StreamExecutor:
334
345
  if self.stream_var_event:
335
346
  for name in self.stream_var_event:
336
347
  self.stream_var_event[name].set()
337
- self.error = error
348
+ self.error_ = error
338
349
 
339
350
  if self.stream_text_event:
340
351
  self.stream_text_event.set()
@@ -383,12 +394,23 @@ class StreamExecutor:
383
394
  else:
384
395
  raise ValueError(f"Unknown type: {type(other)}")
385
396
 
386
- def _execute_fill(self, value: str):
397
+ def _execute_fill(self, value: str, prefix=False):
387
398
  value = str(value)
399
+
400
+ if (
401
+ self.cur_role == "assistant"
402
+ and self.num_api_spec_tokens is not None
403
+ and self.backend.is_chat_model
404
+ and not prefix
405
+ ):
406
+ self.backend.spec_fill(value)
407
+ return
408
+
388
409
  if self.speculated_text.startswith(value):
389
410
  self.speculated_text = self.speculated_text[len(value) :]
390
411
  else:
391
412
  self.speculated_text = ""
413
+
392
414
  self.text_ += value
393
415
 
394
416
  def _execute_image(self, expr: SglImage):
@@ -413,65 +435,80 @@ class StreamExecutor:
413
435
  # if global_config.eager_fill_image:
414
436
  # self.backend.fill_image(self)
415
437
 
438
+ def _spec_gen(self, sampling_params):
439
+ stop = sampling_params.stop
440
+ max_new_tokens = sampling_params.max_new_tokens
441
+ meta_info = {}
442
+
443
+ def regen():
444
+ nonlocal meta_info
445
+
446
+ sampling_params.max_new_tokens = max(
447
+ sampling_params.max_new_tokens, self.num_api_spec_tokens
448
+ )
449
+ sampling_params.stop = None
450
+ self.speculated_text, meta_info = self.backend.generate(
451
+ self, sampling_params=sampling_params
452
+ )
453
+
454
+ def find_stop():
455
+ if isinstance(stop, str):
456
+ return self.speculated_text.find(stop)
457
+ elif isinstance(stop, (tuple, list)):
458
+ pos = -1
459
+ for stop_str in stop:
460
+ stop_pos = self.speculated_text.find(stop_str)
461
+ if stop_pos != -1 and (pos == -1 or stop_pos < pos):
462
+ pos = stop_pos
463
+ return pos
464
+ else:
465
+ raise Exception("Wrong type of stop in sampling parameters.")
466
+
467
+ if stop is None:
468
+ if len(self.speculated_text) < max_new_tokens:
469
+ regen()
470
+ comp = self.speculated_text[:max_new_tokens]
471
+ self.speculated_text = self.speculated_text[max_new_tokens:]
472
+ elif isinstance(stop, (str, list, tuple)):
473
+ if self.speculated_text == "":
474
+ regen()
475
+ stop_pos = find_stop()
476
+ if stop_pos == -1:
477
+ stop_pos = min(
478
+ sampling_params.max_new_tokens,
479
+ len(self.speculated_text),
480
+ )
481
+ comp = self.speculated_text[:stop_pos]
482
+ self.speculated_text = self.speculated_text[stop_pos:]
483
+ else:
484
+ raise ValueError("Wrong type of stop in sampling parameters.")
485
+
486
+ return comp, meta_info
487
+
416
488
  def _execute_gen(self, expr: SglGen):
417
489
  sampling_params = self._resolve_sampling_params(expr.sampling_params)
418
490
  name = expr.name
419
491
 
420
492
  if not self.stream:
421
- if self.api_num_spec_tokens is not None:
422
- stop = sampling_params.stop
423
- max_new_tokens = sampling_params.max_new_tokens
424
- meta_info = {}
425
-
426
- def regen():
427
- sampling_params.max_new_tokens = max(
428
- sampling_params.max_new_tokens, self.api_num_spec_tokens
429
- )
430
- sampling_params.stop = None
431
- self.speculated_text, meta_info = self.backend.generate(
432
- self, sampling_params=sampling_params
433
- )
434
-
435
- def find_stop():
436
- if isinstance(stop, str):
437
- return self.speculated_text.find(stop), len(stop)
438
- elif isinstance(stop, (tuple, list)):
439
- pos = -1
440
- stop_len = 0
441
- for stop_str in stop:
442
- stop_pos = self.speculated_text.find(stop_str)
443
- if stop_pos != -1 and (pos == -1 or stop_pos < pos):
444
- pos = stop_pos
445
- stop_len = len(stop_str)
446
- return pos, stop_len
447
- else:
448
- raise Exception("Wrong type of stop in sampling parameters.")
449
-
450
- if stop is None:
451
- if len(self.speculated_text) < max_new_tokens:
452
- regen()
453
- comp = self.speculated_text[:max_new_tokens]
454
- self.speculated_text = self.speculated_text[max_new_tokens:]
455
- elif isinstance(stop, (str, list, tuple)):
456
- if self.speculated_text == "":
457
- regen()
458
- stop_pos, stop_len = find_stop()
459
- if stop_pos == -1:
460
- stop_pos, stop_len = (
461
- min(
462
- sampling_params.max_new_tokens,
463
- len(self.speculated_text),
464
- ),
465
- 0,
466
- )
467
- comp = self.speculated_text[:stop_pos]
468
- self.speculated_text = self.speculated_text[stop_pos:]
469
- else:
470
- raise ValueError("Wrong type of stop in sampling parameters.")
471
- else:
493
+ if self.num_api_spec_tokens is None:
472
494
  comp, meta_info = self.backend.generate(
473
- self, sampling_params=sampling_params
495
+ self,
496
+ sampling_params=sampling_params,
474
497
  )
498
+ else:
499
+ if self.backend.is_chat_model:
500
+ # Speculative execution on models with only chat interface.
501
+ # Store the calls into a temporary list.
502
+ # They will be lazily executed later.
503
+ comp, meta_info = self.backend.generate(
504
+ self,
505
+ sampling_params=sampling_params,
506
+ spec_var_name=name,
507
+ )
508
+ return
509
+
510
+ else: # Speculative execution on models with completion interface
511
+ comp, meta_info = self._spec_gen(sampling_params)
475
512
 
476
513
  self.text_ += comp
477
514
 
@@ -479,6 +516,9 @@ class StreamExecutor:
479
516
  self.meta_info[name] = meta_info
480
517
  self.variable_event[name].set()
481
518
  else:
519
+ assert (
520
+ self.num_api_spec_tokens is None
521
+ ), "stream is not supported with api speculative execution"
482
522
  generator = self.backend.generate_stream(
483
523
  self, sampling_params=sampling_params
484
524
  )
@@ -534,10 +574,19 @@ class StreamExecutor:
534
574
 
535
575
  prefix, _ = self.chat_template.get_prefix_and_suffix(expr.role, self.messages_)
536
576
 
537
- self._execute_fill(prefix)
577
+ self._execute_fill(prefix, prefix=True)
538
578
  self.cur_role_begin_pos = len(self.text_)
539
579
 
540
580
  def _execute_role_end(self, expr: SglRoleEnd):
581
+ if (
582
+ self.cur_role == "assistant"
583
+ and self.num_api_spec_tokens is not None
584
+ and self.backend.is_chat_model
585
+ ):
586
+ # Execute the stored lazy generation calls
587
+ self.backend.role_end_generate(self)
588
+ self.cur_role = None
589
+
541
590
  new_text = self.text_[self.cur_role_begin_pos :].lstrip()
542
591
 
543
592
  _, suffix = self.chat_template.get_prefix_and_suffix(expr.role, self.messages_)
@@ -564,8 +613,6 @@ class StreamExecutor:
564
613
  # OpenAI chat API format
565
614
  self.messages_.append({"role": expr.role, "content": new_text})
566
615
 
567
- self.cur_role = None
568
-
569
616
  def _execute_var_scope_begin(self, expr: SglVarScopeBegin):
570
617
  self.variables[expr.name] = int(len(self.text_))
571
618
 
@@ -709,7 +756,7 @@ class ProgramState:
709
756
  return self.stream_executor.sync()
710
757
 
711
758
  def error(self):
712
- return self.stream_executor.error
759
+ return self.stream_executor.error()
713
760
 
714
761
  def text_iter(self, var_name: Optional[str] = None):
715
762
  if self.stream_executor.stream: