sglang 0.1.14__py3-none-any.whl → 0.1.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +59 -2
- sglang/api.py +40 -11
- sglang/backend/anthropic.py +17 -3
- sglang/backend/litellm.py +90 -0
- sglang/backend/openai.py +160 -12
- sglang/backend/runtime_endpoint.py +62 -27
- sglang/backend/vertexai.py +1 -0
- sglang/bench_latency.py +320 -0
- sglang/global_config.py +24 -3
- sglang/lang/chat_template.py +122 -6
- sglang/lang/compiler.py +2 -2
- sglang/lang/interpreter.py +206 -98
- sglang/lang/ir.py +98 -34
- sglang/lang/tracer.py +6 -4
- sglang/launch_server.py +4 -1
- sglang/launch_server_llavavid.py +32 -0
- sglang/srt/constrained/__init__.py +14 -6
- sglang/srt/constrained/fsm_cache.py +9 -2
- sglang/srt/constrained/jump_forward.py +113 -24
- sglang/srt/conversation.py +4 -2
- sglang/srt/flush_cache.py +18 -0
- sglang/srt/hf_transformers_utils.py +144 -3
- sglang/srt/layers/context_flashattention_nopad.py +1 -0
- sglang/srt/layers/extend_attention.py +20 -1
- sglang/srt/layers/fused_moe.py +596 -0
- sglang/srt/layers/logits_processor.py +190 -61
- sglang/srt/layers/radix_attention.py +62 -53
- sglang/srt/layers/token_attention.py +21 -9
- sglang/srt/managers/controller/cuda_graph_runner.py +196 -0
- sglang/srt/managers/controller/dp_worker.py +113 -0
- sglang/srt/managers/controller/infer_batch.py +908 -0
- sglang/srt/managers/controller/manager_multi.py +195 -0
- sglang/srt/managers/controller/manager_single.py +177 -0
- sglang/srt/managers/controller/model_runner.py +359 -0
- sglang/srt/managers/{router → controller}/radix_cache.py +102 -53
- sglang/srt/managers/controller/schedule_heuristic.py +65 -0
- sglang/srt/managers/controller/tp_worker.py +813 -0
- sglang/srt/managers/detokenizer_manager.py +42 -40
- sglang/srt/managers/io_struct.py +44 -10
- sglang/srt/managers/tokenizer_manager.py +224 -82
- sglang/srt/memory_pool.py +52 -59
- sglang/srt/model_config.py +97 -2
- sglang/srt/models/chatglm.py +399 -0
- sglang/srt/models/commandr.py +369 -0
- sglang/srt/models/dbrx.py +406 -0
- sglang/srt/models/gemma.py +34 -38
- sglang/srt/models/gemma2.py +436 -0
- sglang/srt/models/grok.py +738 -0
- sglang/srt/models/llama2.py +47 -37
- sglang/srt/models/llama_classification.py +107 -0
- sglang/srt/models/llava.py +92 -27
- sglang/srt/models/llavavid.py +298 -0
- sglang/srt/models/minicpm.py +366 -0
- sglang/srt/models/mixtral.py +302 -127
- sglang/srt/models/mixtral_quant.py +372 -0
- sglang/srt/models/qwen.py +40 -35
- sglang/srt/models/qwen2.py +33 -36
- sglang/srt/models/qwen2_moe.py +473 -0
- sglang/srt/models/stablelm.py +33 -39
- sglang/srt/models/yivl.py +19 -26
- sglang/srt/openai_api_adapter.py +411 -0
- sglang/srt/{managers/openai_protocol.py → openai_protocol.py} +44 -19
- sglang/srt/sampling_params.py +2 -0
- sglang/srt/server.py +197 -481
- sglang/srt/server_args.py +190 -74
- sglang/srt/utils.py +460 -95
- sglang/test/test_programs.py +73 -10
- sglang/test/test_utils.py +226 -7
- sglang/utils.py +97 -27
- {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/METADATA +74 -45
- sglang-0.1.21.dist-info/RECORD +82 -0
- {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/WHEEL +1 -1
- sglang/srt/backend_config.py +0 -13
- sglang/srt/managers/router/infer_batch.py +0 -503
- sglang/srt/managers/router/manager.py +0 -79
- sglang/srt/managers/router/model_rpc.py +0 -686
- sglang/srt/managers/router/model_runner.py +0 -514
- sglang/srt/managers/router/scheduler.py +0 -70
- sglang-0.1.14.dist-info/RECORD +0 -64
- {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/LICENSE +0 -0
- {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/top_level.txt +0 -0
sglang/lang/interpreter.py
CHANGED
@@ -1,15 +1,18 @@
|
|
1
1
|
"""The interpreter that executes SGL programs"""
|
2
2
|
|
3
3
|
import asyncio
|
4
|
+
import contextvars
|
4
5
|
import multiprocessing
|
5
6
|
import queue
|
6
7
|
import threading
|
7
8
|
import uuid
|
9
|
+
import warnings
|
8
10
|
from concurrent.futures import ThreadPoolExecutor
|
9
11
|
from contextlib import contextmanager
|
10
12
|
from typing import Any, Callable, Dict, List, Optional, Union
|
11
13
|
|
12
14
|
import tqdm
|
15
|
+
|
13
16
|
from sglang.global_config import global_config
|
14
17
|
from sglang.lang.ir import (
|
15
18
|
SglCommitLazy,
|
@@ -26,8 +29,13 @@ from sglang.lang.ir import (
|
|
26
29
|
SglVariable,
|
27
30
|
SglVarScopeBegin,
|
28
31
|
SglVarScopeEnd,
|
32
|
+
SglVideo,
|
33
|
+
)
|
34
|
+
from sglang.utils import (
|
35
|
+
encode_image_base64,
|
36
|
+
encode_video_base64,
|
37
|
+
get_exception_traceback,
|
29
38
|
)
|
30
|
-
from sglang.utils import encode_image_base64
|
31
39
|
|
32
40
|
|
33
41
|
def run_internal(state, program, func_args, func_kwargs, sync):
|
@@ -58,7 +66,7 @@ def run_program(
|
|
58
66
|
default_sampling_para,
|
59
67
|
chat_template=None,
|
60
68
|
stream=stream,
|
61
|
-
|
69
|
+
num_api_spec_tokens=program.num_api_spec_tokens,
|
62
70
|
)
|
63
71
|
state = ProgramState(stream_executor)
|
64
72
|
|
@@ -84,9 +92,9 @@ def run_program_batch(
|
|
84
92
|
if hasattr(backend, "endpoint"):
|
85
93
|
backend = backend.endpoint
|
86
94
|
|
87
|
-
#
|
88
|
-
if len(batch_arguments) > 1:
|
89
|
-
|
95
|
+
# Pre-cache the common prefix for a batch. The prefix is extracted by tracing the program.
|
96
|
+
if global_config.enable_precache_with_tracing and len(batch_arguments) > 1:
|
97
|
+
cache_program(program, backend)
|
90
98
|
|
91
99
|
# Run all programs
|
92
100
|
if num_threads == "auto":
|
@@ -152,21 +160,12 @@ def run_program_batch(
|
|
152
160
|
return rets
|
153
161
|
|
154
162
|
|
155
|
-
def
|
156
|
-
|
157
|
-
# TODO: handle multiple backends
|
158
|
-
from sglang.lang.tracer import extract_prefix_by_tracing
|
159
|
-
|
160
|
-
prefix = extract_prefix_by_tracing(program, backend)
|
161
|
-
if prefix and len(prefix) > 64:
|
162
|
-
prefix_rid = backend.cache_prefix(prefix)
|
163
|
-
program.pin_prefix_rid = prefix_rid
|
164
|
-
return prefix_rid
|
165
|
-
return None
|
163
|
+
def cache_program(program, backend):
|
164
|
+
from sglang.lang.tracer import extract_prefix_by_tracing
|
166
165
|
|
167
|
-
|
168
|
-
|
169
|
-
|
166
|
+
prefix = extract_prefix_by_tracing(program, backend)
|
167
|
+
if prefix and len(prefix) > 64:
|
168
|
+
backend.cache_prefix(prefix)
|
170
169
|
|
171
170
|
|
172
171
|
class StreamExecutor:
|
@@ -179,7 +178,7 @@ class StreamExecutor:
|
|
179
178
|
default_sampling_para,
|
180
179
|
chat_template,
|
181
180
|
stream,
|
182
|
-
|
181
|
+
num_api_spec_tokens=None,
|
183
182
|
use_thread=True,
|
184
183
|
):
|
185
184
|
self.sid = uuid.uuid4().hex
|
@@ -187,19 +186,16 @@ class StreamExecutor:
|
|
187
186
|
self.arguments: Dict[str, Any] = arguments
|
188
187
|
self.default_sampling_para = default_sampling_para
|
189
188
|
self.stream = stream
|
190
|
-
self.api_num_spec_tokens = api_num_spec_tokens
|
191
189
|
|
192
190
|
self.variables = {} # Dict[name: str -> value: str]
|
193
191
|
self.variable_event = {} # Dict[name: str -> event: threading.Event]
|
194
192
|
self.meta_info = {} # Dict[name: str -> info: str]
|
195
193
|
self.is_finished = False
|
194
|
+
self.error_ = None
|
196
195
|
|
197
196
|
# For completion
|
198
197
|
self.text_ = "" # The full text
|
199
198
|
|
200
|
-
# For speculative execution
|
201
|
-
self.speculated_text = ""
|
202
|
-
|
203
199
|
# For chat
|
204
200
|
self.messages_ = [] # The messages in the OpenAI API format
|
205
201
|
self.chat_template = chat_template or self.backend.get_chat_template()
|
@@ -213,11 +209,21 @@ class StreamExecutor:
|
|
213
209
|
# For fork/join
|
214
210
|
self.fork_start_text_pos = None
|
215
211
|
|
212
|
+
# For speculative execution
|
213
|
+
self.num_api_spec_tokens = num_api_spec_tokens
|
214
|
+
self.speculated_text = ""
|
215
|
+
|
216
216
|
# Worker thread
|
217
217
|
self.use_thread = use_thread
|
218
218
|
if self.use_thread:
|
219
219
|
self.queue = queue.Queue()
|
220
|
-
|
220
|
+
|
221
|
+
def _run_worker_in_context():
|
222
|
+
self._thread_worker_func()
|
223
|
+
|
224
|
+
self.worker = threading.Thread(
|
225
|
+
target=contextvars.copy_context().run, args=(_run_worker_in_context,)
|
226
|
+
)
|
221
227
|
self.worker.start()
|
222
228
|
|
223
229
|
# For streaming
|
@@ -248,17 +254,24 @@ class StreamExecutor:
|
|
248
254
|
def set_var(self, name, value):
|
249
255
|
self.variables[name] = value
|
250
256
|
|
251
|
-
def get_meta_info(self, name):
|
257
|
+
def get_meta_info(self, name, timeout=None):
|
252
258
|
if name in self.variable_event:
|
253
|
-
self.variable_event[name].wait()
|
259
|
+
got = self.variable_event[name].wait(timeout)
|
260
|
+
if not got:
|
261
|
+
raise TimeoutError(f"Timeout while waiting for event '{name}'")
|
254
262
|
ret = self.meta_info.get(name, None)
|
255
263
|
return ret
|
256
264
|
|
257
|
-
def fork(
|
258
|
-
self
|
259
|
-
|
265
|
+
def fork(
|
266
|
+
self,
|
267
|
+
size: int = 1,
|
268
|
+
position_ids_offset: Optional[List[int]] = None,
|
269
|
+
):
|
270
|
+
if size > 1:
|
271
|
+
self.submit(SglCommitLazy())
|
260
272
|
|
261
|
-
|
273
|
+
self.sync()
|
274
|
+
size = int(size)
|
262
275
|
|
263
276
|
exes = [
|
264
277
|
StreamExecutor(
|
@@ -268,14 +281,17 @@ class StreamExecutor:
|
|
268
281
|
self.chat_template,
|
269
282
|
self.stream,
|
270
283
|
)
|
271
|
-
for _ in range(
|
284
|
+
for _ in range(size)
|
272
285
|
]
|
273
|
-
for i in range(
|
286
|
+
for i in range(size):
|
274
287
|
exes[i].variables = dict(self.variables)
|
275
288
|
exes[i].text_ = str(self.text_)
|
276
289
|
exes[i].messages_ = list(self.messages_)
|
277
290
|
exes[i].cur_role = self.cur_role
|
278
291
|
exes[i].fork_start_text_pos = len(self.text_)
|
292
|
+
exes[i].images_ = list(self.images_)
|
293
|
+
|
294
|
+
# TODO(ying): handle API speculative execution
|
279
295
|
|
280
296
|
return exes
|
281
297
|
|
@@ -287,6 +303,10 @@ class StreamExecutor:
|
|
287
303
|
self.sync()
|
288
304
|
return self.messages_
|
289
305
|
|
306
|
+
def error(self):
|
307
|
+
self.sync()
|
308
|
+
return self.error_
|
309
|
+
|
290
310
|
def end(self):
|
291
311
|
if self.use_thread:
|
292
312
|
if self.worker.is_alive():
|
@@ -294,17 +314,39 @@ class StreamExecutor:
|
|
294
314
|
self.backend.end_program(self)
|
295
315
|
|
296
316
|
def _thread_worker_func(self):
|
317
|
+
error = None
|
318
|
+
|
297
319
|
while True:
|
298
320
|
expr = self.queue.get()
|
299
321
|
if expr is None:
|
300
322
|
self.queue.task_done()
|
301
323
|
break
|
302
324
|
|
303
|
-
|
325
|
+
try:
|
326
|
+
self._execute(expr)
|
327
|
+
except Exception as e:
|
328
|
+
warnings.warn(f"Error in stream_executor: {get_exception_traceback()}")
|
329
|
+
error = e
|
330
|
+
break
|
304
331
|
self.queue.task_done()
|
305
332
|
if self.stream_text_event:
|
306
333
|
self.stream_text_event.set()
|
307
334
|
|
335
|
+
# Clean the queue and events
|
336
|
+
if error is not None:
|
337
|
+
try:
|
338
|
+
while True:
|
339
|
+
self.queue.task_done()
|
340
|
+
self.queue.get_nowait()
|
341
|
+
except queue.Empty:
|
342
|
+
pass
|
343
|
+
for name in self.variable_event:
|
344
|
+
self.variable_event[name].set()
|
345
|
+
if self.stream_var_event:
|
346
|
+
for name in self.stream_var_event:
|
347
|
+
self.stream_var_event[name].set()
|
348
|
+
self.error_ = error
|
349
|
+
|
308
350
|
if self.stream_text_event:
|
309
351
|
self.stream_text_event.set()
|
310
352
|
|
@@ -331,6 +373,8 @@ class StreamExecutor:
|
|
331
373
|
self._execute_role_end(other)
|
332
374
|
elif isinstance(other, SglImage):
|
333
375
|
self._execute_image(other)
|
376
|
+
elif isinstance(other, SglVideo):
|
377
|
+
self._execute_video(other)
|
334
378
|
elif isinstance(other, SglVariable):
|
335
379
|
self._execute_variable(other)
|
336
380
|
elif isinstance(other, SglVarScopeBegin):
|
@@ -350,12 +394,23 @@ class StreamExecutor:
|
|
350
394
|
else:
|
351
395
|
raise ValueError(f"Unknown type: {type(other)}")
|
352
396
|
|
353
|
-
def _execute_fill(self, value: str):
|
397
|
+
def _execute_fill(self, value: str, prefix=False):
|
354
398
|
value = str(value)
|
399
|
+
|
400
|
+
if (
|
401
|
+
self.cur_role == "assistant"
|
402
|
+
and self.num_api_spec_tokens is not None
|
403
|
+
and self.backend.is_chat_model
|
404
|
+
and not prefix
|
405
|
+
):
|
406
|
+
self.backend.spec_fill(value)
|
407
|
+
return
|
408
|
+
|
355
409
|
if self.speculated_text.startswith(value):
|
356
410
|
self.speculated_text = self.speculated_text[len(value) :]
|
357
411
|
else:
|
358
412
|
self.speculated_text = ""
|
413
|
+
|
359
414
|
self.text_ += value
|
360
415
|
|
361
416
|
def _execute_image(self, expr: SglImage):
|
@@ -367,68 +422,93 @@ class StreamExecutor:
|
|
367
422
|
self.cur_images.append((path, base64_data))
|
368
423
|
self.text_ += self.chat_template.image_token
|
369
424
|
|
425
|
+
def _execute_video(self, expr: SglVideo):
|
426
|
+
path = expr.path
|
427
|
+
num_frames = expr.num_frames
|
428
|
+
|
429
|
+
base64_data = encode_video_base64(path, num_frames)
|
430
|
+
|
431
|
+
self.images_.append((path, base64_data))
|
432
|
+
self.cur_images.append((path, base64_data))
|
433
|
+
self.text_ += self.chat_template.image_token
|
434
|
+
|
370
435
|
# if global_config.eager_fill_image:
|
371
436
|
# self.backend.fill_image(self)
|
372
437
|
|
438
|
+
def _spec_gen(self, sampling_params):
|
439
|
+
stop = sampling_params.stop
|
440
|
+
max_new_tokens = sampling_params.max_new_tokens
|
441
|
+
meta_info = {}
|
442
|
+
|
443
|
+
def regen():
|
444
|
+
nonlocal meta_info
|
445
|
+
|
446
|
+
sampling_params.max_new_tokens = max(
|
447
|
+
sampling_params.max_new_tokens, self.num_api_spec_tokens
|
448
|
+
)
|
449
|
+
sampling_params.stop = None
|
450
|
+
self.speculated_text, meta_info = self.backend.generate(
|
451
|
+
self, sampling_params=sampling_params
|
452
|
+
)
|
453
|
+
|
454
|
+
def find_stop():
|
455
|
+
if isinstance(stop, str):
|
456
|
+
return self.speculated_text.find(stop)
|
457
|
+
elif isinstance(stop, (tuple, list)):
|
458
|
+
pos = -1
|
459
|
+
for stop_str in stop:
|
460
|
+
stop_pos = self.speculated_text.find(stop_str)
|
461
|
+
if stop_pos != -1 and (pos == -1 or stop_pos < pos):
|
462
|
+
pos = stop_pos
|
463
|
+
return pos
|
464
|
+
else:
|
465
|
+
raise Exception("Wrong type of stop in sampling parameters.")
|
466
|
+
|
467
|
+
if stop is None:
|
468
|
+
if len(self.speculated_text) < max_new_tokens:
|
469
|
+
regen()
|
470
|
+
comp = self.speculated_text[:max_new_tokens]
|
471
|
+
self.speculated_text = self.speculated_text[max_new_tokens:]
|
472
|
+
elif isinstance(stop, (str, list, tuple)):
|
473
|
+
if self.speculated_text == "":
|
474
|
+
regen()
|
475
|
+
stop_pos = find_stop()
|
476
|
+
if stop_pos == -1:
|
477
|
+
stop_pos = min(
|
478
|
+
sampling_params.max_new_tokens,
|
479
|
+
len(self.speculated_text),
|
480
|
+
)
|
481
|
+
comp = self.speculated_text[:stop_pos]
|
482
|
+
self.speculated_text = self.speculated_text[stop_pos:]
|
483
|
+
else:
|
484
|
+
raise ValueError("Wrong type of stop in sampling parameters.")
|
485
|
+
|
486
|
+
return comp, meta_info
|
487
|
+
|
373
488
|
def _execute_gen(self, expr: SglGen):
|
374
489
|
sampling_params = self._resolve_sampling_params(expr.sampling_params)
|
375
490
|
name = expr.name
|
376
491
|
|
377
492
|
if not self.stream:
|
378
|
-
if self.
|
379
|
-
stop = sampling_params.stop
|
380
|
-
max_new_tokens = sampling_params.max_new_tokens
|
381
|
-
meta_info = {}
|
382
|
-
|
383
|
-
def regen():
|
384
|
-
sampling_params.max_new_tokens = max(
|
385
|
-
sampling_params.max_new_tokens, self.api_num_spec_tokens
|
386
|
-
)
|
387
|
-
sampling_params.stop = None
|
388
|
-
self.speculated_text, meta_info = self.backend.generate(
|
389
|
-
self, sampling_params=sampling_params
|
390
|
-
)
|
391
|
-
|
392
|
-
def find_stop():
|
393
|
-
if isinstance(stop, str):
|
394
|
-
return self.speculated_text.find(stop), len(stop)
|
395
|
-
elif isinstance(stop, (tuple, list)):
|
396
|
-
pos = -1
|
397
|
-
stop_len = 0
|
398
|
-
for stop_str in stop:
|
399
|
-
stop_pos = self.speculated_text.find(stop_str)
|
400
|
-
if stop_pos != -1 and (pos == -1 or stop_pos < pos):
|
401
|
-
pos = stop_pos
|
402
|
-
stop_len = len(stop_str)
|
403
|
-
return pos, stop_len
|
404
|
-
else:
|
405
|
-
raise Exception("Wrong type of stop in sampling parameters.")
|
406
|
-
|
407
|
-
if stop is None:
|
408
|
-
if len(self.speculated_text) < max_new_tokens:
|
409
|
-
regen()
|
410
|
-
comp = self.speculated_text[:max_new_tokens]
|
411
|
-
self.speculated_text = self.speculated_text[max_new_tokens:]
|
412
|
-
elif isinstance(stop, (str, list, tuple)):
|
413
|
-
if self.speculated_text == "":
|
414
|
-
regen()
|
415
|
-
stop_pos, stop_len = find_stop()
|
416
|
-
if stop_pos == -1:
|
417
|
-
stop_pos, stop_len = (
|
418
|
-
min(
|
419
|
-
sampling_params.max_new_tokens,
|
420
|
-
len(self.speculated_text),
|
421
|
-
),
|
422
|
-
0,
|
423
|
-
)
|
424
|
-
comp = self.speculated_text[:stop_pos]
|
425
|
-
self.speculated_text = self.speculated_text[stop_pos:]
|
426
|
-
else:
|
427
|
-
raise ValueError("Wrong type of stop in sampling parameters.")
|
428
|
-
else:
|
493
|
+
if self.num_api_spec_tokens is None:
|
429
494
|
comp, meta_info = self.backend.generate(
|
430
|
-
self,
|
495
|
+
self,
|
496
|
+
sampling_params=sampling_params,
|
431
497
|
)
|
498
|
+
else:
|
499
|
+
if self.backend.is_chat_model:
|
500
|
+
# Speculative execution on models with only chat interface.
|
501
|
+
# Store the calls into a temporary list.
|
502
|
+
# They will be lazily executed later.
|
503
|
+
comp, meta_info = self.backend.generate(
|
504
|
+
self,
|
505
|
+
sampling_params=sampling_params,
|
506
|
+
spec_var_name=name,
|
507
|
+
)
|
508
|
+
return
|
509
|
+
|
510
|
+
else: # Speculative execution on models with completion interface
|
511
|
+
comp, meta_info = self._spec_gen(sampling_params)
|
432
512
|
|
433
513
|
self.text_ += comp
|
434
514
|
|
@@ -436,13 +516,16 @@ class StreamExecutor:
|
|
436
516
|
self.meta_info[name] = meta_info
|
437
517
|
self.variable_event[name].set()
|
438
518
|
else:
|
519
|
+
assert (
|
520
|
+
self.num_api_spec_tokens is None
|
521
|
+
), "stream is not supported with api speculative execution"
|
439
522
|
generator = self.backend.generate_stream(
|
440
523
|
self, sampling_params=sampling_params
|
441
524
|
)
|
442
525
|
|
526
|
+
self.variables[name] = ""
|
443
527
|
self.stream_var_event[name].set()
|
444
528
|
|
445
|
-
self.variables[name] = ""
|
446
529
|
for comp, meta_info in generator:
|
447
530
|
self.text_ += comp
|
448
531
|
self.variables[name] += comp
|
@@ -454,15 +537,19 @@ class StreamExecutor:
|
|
454
537
|
self.stream_var_event[name].set()
|
455
538
|
|
456
539
|
def _execute_select(self, expr: SglSelect):
|
457
|
-
|
458
|
-
|
459
|
-
|
540
|
+
(
|
541
|
+
decision,
|
542
|
+
normalized_prompt_logprobs,
|
543
|
+
prefill_token_logprobs,
|
544
|
+
decode_token_logprobs,
|
545
|
+
) = self.backend.select(self, expr.choices, expr.temperature)
|
460
546
|
if expr.name is not None:
|
461
547
|
name = expr.name
|
462
548
|
self.variables[name] = decision
|
463
549
|
self.meta_info[name] = {
|
464
|
-
"
|
465
|
-
"
|
550
|
+
"normalized_prompt_logprobs": normalized_prompt_logprobs,
|
551
|
+
"prefill_token_logprobs": prefill_token_logprobs,
|
552
|
+
"decode_token_logprobs": decode_token_logprobs,
|
466
553
|
}
|
467
554
|
self.variable_event[name].set()
|
468
555
|
self.text_ += decision
|
@@ -487,10 +574,19 @@ class StreamExecutor:
|
|
487
574
|
|
488
575
|
prefix, _ = self.chat_template.get_prefix_and_suffix(expr.role, self.messages_)
|
489
576
|
|
490
|
-
self._execute_fill(prefix)
|
577
|
+
self._execute_fill(prefix, prefix=True)
|
491
578
|
self.cur_role_begin_pos = len(self.text_)
|
492
579
|
|
493
580
|
def _execute_role_end(self, expr: SglRoleEnd):
|
581
|
+
if (
|
582
|
+
self.cur_role == "assistant"
|
583
|
+
and self.num_api_spec_tokens is not None
|
584
|
+
and self.backend.is_chat_model
|
585
|
+
):
|
586
|
+
# Execute the stored lazy generation calls
|
587
|
+
self.backend.role_end_generate(self)
|
588
|
+
self.cur_role = None
|
589
|
+
|
494
590
|
new_text = self.text_[self.cur_role_begin_pos :].lstrip()
|
495
591
|
|
496
592
|
_, suffix = self.chat_template.get_prefix_and_suffix(expr.role, self.messages_)
|
@@ -517,8 +613,6 @@ class StreamExecutor:
|
|
517
613
|
# OpenAI chat API format
|
518
614
|
self.messages_.append({"role": expr.role, "content": new_text})
|
519
615
|
|
520
|
-
self.cur_role = None
|
521
|
-
|
522
616
|
def _execute_var_scope_begin(self, expr: SglVarScopeBegin):
|
523
617
|
self.variables[expr.name] = int(len(self.text_))
|
524
618
|
|
@@ -574,6 +668,10 @@ class StreamExecutor:
|
|
574
668
|
"frequency_penalty",
|
575
669
|
"presence_penalty",
|
576
670
|
"ignore_eos",
|
671
|
+
"return_logprob",
|
672
|
+
"logprob_start_len",
|
673
|
+
"top_logprobs_num",
|
674
|
+
"return_text_in_logprobs",
|
577
675
|
"dtype",
|
578
676
|
"regex",
|
579
677
|
]:
|
@@ -634,8 +732,12 @@ class ProgramState:
|
|
634
732
|
yield
|
635
733
|
self.stream_executor.submit(SglVarScopeEnd(name))
|
636
734
|
|
637
|
-
def fork(
|
638
|
-
|
735
|
+
def fork(
|
736
|
+
self,
|
737
|
+
size: int = 1,
|
738
|
+
position_ids_offset: Optional[List[int]] = None,
|
739
|
+
):
|
740
|
+
stream_executors = self.stream_executor.fork(size, position_ids_offset)
|
639
741
|
states = [ProgramState(x) for x in stream_executors]
|
640
742
|
state_group = ProgramStateGroup(states, self)
|
641
743
|
return state_group
|
@@ -657,6 +759,9 @@ class ProgramState:
|
|
657
759
|
def sync(self):
|
658
760
|
return self.stream_executor.sync()
|
659
761
|
|
762
|
+
def error(self):
|
763
|
+
return self.stream_executor.error()
|
764
|
+
|
660
765
|
def text_iter(self, var_name: Optional[str] = None):
|
661
766
|
if self.stream_executor.stream:
|
662
767
|
prev = 0
|
@@ -745,6 +850,9 @@ class ProgramState:
|
|
745
850
|
def __setitem__(self, name, value):
|
746
851
|
self.set_var(name, value)
|
747
852
|
|
853
|
+
def __contains__(self, name):
|
854
|
+
return name in self.stream_executor.variables
|
855
|
+
|
748
856
|
def __del__(self):
|
749
857
|
self.stream_executor.end()
|
750
858
|
|