sglang 0.1.15__py3-none-any.whl → 0.1.17__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +5 -1
- sglang/api.py +8 -3
- sglang/backend/anthropic.py +1 -1
- sglang/backend/litellm.py +90 -0
- sglang/backend/openai.py +148 -12
- sglang/backend/runtime_endpoint.py +18 -10
- sglang/global_config.py +11 -1
- sglang/lang/chat_template.py +9 -2
- sglang/lang/interpreter.py +161 -81
- sglang/lang/ir.py +29 -11
- sglang/lang/tracer.py +1 -1
- sglang/launch_server.py +1 -2
- sglang/launch_server_llavavid.py +31 -0
- sglang/srt/constrained/fsm_cache.py +3 -0
- sglang/srt/flush_cache.py +16 -0
- sglang/srt/hf_transformers_utils.py +83 -2
- sglang/srt/layers/extend_attention.py +17 -0
- sglang/srt/layers/fused_moe.py +485 -0
- sglang/srt/layers/logits_processor.py +12 -7
- sglang/srt/layers/radix_attention.py +10 -3
- sglang/srt/layers/token_attention.py +16 -1
- sglang/srt/managers/controller/dp_worker.py +110 -0
- sglang/srt/managers/controller/infer_batch.py +619 -0
- sglang/srt/managers/controller/manager_multi.py +191 -0
- sglang/srt/managers/controller/manager_single.py +97 -0
- sglang/srt/managers/controller/model_runner.py +462 -0
- sglang/srt/managers/controller/radix_cache.py +267 -0
- sglang/srt/managers/controller/schedule_heuristic.py +59 -0
- sglang/srt/managers/controller/tp_worker.py +791 -0
- sglang/srt/managers/detokenizer_manager.py +45 -45
- sglang/srt/managers/io_struct.py +26 -10
- sglang/srt/managers/router/infer_batch.py +130 -74
- sglang/srt/managers/router/manager.py +7 -9
- sglang/srt/managers/router/model_rpc.py +224 -135
- sglang/srt/managers/router/model_runner.py +94 -107
- sglang/srt/managers/router/radix_cache.py +54 -18
- sglang/srt/managers/router/scheduler.py +23 -34
- sglang/srt/managers/tokenizer_manager.py +183 -88
- sglang/srt/model_config.py +5 -2
- sglang/srt/models/commandr.py +15 -22
- sglang/srt/models/dbrx.py +22 -29
- sglang/srt/models/gemma.py +14 -24
- sglang/srt/models/grok.py +671 -0
- sglang/srt/models/llama2.py +24 -23
- sglang/srt/models/llava.py +85 -25
- sglang/srt/models/llavavid.py +298 -0
- sglang/srt/models/mixtral.py +254 -130
- sglang/srt/models/mixtral_quant.py +373 -0
- sglang/srt/models/qwen.py +28 -25
- sglang/srt/models/qwen2.py +17 -22
- sglang/srt/models/stablelm.py +21 -26
- sglang/srt/models/yivl.py +17 -25
- sglang/srt/openai_api_adapter.py +140 -95
- sglang/srt/openai_protocol.py +10 -1
- sglang/srt/server.py +101 -52
- sglang/srt/server_args.py +59 -11
- sglang/srt/utils.py +242 -75
- sglang/test/test_programs.py +44 -0
- sglang/test/test_utils.py +32 -1
- sglang/utils.py +95 -26
- {sglang-0.1.15.dist-info → sglang-0.1.17.dist-info}/METADATA +23 -13
- sglang-0.1.17.dist-info/RECORD +81 -0
- sglang/srt/backend_config.py +0 -13
- sglang/srt/models/dbrx_config.py +0 -281
- sglang/srt/weight_utils.py +0 -402
- sglang-0.1.15.dist-info/RECORD +0 -69
- {sglang-0.1.15.dist-info → sglang-0.1.17.dist-info}/LICENSE +0 -0
- {sglang-0.1.15.dist-info → sglang-0.1.17.dist-info}/WHEEL +0 -0
- {sglang-0.1.15.dist-info → sglang-0.1.17.dist-info}/top_level.txt +0 -0
sglang/lang/interpreter.py
CHANGED
@@ -6,6 +6,7 @@ import multiprocessing
|
|
6
6
|
import queue
|
7
7
|
import threading
|
8
8
|
import uuid
|
9
|
+
import warnings
|
9
10
|
from concurrent.futures import ThreadPoolExecutor
|
10
11
|
from contextlib import contextmanager
|
11
12
|
from typing import Any, Callable, Dict, List, Optional, Union
|
@@ -28,8 +29,13 @@ from sglang.lang.ir import (
|
|
28
29
|
SglVariable,
|
29
30
|
SglVarScopeBegin,
|
30
31
|
SglVarScopeEnd,
|
32
|
+
SglVideo,
|
33
|
+
)
|
34
|
+
from sglang.utils import (
|
35
|
+
encode_image_base64,
|
36
|
+
encode_video_base64,
|
37
|
+
get_exception_traceback,
|
31
38
|
)
|
32
|
-
from sglang.utils import encode_image_base64
|
33
39
|
|
34
40
|
|
35
41
|
def run_internal(state, program, func_args, func_kwargs, sync):
|
@@ -60,7 +66,7 @@ def run_program(
|
|
60
66
|
default_sampling_para,
|
61
67
|
chat_template=None,
|
62
68
|
stream=stream,
|
63
|
-
|
69
|
+
num_api_spec_tokens=program.num_api_spec_tokens,
|
64
70
|
)
|
65
71
|
state = ProgramState(stream_executor)
|
66
72
|
|
@@ -86,9 +92,9 @@ def run_program_batch(
|
|
86
92
|
if hasattr(backend, "endpoint"):
|
87
93
|
backend = backend.endpoint
|
88
94
|
|
89
|
-
#
|
90
|
-
if len(batch_arguments) > 1:
|
91
|
-
|
95
|
+
# Pre-cache the common prefix for a batch. The prefix is extracted by tracing the program.
|
96
|
+
if global_config.enable_precache_with_tracing and len(batch_arguments) > 1:
|
97
|
+
cache_program(program, backend)
|
92
98
|
|
93
99
|
# Run all programs
|
94
100
|
if num_threads == "auto":
|
@@ -154,21 +160,12 @@ def run_program_batch(
|
|
154
160
|
return rets
|
155
161
|
|
156
162
|
|
157
|
-
def
|
158
|
-
|
159
|
-
# TODO: handle multiple backends
|
160
|
-
from sglang.lang.tracer import extract_prefix_by_tracing
|
161
|
-
|
162
|
-
prefix = extract_prefix_by_tracing(program, backend)
|
163
|
-
if prefix and len(prefix) > 64:
|
164
|
-
prefix_rid = backend.cache_prefix(prefix)
|
165
|
-
program.pin_prefix_rid = prefix_rid
|
166
|
-
return prefix_rid
|
167
|
-
return None
|
163
|
+
def cache_program(program, backend):
|
164
|
+
from sglang.lang.tracer import extract_prefix_by_tracing
|
168
165
|
|
169
|
-
|
170
|
-
|
171
|
-
|
166
|
+
prefix = extract_prefix_by_tracing(program, backend)
|
167
|
+
if prefix and len(prefix) > 64:
|
168
|
+
backend.cache_prefix(prefix)
|
172
169
|
|
173
170
|
|
174
171
|
class StreamExecutor:
|
@@ -181,7 +178,7 @@ class StreamExecutor:
|
|
181
178
|
default_sampling_para,
|
182
179
|
chat_template,
|
183
180
|
stream,
|
184
|
-
|
181
|
+
num_api_spec_tokens=None,
|
185
182
|
use_thread=True,
|
186
183
|
):
|
187
184
|
self.sid = uuid.uuid4().hex
|
@@ -189,19 +186,16 @@ class StreamExecutor:
|
|
189
186
|
self.arguments: Dict[str, Any] = arguments
|
190
187
|
self.default_sampling_para = default_sampling_para
|
191
188
|
self.stream = stream
|
192
|
-
self.api_num_spec_tokens = api_num_spec_tokens
|
193
189
|
|
194
190
|
self.variables = {} # Dict[name: str -> value: str]
|
195
191
|
self.variable_event = {} # Dict[name: str -> event: threading.Event]
|
196
192
|
self.meta_info = {} # Dict[name: str -> info: str]
|
197
193
|
self.is_finished = False
|
194
|
+
self.error_ = None
|
198
195
|
|
199
196
|
# For completion
|
200
197
|
self.text_ = "" # The full text
|
201
198
|
|
202
|
-
# For speculative execution
|
203
|
-
self.speculated_text = ""
|
204
|
-
|
205
199
|
# For chat
|
206
200
|
self.messages_ = [] # The messages in the OpenAI API format
|
207
201
|
self.chat_template = chat_template or self.backend.get_chat_template()
|
@@ -215,6 +209,10 @@ class StreamExecutor:
|
|
215
209
|
# For fork/join
|
216
210
|
self.fork_start_text_pos = None
|
217
211
|
|
212
|
+
# For speculative execution
|
213
|
+
self.num_api_spec_tokens = num_api_spec_tokens
|
214
|
+
self.speculated_text = ""
|
215
|
+
|
218
216
|
# Worker thread
|
219
217
|
self.use_thread = use_thread
|
220
218
|
if self.use_thread:
|
@@ -293,6 +291,8 @@ class StreamExecutor:
|
|
293
291
|
exes[i].fork_start_text_pos = len(self.text_)
|
294
292
|
exes[i].images_ = list(self.images_)
|
295
293
|
|
294
|
+
# TODO(ying): handle API speculative execution
|
295
|
+
|
296
296
|
return exes
|
297
297
|
|
298
298
|
def text(self):
|
@@ -303,6 +303,10 @@ class StreamExecutor:
|
|
303
303
|
self.sync()
|
304
304
|
return self.messages_
|
305
305
|
|
306
|
+
def error(self):
|
307
|
+
self.sync()
|
308
|
+
return self.error_
|
309
|
+
|
306
310
|
def end(self):
|
307
311
|
if self.use_thread:
|
308
312
|
if self.worker.is_alive():
|
@@ -310,17 +314,39 @@ class StreamExecutor:
|
|
310
314
|
self.backend.end_program(self)
|
311
315
|
|
312
316
|
def _thread_worker_func(self):
|
317
|
+
error = None
|
318
|
+
|
313
319
|
while True:
|
314
320
|
expr = self.queue.get()
|
315
321
|
if expr is None:
|
316
322
|
self.queue.task_done()
|
317
323
|
break
|
318
324
|
|
319
|
-
|
325
|
+
try:
|
326
|
+
self._execute(expr)
|
327
|
+
except Exception as e:
|
328
|
+
warnings.warn(f"Error in stream_executor: {get_exception_traceback()}")
|
329
|
+
error = e
|
330
|
+
break
|
320
331
|
self.queue.task_done()
|
321
332
|
if self.stream_text_event:
|
322
333
|
self.stream_text_event.set()
|
323
334
|
|
335
|
+
# Clean the queue and events
|
336
|
+
if error is not None:
|
337
|
+
try:
|
338
|
+
while True:
|
339
|
+
self.queue.task_done()
|
340
|
+
self.queue.get_nowait()
|
341
|
+
except queue.Empty:
|
342
|
+
pass
|
343
|
+
for name in self.variable_event:
|
344
|
+
self.variable_event[name].set()
|
345
|
+
if self.stream_var_event:
|
346
|
+
for name in self.stream_var_event:
|
347
|
+
self.stream_var_event[name].set()
|
348
|
+
self.error_ = error
|
349
|
+
|
324
350
|
if self.stream_text_event:
|
325
351
|
self.stream_text_event.set()
|
326
352
|
|
@@ -347,6 +373,8 @@ class StreamExecutor:
|
|
347
373
|
self._execute_role_end(other)
|
348
374
|
elif isinstance(other, SglImage):
|
349
375
|
self._execute_image(other)
|
376
|
+
elif isinstance(other, SglVideo):
|
377
|
+
self._execute_video(other)
|
350
378
|
elif isinstance(other, SglVariable):
|
351
379
|
self._execute_variable(other)
|
352
380
|
elif isinstance(other, SglVarScopeBegin):
|
@@ -366,12 +394,23 @@ class StreamExecutor:
|
|
366
394
|
else:
|
367
395
|
raise ValueError(f"Unknown type: {type(other)}")
|
368
396
|
|
369
|
-
def _execute_fill(self, value: str):
|
397
|
+
def _execute_fill(self, value: str, prefix=False):
|
370
398
|
value = str(value)
|
399
|
+
|
400
|
+
if (
|
401
|
+
self.cur_role == "assistant"
|
402
|
+
and self.num_api_spec_tokens is not None
|
403
|
+
and self.backend.is_chat_model
|
404
|
+
and not prefix
|
405
|
+
):
|
406
|
+
self.backend.spec_fill(value)
|
407
|
+
return
|
408
|
+
|
371
409
|
if self.speculated_text.startswith(value):
|
372
410
|
self.speculated_text = self.speculated_text[len(value) :]
|
373
411
|
else:
|
374
412
|
self.speculated_text = ""
|
413
|
+
|
375
414
|
self.text_ += value
|
376
415
|
|
377
416
|
def _execute_image(self, expr: SglImage):
|
@@ -383,68 +422,93 @@ class StreamExecutor:
|
|
383
422
|
self.cur_images.append((path, base64_data))
|
384
423
|
self.text_ += self.chat_template.image_token
|
385
424
|
|
425
|
+
def _execute_video(self, expr: SglVideo):
|
426
|
+
path = expr.path
|
427
|
+
num_frames = expr.num_frames
|
428
|
+
|
429
|
+
base64_data = encode_video_base64(path, num_frames)
|
430
|
+
|
431
|
+
self.images_.append((path, base64_data))
|
432
|
+
self.cur_images.append((path, base64_data))
|
433
|
+
self.text_ += self.chat_template.image_token
|
434
|
+
|
386
435
|
# if global_config.eager_fill_image:
|
387
436
|
# self.backend.fill_image(self)
|
388
437
|
|
438
|
+
def _spec_gen(self, sampling_params):
|
439
|
+
stop = sampling_params.stop
|
440
|
+
max_new_tokens = sampling_params.max_new_tokens
|
441
|
+
meta_info = {}
|
442
|
+
|
443
|
+
def regen():
|
444
|
+
nonlocal meta_info
|
445
|
+
|
446
|
+
sampling_params.max_new_tokens = max(
|
447
|
+
sampling_params.max_new_tokens, self.num_api_spec_tokens
|
448
|
+
)
|
449
|
+
sampling_params.stop = None
|
450
|
+
self.speculated_text, meta_info = self.backend.generate(
|
451
|
+
self, sampling_params=sampling_params
|
452
|
+
)
|
453
|
+
|
454
|
+
def find_stop():
|
455
|
+
if isinstance(stop, str):
|
456
|
+
return self.speculated_text.find(stop)
|
457
|
+
elif isinstance(stop, (tuple, list)):
|
458
|
+
pos = -1
|
459
|
+
for stop_str in stop:
|
460
|
+
stop_pos = self.speculated_text.find(stop_str)
|
461
|
+
if stop_pos != -1 and (pos == -1 or stop_pos < pos):
|
462
|
+
pos = stop_pos
|
463
|
+
return pos
|
464
|
+
else:
|
465
|
+
raise Exception("Wrong type of stop in sampling parameters.")
|
466
|
+
|
467
|
+
if stop is None:
|
468
|
+
if len(self.speculated_text) < max_new_tokens:
|
469
|
+
regen()
|
470
|
+
comp = self.speculated_text[:max_new_tokens]
|
471
|
+
self.speculated_text = self.speculated_text[max_new_tokens:]
|
472
|
+
elif isinstance(stop, (str, list, tuple)):
|
473
|
+
if self.speculated_text == "":
|
474
|
+
regen()
|
475
|
+
stop_pos = find_stop()
|
476
|
+
if stop_pos == -1:
|
477
|
+
stop_pos = min(
|
478
|
+
sampling_params.max_new_tokens,
|
479
|
+
len(self.speculated_text),
|
480
|
+
)
|
481
|
+
comp = self.speculated_text[:stop_pos]
|
482
|
+
self.speculated_text = self.speculated_text[stop_pos:]
|
483
|
+
else:
|
484
|
+
raise ValueError("Wrong type of stop in sampling parameters.")
|
485
|
+
|
486
|
+
return comp, meta_info
|
487
|
+
|
389
488
|
def _execute_gen(self, expr: SglGen):
|
390
489
|
sampling_params = self._resolve_sampling_params(expr.sampling_params)
|
391
490
|
name = expr.name
|
392
491
|
|
393
492
|
if not self.stream:
|
394
|
-
if self.
|
395
|
-
stop = sampling_params.stop
|
396
|
-
max_new_tokens = sampling_params.max_new_tokens
|
397
|
-
meta_info = {}
|
398
|
-
|
399
|
-
def regen():
|
400
|
-
sampling_params.max_new_tokens = max(
|
401
|
-
sampling_params.max_new_tokens, self.api_num_spec_tokens
|
402
|
-
)
|
403
|
-
sampling_params.stop = None
|
404
|
-
self.speculated_text, meta_info = self.backend.generate(
|
405
|
-
self, sampling_params=sampling_params
|
406
|
-
)
|
407
|
-
|
408
|
-
def find_stop():
|
409
|
-
if isinstance(stop, str):
|
410
|
-
return self.speculated_text.find(stop), len(stop)
|
411
|
-
elif isinstance(stop, (tuple, list)):
|
412
|
-
pos = -1
|
413
|
-
stop_len = 0
|
414
|
-
for stop_str in stop:
|
415
|
-
stop_pos = self.speculated_text.find(stop_str)
|
416
|
-
if stop_pos != -1 and (pos == -1 or stop_pos < pos):
|
417
|
-
pos = stop_pos
|
418
|
-
stop_len = len(stop_str)
|
419
|
-
return pos, stop_len
|
420
|
-
else:
|
421
|
-
raise Exception("Wrong type of stop in sampling parameters.")
|
422
|
-
|
423
|
-
if stop is None:
|
424
|
-
if len(self.speculated_text) < max_new_tokens:
|
425
|
-
regen()
|
426
|
-
comp = self.speculated_text[:max_new_tokens]
|
427
|
-
self.speculated_text = self.speculated_text[max_new_tokens:]
|
428
|
-
elif isinstance(stop, (str, list, tuple)):
|
429
|
-
if self.speculated_text == "":
|
430
|
-
regen()
|
431
|
-
stop_pos, stop_len = find_stop()
|
432
|
-
if stop_pos == -1:
|
433
|
-
stop_pos, stop_len = (
|
434
|
-
min(
|
435
|
-
sampling_params.max_new_tokens,
|
436
|
-
len(self.speculated_text),
|
437
|
-
),
|
438
|
-
0,
|
439
|
-
)
|
440
|
-
comp = self.speculated_text[:stop_pos]
|
441
|
-
self.speculated_text = self.speculated_text[stop_pos:]
|
442
|
-
else:
|
443
|
-
raise ValueError("Wrong type of stop in sampling parameters.")
|
444
|
-
else:
|
493
|
+
if self.num_api_spec_tokens is None:
|
445
494
|
comp, meta_info = self.backend.generate(
|
446
|
-
self,
|
495
|
+
self,
|
496
|
+
sampling_params=sampling_params,
|
447
497
|
)
|
498
|
+
else:
|
499
|
+
if self.backend.is_chat_model:
|
500
|
+
# Speculative execution on models with only chat interface.
|
501
|
+
# Store the calls into a temporary list.
|
502
|
+
# They will be lazily executed later.
|
503
|
+
comp, meta_info = self.backend.generate(
|
504
|
+
self,
|
505
|
+
sampling_params=sampling_params,
|
506
|
+
spec_var_name=name,
|
507
|
+
)
|
508
|
+
return
|
509
|
+
|
510
|
+
else: # Speculative execution on models with completion interface
|
511
|
+
comp, meta_info = self._spec_gen(sampling_params)
|
448
512
|
|
449
513
|
self.text_ += comp
|
450
514
|
|
@@ -452,6 +516,9 @@ class StreamExecutor:
|
|
452
516
|
self.meta_info[name] = meta_info
|
453
517
|
self.variable_event[name].set()
|
454
518
|
else:
|
519
|
+
assert (
|
520
|
+
self.num_api_spec_tokens is None
|
521
|
+
), "stream is not supported with api speculative execution"
|
455
522
|
generator = self.backend.generate_stream(
|
456
523
|
self, sampling_params=sampling_params
|
457
524
|
)
|
@@ -507,10 +574,19 @@ class StreamExecutor:
|
|
507
574
|
|
508
575
|
prefix, _ = self.chat_template.get_prefix_and_suffix(expr.role, self.messages_)
|
509
576
|
|
510
|
-
self._execute_fill(prefix)
|
577
|
+
self._execute_fill(prefix, prefix=True)
|
511
578
|
self.cur_role_begin_pos = len(self.text_)
|
512
579
|
|
513
580
|
def _execute_role_end(self, expr: SglRoleEnd):
|
581
|
+
if (
|
582
|
+
self.cur_role == "assistant"
|
583
|
+
and self.num_api_spec_tokens is not None
|
584
|
+
and self.backend.is_chat_model
|
585
|
+
):
|
586
|
+
# Execute the stored lazy generation calls
|
587
|
+
self.backend.role_end_generate(self)
|
588
|
+
self.cur_role = None
|
589
|
+
|
514
590
|
new_text = self.text_[self.cur_role_begin_pos :].lstrip()
|
515
591
|
|
516
592
|
_, suffix = self.chat_template.get_prefix_and_suffix(expr.role, self.messages_)
|
@@ -537,8 +613,6 @@ class StreamExecutor:
|
|
537
613
|
# OpenAI chat API format
|
538
614
|
self.messages_.append({"role": expr.role, "content": new_text})
|
539
615
|
|
540
|
-
self.cur_role = None
|
541
|
-
|
542
616
|
def _execute_var_scope_begin(self, expr: SglVarScopeBegin):
|
543
617
|
self.variables[expr.name] = int(len(self.text_))
|
544
618
|
|
@@ -681,6 +755,9 @@ class ProgramState:
|
|
681
755
|
def sync(self):
|
682
756
|
return self.stream_executor.sync()
|
683
757
|
|
758
|
+
def error(self):
|
759
|
+
return self.stream_executor.error()
|
760
|
+
|
684
761
|
def text_iter(self, var_name: Optional[str] = None):
|
685
762
|
if self.stream_executor.stream:
|
686
763
|
prev = 0
|
@@ -769,6 +846,9 @@ class ProgramState:
|
|
769
846
|
def __setitem__(self, name, value):
|
770
847
|
self.set_var(name, value)
|
771
848
|
|
849
|
+
def __contains__(self, name):
|
850
|
+
return name in self.stream_executor.variables
|
851
|
+
|
772
852
|
def __del__(self):
|
773
853
|
self.stream_executor.end()
|
774
854
|
|
sglang/lang/ir.py
CHANGED
@@ -81,6 +81,21 @@ class SglSamplingParams:
|
|
81
81
|
"top_p": self.top_p,
|
82
82
|
"top_k": self.top_k,
|
83
83
|
}
|
84
|
+
|
85
|
+
def to_litellm_kwargs(self):
|
86
|
+
if self.regex is not None:
|
87
|
+
warnings.warn(
|
88
|
+
"Regular expression is not supported in the LiteLLM backend."
|
89
|
+
)
|
90
|
+
return {
|
91
|
+
"max_tokens": self.max_new_tokens,
|
92
|
+
"stop": self.stop or None,
|
93
|
+
"temperature": self.temperature,
|
94
|
+
"top_p": self.top_p,
|
95
|
+
"top_k": self.top_k,
|
96
|
+
"frequency_penalty": self.frequency_penalty,
|
97
|
+
"presence_penalty": self.presence_penalty,
|
98
|
+
}
|
84
99
|
|
85
100
|
def to_srt_kwargs(self):
|
86
101
|
return {
|
@@ -97,9 +112,9 @@ class SglSamplingParams:
|
|
97
112
|
|
98
113
|
|
99
114
|
class SglFunction:
|
100
|
-
def __init__(self, func,
|
115
|
+
def __init__(self, func, num_api_spec_tokens=None, bind_arguments=None):
|
101
116
|
self.func = func
|
102
|
-
self.
|
117
|
+
self.num_api_spec_tokens = num_api_spec_tokens
|
103
118
|
self.bind_arguments = bind_arguments or {}
|
104
119
|
self.pin_prefix_rid = None
|
105
120
|
|
@@ -193,17 +208,11 @@ class SglFunction:
|
|
193
208
|
backend = backend or global_config.default_backend
|
194
209
|
return trace_program(self, kwargs, backend)
|
195
210
|
|
196
|
-
def
|
197
|
-
from sglang.lang.interpreter import
|
198
|
-
|
199
|
-
backend = backend or global_config.default_backend
|
200
|
-
return pin_program(self, backend)
|
201
|
-
|
202
|
-
def unpin(self, backend=None):
|
203
|
-
from sglang.lang.interpreter import unpin_program
|
211
|
+
def cache(self, backend=None):
|
212
|
+
from sglang.lang.interpreter import cache_program
|
204
213
|
|
205
214
|
backend = backend or global_config.default_backend
|
206
|
-
return
|
215
|
+
return cache_program(self, backend)
|
207
216
|
|
208
217
|
def compile(self, *, backend=None):
|
209
218
|
from sglang.lang.compiler import compile_func
|
@@ -336,6 +345,15 @@ class SglImage(SglExpr):
|
|
336
345
|
return f"SglImage({self.path})"
|
337
346
|
|
338
347
|
|
348
|
+
class SglVideo(SglExpr):
|
349
|
+
def __init__(self, path, num_frames):
|
350
|
+
self.path = path
|
351
|
+
self.num_frames = num_frames
|
352
|
+
|
353
|
+
def __repr__(self) -> str:
|
354
|
+
return f"SglVideo({self.path}, {self.num_frames})"
|
355
|
+
|
356
|
+
|
339
357
|
class SglGen(SglExpr):
|
340
358
|
def __init__(
|
341
359
|
self,
|
sglang/lang/tracer.py
CHANGED
@@ -110,7 +110,7 @@ class TracerProgramState(ProgramState):
|
|
110
110
|
##################################
|
111
111
|
|
112
112
|
def fork(self, size: int = 1, position_ids_offset: Optional[List[int]] = None):
|
113
|
-
assert
|
113
|
+
assert size >= 1
|
114
114
|
|
115
115
|
if self.only_trace_prefix:
|
116
116
|
raise StopTracing()
|
sglang/launch_server.py
CHANGED
@@ -2,11 +2,10 @@ import argparse
|
|
2
2
|
|
3
3
|
from sglang.srt.server import ServerArgs, launch_server
|
4
4
|
|
5
|
-
|
6
5
|
if __name__ == "__main__":
|
7
6
|
parser = argparse.ArgumentParser()
|
8
7
|
ServerArgs.add_cli_args(parser)
|
9
8
|
args = parser.parse_args()
|
10
9
|
server_args = ServerArgs.from_cli_args(args)
|
11
10
|
|
12
|
-
launch_server(server_args, None)
|
11
|
+
launch_server(server_args, None)
|
@@ -0,0 +1,31 @@
|
|
1
|
+
import argparse
|
2
|
+
import multiprocessing as mp
|
3
|
+
|
4
|
+
from sglang.srt.server import ServerArgs, launch_server
|
5
|
+
|
6
|
+
if __name__ == "__main__":
|
7
|
+
|
8
|
+
model_overide_args = {}
|
9
|
+
|
10
|
+
model_overide_args["mm_spatial_pool_stride"] = 2
|
11
|
+
model_overide_args["architectures"] = ["LlavaVidForCausalLM"]
|
12
|
+
model_overide_args["num_frames"] = 16
|
13
|
+
model_overide_args["model_type"] = "llavavid"
|
14
|
+
if model_overide_args["num_frames"] == 32:
|
15
|
+
model_overide_args["rope_scaling"] = {"factor": 2.0, "type": "linear"}
|
16
|
+
model_overide_args["max_sequence_length"] = 4096 * 2
|
17
|
+
model_overide_args["tokenizer_model_max_length"] = 4096 * 2
|
18
|
+
model_overide_args["model_max_length"] = 4096 * 2
|
19
|
+
|
20
|
+
parser = argparse.ArgumentParser()
|
21
|
+
ServerArgs.add_cli_args(parser)
|
22
|
+
args = parser.parse_args()
|
23
|
+
|
24
|
+
if "34b" in args.model_path.lower():
|
25
|
+
model_overide_args["image_token_index"] = 64002
|
26
|
+
|
27
|
+
server_args = ServerArgs.from_cli_args(args)
|
28
|
+
|
29
|
+
pipe_reader, pipe_writer = mp.Pipe(duplex=False)
|
30
|
+
|
31
|
+
launch_server(server_args, pipe_writer, model_overide_args)
|
@@ -6,6 +6,9 @@ class FSMCache(BaseCache):
|
|
6
6
|
def __init__(self, tokenizer_path, tokenizer_args_dict, enable=True):
|
7
7
|
super().__init__(enable=enable)
|
8
8
|
|
9
|
+
if tokenizer_path.endswith(".json"):
|
10
|
+
return
|
11
|
+
|
9
12
|
from importlib.metadata import version
|
10
13
|
|
11
14
|
if version("outlines") >= "0.0.35":
|
@@ -0,0 +1,16 @@
|
|
1
|
+
"""
|
2
|
+
Usage:
|
3
|
+
python3 -m sglang.srt.flush_cache --url http://localhost:30000
|
4
|
+
"""
|
5
|
+
|
6
|
+
import argparse
|
7
|
+
|
8
|
+
import requests
|
9
|
+
|
10
|
+
if __name__ == "__main__":
|
11
|
+
parser = argparse.ArgumentParser()
|
12
|
+
parser.add_argument("--url", type=str, default="http://localhost:30000")
|
13
|
+
args = parser.parse_args()
|
14
|
+
|
15
|
+
response = requests.get(args.url + "/flush_cache")
|
16
|
+
assert response.status_code == 200
|
@@ -3,7 +3,8 @@
|
|
3
3
|
import json
|
4
4
|
import os
|
5
5
|
import warnings
|
6
|
-
|
6
|
+
import functools
|
7
|
+
from typing import Optional, Union, AbstractSet, Collection, Literal
|
7
8
|
|
8
9
|
from huggingface_hub import snapshot_download
|
9
10
|
from transformers import (
|
@@ -30,10 +31,17 @@ def get_config_json(model_path: str):
|
|
30
31
|
return config
|
31
32
|
|
32
33
|
|
33
|
-
def get_config(
|
34
|
+
def get_config(
|
35
|
+
model: str,
|
36
|
+
trust_remote_code: bool,
|
37
|
+
revision: Optional[str] = None,
|
38
|
+
model_overide_args: Optional[dict] = None,
|
39
|
+
):
|
34
40
|
config = AutoConfig.from_pretrained(
|
35
41
|
model, trust_remote_code=trust_remote_code, revision=revision
|
36
42
|
)
|
43
|
+
if model_overide_args:
|
44
|
+
config.update(model_overide_args)
|
37
45
|
return config
|
38
46
|
|
39
47
|
|
@@ -77,6 +85,9 @@ def get_tokenizer(
|
|
77
85
|
tokenizer_revision: Optional[str] = None,
|
78
86
|
**kwargs,
|
79
87
|
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
|
88
|
+
if tokenizer_name.endswith(".json"):
|
89
|
+
return TiktokenTokenizer(tokenizer_name)
|
90
|
+
|
80
91
|
"""Gets a tokenizer for the given model name via Huggingface."""
|
81
92
|
if is_multimodal_model(tokenizer_name):
|
82
93
|
processor = get_processor(
|
@@ -163,3 +174,73 @@ def get_processor(
|
|
163
174
|
**kwargs,
|
164
175
|
)
|
165
176
|
return processor
|
177
|
+
|
178
|
+
|
179
|
+
class TiktokenTokenizer:
|
180
|
+
def __init__(self, tokenizer_path):
|
181
|
+
import tiktoken
|
182
|
+
PAT_STR_B = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
|
183
|
+
|
184
|
+
# Read JSON
|
185
|
+
name = "tmp-json"
|
186
|
+
with open(tokenizer_path, "rb") as fin:
|
187
|
+
tok_dict = json.load(fin)
|
188
|
+
|
189
|
+
mergeable_ranks = {
|
190
|
+
bytes(item["bytes"]): item["token"] for item in tok_dict["regular_tokens"]
|
191
|
+
}
|
192
|
+
special_tokens = {
|
193
|
+
bytes(item["bytes"]).decode(): item["token"] for item in tok_dict["special_tokens"]
|
194
|
+
}
|
195
|
+
assert tok_dict["word_split"] == "V1"
|
196
|
+
|
197
|
+
kwargs = {
|
198
|
+
"name": name,
|
199
|
+
"pat_str": tok_dict.get("pat_str", PAT_STR_B),
|
200
|
+
"mergeable_ranks": mergeable_ranks,
|
201
|
+
"special_tokens": special_tokens,
|
202
|
+
}
|
203
|
+
if "default_allowed_special" in tok_dict:
|
204
|
+
default_allowed_special = set(
|
205
|
+
[bytes(bytes_list).decode() for bytes_list in tok_dict["default_allowed_special"]]
|
206
|
+
)
|
207
|
+
else:
|
208
|
+
default_allowed_special = None
|
209
|
+
if "vocab_size" in tok_dict:
|
210
|
+
kwargs["explicit_n_vocab"] = tok_dict["vocab_size"]
|
211
|
+
|
212
|
+
tokenizer = tiktoken.Encoding(**kwargs)
|
213
|
+
tokenizer._default_allowed_special = default_allowed_special or set()
|
214
|
+
|
215
|
+
def encode_patched(
|
216
|
+
self,
|
217
|
+
text: str,
|
218
|
+
*,
|
219
|
+
allowed_special: Union[Literal["all"], AbstractSet[str]] = set(), # noqa: B006
|
220
|
+
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
|
221
|
+
) -> list[int]:
|
222
|
+
if isinstance(allowed_special, set):
|
223
|
+
allowed_special |= self._default_allowed_special
|
224
|
+
return tiktoken.Encoding.encode(
|
225
|
+
self, text, allowed_special=allowed_special, disallowed_special=disallowed_special
|
226
|
+
)
|
227
|
+
tokenizer.encode = functools.partial(encode_patched, tokenizer)
|
228
|
+
|
229
|
+
# Convert to HF interface
|
230
|
+
self.tokenizer = tokenizer
|
231
|
+
self.eos_token_id = tokenizer._special_tokens["<|eos|>"]
|
232
|
+
self.vocab_size = tokenizer.n_vocab
|
233
|
+
|
234
|
+
def encode(self, x, add_special_tokens=False):
|
235
|
+
return self.tokenizer.encode(x)
|
236
|
+
|
237
|
+
def decode(self, x):
|
238
|
+
return self.tokenizer.decode(x)
|
239
|
+
|
240
|
+
def batch_decode(self, batch, skip_special_tokens=True, spaces_between_special_tokens=False):
|
241
|
+
if isinstance(batch[0], int):
|
242
|
+
batch = [[x] for x in batch]
|
243
|
+
return self.tokenizer.decode_batch(batch)
|
244
|
+
|
245
|
+
def convert_ids_to_tokens(self, index):
|
246
|
+
return self.tokenizer.decode_single_token_bytes(index).decode("utf-8", errors="ignore")
|