sglang 0.1.15__py3-none-any.whl → 0.1.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. sglang/__init__.py +5 -1
  2. sglang/api.py +8 -3
  3. sglang/backend/anthropic.py +1 -1
  4. sglang/backend/litellm.py +90 -0
  5. sglang/backend/openai.py +148 -12
  6. sglang/backend/runtime_endpoint.py +18 -10
  7. sglang/global_config.py +11 -1
  8. sglang/lang/chat_template.py +9 -2
  9. sglang/lang/interpreter.py +161 -81
  10. sglang/lang/ir.py +29 -11
  11. sglang/lang/tracer.py +1 -1
  12. sglang/launch_server.py +1 -2
  13. sglang/launch_server_llavavid.py +31 -0
  14. sglang/srt/constrained/fsm_cache.py +3 -0
  15. sglang/srt/flush_cache.py +16 -0
  16. sglang/srt/hf_transformers_utils.py +83 -2
  17. sglang/srt/layers/extend_attention.py +17 -0
  18. sglang/srt/layers/fused_moe.py +485 -0
  19. sglang/srt/layers/logits_processor.py +12 -7
  20. sglang/srt/layers/radix_attention.py +10 -3
  21. sglang/srt/layers/token_attention.py +16 -1
  22. sglang/srt/managers/controller/dp_worker.py +110 -0
  23. sglang/srt/managers/controller/infer_batch.py +619 -0
  24. sglang/srt/managers/controller/manager_multi.py +191 -0
  25. sglang/srt/managers/controller/manager_single.py +97 -0
  26. sglang/srt/managers/controller/model_runner.py +462 -0
  27. sglang/srt/managers/controller/radix_cache.py +267 -0
  28. sglang/srt/managers/controller/schedule_heuristic.py +59 -0
  29. sglang/srt/managers/controller/tp_worker.py +791 -0
  30. sglang/srt/managers/detokenizer_manager.py +45 -45
  31. sglang/srt/managers/io_struct.py +26 -10
  32. sglang/srt/managers/router/infer_batch.py +130 -74
  33. sglang/srt/managers/router/manager.py +7 -9
  34. sglang/srt/managers/router/model_rpc.py +224 -135
  35. sglang/srt/managers/router/model_runner.py +94 -107
  36. sglang/srt/managers/router/radix_cache.py +54 -18
  37. sglang/srt/managers/router/scheduler.py +23 -34
  38. sglang/srt/managers/tokenizer_manager.py +183 -88
  39. sglang/srt/model_config.py +5 -2
  40. sglang/srt/models/commandr.py +15 -22
  41. sglang/srt/models/dbrx.py +22 -29
  42. sglang/srt/models/gemma.py +14 -24
  43. sglang/srt/models/grok.py +671 -0
  44. sglang/srt/models/llama2.py +24 -23
  45. sglang/srt/models/llava.py +85 -25
  46. sglang/srt/models/llavavid.py +298 -0
  47. sglang/srt/models/mixtral.py +254 -130
  48. sglang/srt/models/mixtral_quant.py +373 -0
  49. sglang/srt/models/qwen.py +28 -25
  50. sglang/srt/models/qwen2.py +17 -22
  51. sglang/srt/models/stablelm.py +21 -26
  52. sglang/srt/models/yivl.py +17 -25
  53. sglang/srt/openai_api_adapter.py +140 -95
  54. sglang/srt/openai_protocol.py +10 -1
  55. sglang/srt/server.py +101 -52
  56. sglang/srt/server_args.py +59 -11
  57. sglang/srt/utils.py +242 -75
  58. sglang/test/test_programs.py +44 -0
  59. sglang/test/test_utils.py +32 -1
  60. sglang/utils.py +95 -26
  61. {sglang-0.1.15.dist-info → sglang-0.1.17.dist-info}/METADATA +23 -13
  62. sglang-0.1.17.dist-info/RECORD +81 -0
  63. sglang/srt/backend_config.py +0 -13
  64. sglang/srt/models/dbrx_config.py +0 -281
  65. sglang/srt/weight_utils.py +0 -402
  66. sglang-0.1.15.dist-info/RECORD +0 -69
  67. {sglang-0.1.15.dist-info → sglang-0.1.17.dist-info}/LICENSE +0 -0
  68. {sglang-0.1.15.dist-info → sglang-0.1.17.dist-info}/WHEEL +0 -0
  69. {sglang-0.1.15.dist-info → sglang-0.1.17.dist-info}/top_level.txt +0 -0
@@ -6,6 +6,7 @@ import multiprocessing
6
6
  import queue
7
7
  import threading
8
8
  import uuid
9
+ import warnings
9
10
  from concurrent.futures import ThreadPoolExecutor
10
11
  from contextlib import contextmanager
11
12
  from typing import Any, Callable, Dict, List, Optional, Union
@@ -28,8 +29,13 @@ from sglang.lang.ir import (
28
29
  SglVariable,
29
30
  SglVarScopeBegin,
30
31
  SglVarScopeEnd,
32
+ SglVideo,
33
+ )
34
+ from sglang.utils import (
35
+ encode_image_base64,
36
+ encode_video_base64,
37
+ get_exception_traceback,
31
38
  )
32
- from sglang.utils import encode_image_base64
33
39
 
34
40
 
35
41
  def run_internal(state, program, func_args, func_kwargs, sync):
@@ -60,7 +66,7 @@ def run_program(
60
66
  default_sampling_para,
61
67
  chat_template=None,
62
68
  stream=stream,
63
- api_num_spec_tokens=program.api_num_spec_tokens,
69
+ num_api_spec_tokens=program.num_api_spec_tokens,
64
70
  )
65
71
  state = ProgramState(stream_executor)
66
72
 
@@ -86,9 +92,9 @@ def run_program_batch(
86
92
  if hasattr(backend, "endpoint"):
87
93
  backend = backend.endpoint
88
94
 
89
- # Extract prefix by tracing and cache it
90
- if len(batch_arguments) > 1:
91
- pin_program(program, backend)
95
+ # Pre-cache the common prefix for a batch. The prefix is extracted by tracing the program.
96
+ if global_config.enable_precache_with_tracing and len(batch_arguments) > 1:
97
+ cache_program(program, backend)
92
98
 
93
99
  # Run all programs
94
100
  if num_threads == "auto":
@@ -154,21 +160,12 @@ def run_program_batch(
154
160
  return rets
155
161
 
156
162
 
157
- def pin_program(program, backend):
158
- if global_config.enable_prefix_sharing and program.pin_prefix_rid is None:
159
- # TODO: handle multiple backends
160
- from sglang.lang.tracer import extract_prefix_by_tracing
161
-
162
- prefix = extract_prefix_by_tracing(program, backend)
163
- if prefix and len(prefix) > 64:
164
- prefix_rid = backend.cache_prefix(prefix)
165
- program.pin_prefix_rid = prefix_rid
166
- return prefix_rid
167
- return None
163
+ def cache_program(program, backend):
164
+ from sglang.lang.tracer import extract_prefix_by_tracing
168
165
 
169
-
170
- def unpin_program(program, backend):
171
- pass
166
+ prefix = extract_prefix_by_tracing(program, backend)
167
+ if prefix and len(prefix) > 64:
168
+ backend.cache_prefix(prefix)
172
169
 
173
170
 
174
171
  class StreamExecutor:
@@ -181,7 +178,7 @@ class StreamExecutor:
181
178
  default_sampling_para,
182
179
  chat_template,
183
180
  stream,
184
- api_num_spec_tokens=None,
181
+ num_api_spec_tokens=None,
185
182
  use_thread=True,
186
183
  ):
187
184
  self.sid = uuid.uuid4().hex
@@ -189,19 +186,16 @@ class StreamExecutor:
189
186
  self.arguments: Dict[str, Any] = arguments
190
187
  self.default_sampling_para = default_sampling_para
191
188
  self.stream = stream
192
- self.api_num_spec_tokens = api_num_spec_tokens
193
189
 
194
190
  self.variables = {} # Dict[name: str -> value: str]
195
191
  self.variable_event = {} # Dict[name: str -> event: threading.Event]
196
192
  self.meta_info = {} # Dict[name: str -> info: str]
197
193
  self.is_finished = False
194
+ self.error_ = None
198
195
 
199
196
  # For completion
200
197
  self.text_ = "" # The full text
201
198
 
202
- # For speculative execution
203
- self.speculated_text = ""
204
-
205
199
  # For chat
206
200
  self.messages_ = [] # The messages in the OpenAI API format
207
201
  self.chat_template = chat_template or self.backend.get_chat_template()
@@ -215,6 +209,10 @@ class StreamExecutor:
215
209
  # For fork/join
216
210
  self.fork_start_text_pos = None
217
211
 
212
+ # For speculative execution
213
+ self.num_api_spec_tokens = num_api_spec_tokens
214
+ self.speculated_text = ""
215
+
218
216
  # Worker thread
219
217
  self.use_thread = use_thread
220
218
  if self.use_thread:
@@ -293,6 +291,8 @@ class StreamExecutor:
293
291
  exes[i].fork_start_text_pos = len(self.text_)
294
292
  exes[i].images_ = list(self.images_)
295
293
 
294
+ # TODO(ying): handle API speculative execution
295
+
296
296
  return exes
297
297
 
298
298
  def text(self):
@@ -303,6 +303,10 @@ class StreamExecutor:
303
303
  self.sync()
304
304
  return self.messages_
305
305
 
306
+ def error(self):
307
+ self.sync()
308
+ return self.error_
309
+
306
310
  def end(self):
307
311
  if self.use_thread:
308
312
  if self.worker.is_alive():
@@ -310,17 +314,39 @@ class StreamExecutor:
310
314
  self.backend.end_program(self)
311
315
 
312
316
  def _thread_worker_func(self):
317
+ error = None
318
+
313
319
  while True:
314
320
  expr = self.queue.get()
315
321
  if expr is None:
316
322
  self.queue.task_done()
317
323
  break
318
324
 
319
- self._execute(expr)
325
+ try:
326
+ self._execute(expr)
327
+ except Exception as e:
328
+ warnings.warn(f"Error in stream_executor: {get_exception_traceback()}")
329
+ error = e
330
+ break
320
331
  self.queue.task_done()
321
332
  if self.stream_text_event:
322
333
  self.stream_text_event.set()
323
334
 
335
+ # Clean the queue and events
336
+ if error is not None:
337
+ try:
338
+ while True:
339
+ self.queue.task_done()
340
+ self.queue.get_nowait()
341
+ except queue.Empty:
342
+ pass
343
+ for name in self.variable_event:
344
+ self.variable_event[name].set()
345
+ if self.stream_var_event:
346
+ for name in self.stream_var_event:
347
+ self.stream_var_event[name].set()
348
+ self.error_ = error
349
+
324
350
  if self.stream_text_event:
325
351
  self.stream_text_event.set()
326
352
 
@@ -347,6 +373,8 @@ class StreamExecutor:
347
373
  self._execute_role_end(other)
348
374
  elif isinstance(other, SglImage):
349
375
  self._execute_image(other)
376
+ elif isinstance(other, SglVideo):
377
+ self._execute_video(other)
350
378
  elif isinstance(other, SglVariable):
351
379
  self._execute_variable(other)
352
380
  elif isinstance(other, SglVarScopeBegin):
@@ -366,12 +394,23 @@ class StreamExecutor:
366
394
  else:
367
395
  raise ValueError(f"Unknown type: {type(other)}")
368
396
 
369
- def _execute_fill(self, value: str):
397
+ def _execute_fill(self, value: str, prefix=False):
370
398
  value = str(value)
399
+
400
+ if (
401
+ self.cur_role == "assistant"
402
+ and self.num_api_spec_tokens is not None
403
+ and self.backend.is_chat_model
404
+ and not prefix
405
+ ):
406
+ self.backend.spec_fill(value)
407
+ return
408
+
371
409
  if self.speculated_text.startswith(value):
372
410
  self.speculated_text = self.speculated_text[len(value) :]
373
411
  else:
374
412
  self.speculated_text = ""
413
+
375
414
  self.text_ += value
376
415
 
377
416
  def _execute_image(self, expr: SglImage):
@@ -383,68 +422,93 @@ class StreamExecutor:
383
422
  self.cur_images.append((path, base64_data))
384
423
  self.text_ += self.chat_template.image_token
385
424
 
425
+ def _execute_video(self, expr: SglVideo):
426
+ path = expr.path
427
+ num_frames = expr.num_frames
428
+
429
+ base64_data = encode_video_base64(path, num_frames)
430
+
431
+ self.images_.append((path, base64_data))
432
+ self.cur_images.append((path, base64_data))
433
+ self.text_ += self.chat_template.image_token
434
+
386
435
  # if global_config.eager_fill_image:
387
436
  # self.backend.fill_image(self)
388
437
 
438
+ def _spec_gen(self, sampling_params):
439
+ stop = sampling_params.stop
440
+ max_new_tokens = sampling_params.max_new_tokens
441
+ meta_info = {}
442
+
443
+ def regen():
444
+ nonlocal meta_info
445
+
446
+ sampling_params.max_new_tokens = max(
447
+ sampling_params.max_new_tokens, self.num_api_spec_tokens
448
+ )
449
+ sampling_params.stop = None
450
+ self.speculated_text, meta_info = self.backend.generate(
451
+ self, sampling_params=sampling_params
452
+ )
453
+
454
+ def find_stop():
455
+ if isinstance(stop, str):
456
+ return self.speculated_text.find(stop)
457
+ elif isinstance(stop, (tuple, list)):
458
+ pos = -1
459
+ for stop_str in stop:
460
+ stop_pos = self.speculated_text.find(stop_str)
461
+ if stop_pos != -1 and (pos == -1 or stop_pos < pos):
462
+ pos = stop_pos
463
+ return pos
464
+ else:
465
+ raise Exception("Wrong type of stop in sampling parameters.")
466
+
467
+ if stop is None:
468
+ if len(self.speculated_text) < max_new_tokens:
469
+ regen()
470
+ comp = self.speculated_text[:max_new_tokens]
471
+ self.speculated_text = self.speculated_text[max_new_tokens:]
472
+ elif isinstance(stop, (str, list, tuple)):
473
+ if self.speculated_text == "":
474
+ regen()
475
+ stop_pos = find_stop()
476
+ if stop_pos == -1:
477
+ stop_pos = min(
478
+ sampling_params.max_new_tokens,
479
+ len(self.speculated_text),
480
+ )
481
+ comp = self.speculated_text[:stop_pos]
482
+ self.speculated_text = self.speculated_text[stop_pos:]
483
+ else:
484
+ raise ValueError("Wrong type of stop in sampling parameters.")
485
+
486
+ return comp, meta_info
487
+
389
488
  def _execute_gen(self, expr: SglGen):
390
489
  sampling_params = self._resolve_sampling_params(expr.sampling_params)
391
490
  name = expr.name
392
491
 
393
492
  if not self.stream:
394
- if self.api_num_spec_tokens is not None:
395
- stop = sampling_params.stop
396
- max_new_tokens = sampling_params.max_new_tokens
397
- meta_info = {}
398
-
399
- def regen():
400
- sampling_params.max_new_tokens = max(
401
- sampling_params.max_new_tokens, self.api_num_spec_tokens
402
- )
403
- sampling_params.stop = None
404
- self.speculated_text, meta_info = self.backend.generate(
405
- self, sampling_params=sampling_params
406
- )
407
-
408
- def find_stop():
409
- if isinstance(stop, str):
410
- return self.speculated_text.find(stop), len(stop)
411
- elif isinstance(stop, (tuple, list)):
412
- pos = -1
413
- stop_len = 0
414
- for stop_str in stop:
415
- stop_pos = self.speculated_text.find(stop_str)
416
- if stop_pos != -1 and (pos == -1 or stop_pos < pos):
417
- pos = stop_pos
418
- stop_len = len(stop_str)
419
- return pos, stop_len
420
- else:
421
- raise Exception("Wrong type of stop in sampling parameters.")
422
-
423
- if stop is None:
424
- if len(self.speculated_text) < max_new_tokens:
425
- regen()
426
- comp = self.speculated_text[:max_new_tokens]
427
- self.speculated_text = self.speculated_text[max_new_tokens:]
428
- elif isinstance(stop, (str, list, tuple)):
429
- if self.speculated_text == "":
430
- regen()
431
- stop_pos, stop_len = find_stop()
432
- if stop_pos == -1:
433
- stop_pos, stop_len = (
434
- min(
435
- sampling_params.max_new_tokens,
436
- len(self.speculated_text),
437
- ),
438
- 0,
439
- )
440
- comp = self.speculated_text[:stop_pos]
441
- self.speculated_text = self.speculated_text[stop_pos:]
442
- else:
443
- raise ValueError("Wrong type of stop in sampling parameters.")
444
- else:
493
+ if self.num_api_spec_tokens is None:
445
494
  comp, meta_info = self.backend.generate(
446
- self, sampling_params=sampling_params
495
+ self,
496
+ sampling_params=sampling_params,
447
497
  )
498
+ else:
499
+ if self.backend.is_chat_model:
500
+ # Speculative execution on models with only chat interface.
501
+ # Store the calls into a temporary list.
502
+ # They will be lazily executed later.
503
+ comp, meta_info = self.backend.generate(
504
+ self,
505
+ sampling_params=sampling_params,
506
+ spec_var_name=name,
507
+ )
508
+ return
509
+
510
+ else: # Speculative execution on models with completion interface
511
+ comp, meta_info = self._spec_gen(sampling_params)
448
512
 
449
513
  self.text_ += comp
450
514
 
@@ -452,6 +516,9 @@ class StreamExecutor:
452
516
  self.meta_info[name] = meta_info
453
517
  self.variable_event[name].set()
454
518
  else:
519
+ assert (
520
+ self.num_api_spec_tokens is None
521
+ ), "stream is not supported with api speculative execution"
455
522
  generator = self.backend.generate_stream(
456
523
  self, sampling_params=sampling_params
457
524
  )
@@ -507,10 +574,19 @@ class StreamExecutor:
507
574
 
508
575
  prefix, _ = self.chat_template.get_prefix_and_suffix(expr.role, self.messages_)
509
576
 
510
- self._execute_fill(prefix)
577
+ self._execute_fill(prefix, prefix=True)
511
578
  self.cur_role_begin_pos = len(self.text_)
512
579
 
513
580
  def _execute_role_end(self, expr: SglRoleEnd):
581
+ if (
582
+ self.cur_role == "assistant"
583
+ and self.num_api_spec_tokens is not None
584
+ and self.backend.is_chat_model
585
+ ):
586
+ # Execute the stored lazy generation calls
587
+ self.backend.role_end_generate(self)
588
+ self.cur_role = None
589
+
514
590
  new_text = self.text_[self.cur_role_begin_pos :].lstrip()
515
591
 
516
592
  _, suffix = self.chat_template.get_prefix_and_suffix(expr.role, self.messages_)
@@ -537,8 +613,6 @@ class StreamExecutor:
537
613
  # OpenAI chat API format
538
614
  self.messages_.append({"role": expr.role, "content": new_text})
539
615
 
540
- self.cur_role = None
541
-
542
616
  def _execute_var_scope_begin(self, expr: SglVarScopeBegin):
543
617
  self.variables[expr.name] = int(len(self.text_))
544
618
 
@@ -681,6 +755,9 @@ class ProgramState:
681
755
  def sync(self):
682
756
  return self.stream_executor.sync()
683
757
 
758
+ def error(self):
759
+ return self.stream_executor.error()
760
+
684
761
  def text_iter(self, var_name: Optional[str] = None):
685
762
  if self.stream_executor.stream:
686
763
  prev = 0
@@ -769,6 +846,9 @@ class ProgramState:
769
846
  def __setitem__(self, name, value):
770
847
  self.set_var(name, value)
771
848
 
849
+ def __contains__(self, name):
850
+ return name in self.stream_executor.variables
851
+
772
852
  def __del__(self):
773
853
  self.stream_executor.end()
774
854
 
sglang/lang/ir.py CHANGED
@@ -81,6 +81,21 @@ class SglSamplingParams:
81
81
  "top_p": self.top_p,
82
82
  "top_k": self.top_k,
83
83
  }
84
+
85
+ def to_litellm_kwargs(self):
86
+ if self.regex is not None:
87
+ warnings.warn(
88
+ "Regular expression is not supported in the LiteLLM backend."
89
+ )
90
+ return {
91
+ "max_tokens": self.max_new_tokens,
92
+ "stop": self.stop or None,
93
+ "temperature": self.temperature,
94
+ "top_p": self.top_p,
95
+ "top_k": self.top_k,
96
+ "frequency_penalty": self.frequency_penalty,
97
+ "presence_penalty": self.presence_penalty,
98
+ }
84
99
 
85
100
  def to_srt_kwargs(self):
86
101
  return {
@@ -97,9 +112,9 @@ class SglSamplingParams:
97
112
 
98
113
 
99
114
  class SglFunction:
100
- def __init__(self, func, api_num_spec_tokens=None, bind_arguments=None):
115
+ def __init__(self, func, num_api_spec_tokens=None, bind_arguments=None):
101
116
  self.func = func
102
- self.api_num_spec_tokens = api_num_spec_tokens
117
+ self.num_api_spec_tokens = num_api_spec_tokens
103
118
  self.bind_arguments = bind_arguments or {}
104
119
  self.pin_prefix_rid = None
105
120
 
@@ -193,17 +208,11 @@ class SglFunction:
193
208
  backend = backend or global_config.default_backend
194
209
  return trace_program(self, kwargs, backend)
195
210
 
196
- def pin(self, backend=None):
197
- from sglang.lang.interpreter import pin_program
198
-
199
- backend = backend or global_config.default_backend
200
- return pin_program(self, backend)
201
-
202
- def unpin(self, backend=None):
203
- from sglang.lang.interpreter import unpin_program
211
+ def cache(self, backend=None):
212
+ from sglang.lang.interpreter import cache_program
204
213
 
205
214
  backend = backend or global_config.default_backend
206
- return unpin_program(self, backend)
215
+ return cache_program(self, backend)
207
216
 
208
217
  def compile(self, *, backend=None):
209
218
  from sglang.lang.compiler import compile_func
@@ -336,6 +345,15 @@ class SglImage(SglExpr):
336
345
  return f"SglImage({self.path})"
337
346
 
338
347
 
348
+ class SglVideo(SglExpr):
349
+ def __init__(self, path, num_frames):
350
+ self.path = path
351
+ self.num_frames = num_frames
352
+
353
+ def __repr__(self) -> str:
354
+ return f"SglVideo({self.path}, {self.num_frames})"
355
+
356
+
339
357
  class SglGen(SglExpr):
340
358
  def __init__(
341
359
  self,
sglang/lang/tracer.py CHANGED
@@ -110,7 +110,7 @@ class TracerProgramState(ProgramState):
110
110
  ##################################
111
111
 
112
112
  def fork(self, size: int = 1, position_ids_offset: Optional[List[int]] = None):
113
- assert (size >= 1)
113
+ assert size >= 1
114
114
 
115
115
  if self.only_trace_prefix:
116
116
  raise StopTracing()
sglang/launch_server.py CHANGED
@@ -2,11 +2,10 @@ import argparse
2
2
 
3
3
  from sglang.srt.server import ServerArgs, launch_server
4
4
 
5
-
6
5
  if __name__ == "__main__":
7
6
  parser = argparse.ArgumentParser()
8
7
  ServerArgs.add_cli_args(parser)
9
8
  args = parser.parse_args()
10
9
  server_args = ServerArgs.from_cli_args(args)
11
10
 
12
- launch_server(server_args, None)
11
+ launch_server(server_args, None)
@@ -0,0 +1,31 @@
1
+ import argparse
2
+ import multiprocessing as mp
3
+
4
+ from sglang.srt.server import ServerArgs, launch_server
5
+
6
+ if __name__ == "__main__":
7
+
8
+ model_overide_args = {}
9
+
10
+ model_overide_args["mm_spatial_pool_stride"] = 2
11
+ model_overide_args["architectures"] = ["LlavaVidForCausalLM"]
12
+ model_overide_args["num_frames"] = 16
13
+ model_overide_args["model_type"] = "llavavid"
14
+ if model_overide_args["num_frames"] == 32:
15
+ model_overide_args["rope_scaling"] = {"factor": 2.0, "type": "linear"}
16
+ model_overide_args["max_sequence_length"] = 4096 * 2
17
+ model_overide_args["tokenizer_model_max_length"] = 4096 * 2
18
+ model_overide_args["model_max_length"] = 4096 * 2
19
+
20
+ parser = argparse.ArgumentParser()
21
+ ServerArgs.add_cli_args(parser)
22
+ args = parser.parse_args()
23
+
24
+ if "34b" in args.model_path.lower():
25
+ model_overide_args["image_token_index"] = 64002
26
+
27
+ server_args = ServerArgs.from_cli_args(args)
28
+
29
+ pipe_reader, pipe_writer = mp.Pipe(duplex=False)
30
+
31
+ launch_server(server_args, pipe_writer, model_overide_args)
@@ -6,6 +6,9 @@ class FSMCache(BaseCache):
6
6
  def __init__(self, tokenizer_path, tokenizer_args_dict, enable=True):
7
7
  super().__init__(enable=enable)
8
8
 
9
+ if tokenizer_path.endswith(".json"):
10
+ return
11
+
9
12
  from importlib.metadata import version
10
13
 
11
14
  if version("outlines") >= "0.0.35":
@@ -0,0 +1,16 @@
1
+ """
2
+ Usage:
3
+ python3 -m sglang.srt.flush_cache --url http://localhost:30000
4
+ """
5
+
6
+ import argparse
7
+
8
+ import requests
9
+
10
+ if __name__ == "__main__":
11
+ parser = argparse.ArgumentParser()
12
+ parser.add_argument("--url", type=str, default="http://localhost:30000")
13
+ args = parser.parse_args()
14
+
15
+ response = requests.get(args.url + "/flush_cache")
16
+ assert response.status_code == 200
@@ -3,7 +3,8 @@
3
3
  import json
4
4
  import os
5
5
  import warnings
6
- from typing import List, Optional, Tuple, Union
6
+ import functools
7
+ from typing import Optional, Union, AbstractSet, Collection, Literal
7
8
 
8
9
  from huggingface_hub import snapshot_download
9
10
  from transformers import (
@@ -30,10 +31,17 @@ def get_config_json(model_path: str):
30
31
  return config
31
32
 
32
33
 
33
- def get_config(model: str, trust_remote_code: bool, revision: Optional[str] = None):
34
+ def get_config(
35
+ model: str,
36
+ trust_remote_code: bool,
37
+ revision: Optional[str] = None,
38
+ model_overide_args: Optional[dict] = None,
39
+ ):
34
40
  config = AutoConfig.from_pretrained(
35
41
  model, trust_remote_code=trust_remote_code, revision=revision
36
42
  )
43
+ if model_overide_args:
44
+ config.update(model_overide_args)
37
45
  return config
38
46
 
39
47
 
@@ -77,6 +85,9 @@ def get_tokenizer(
77
85
  tokenizer_revision: Optional[str] = None,
78
86
  **kwargs,
79
87
  ) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
88
+ if tokenizer_name.endswith(".json"):
89
+ return TiktokenTokenizer(tokenizer_name)
90
+
80
91
  """Gets a tokenizer for the given model name via Huggingface."""
81
92
  if is_multimodal_model(tokenizer_name):
82
93
  processor = get_processor(
@@ -163,3 +174,73 @@ def get_processor(
163
174
  **kwargs,
164
175
  )
165
176
  return processor
177
+
178
+
179
+ class TiktokenTokenizer:
180
+ def __init__(self, tokenizer_path):
181
+ import tiktoken
182
+ PAT_STR_B = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
183
+
184
+ # Read JSON
185
+ name = "tmp-json"
186
+ with open(tokenizer_path, "rb") as fin:
187
+ tok_dict = json.load(fin)
188
+
189
+ mergeable_ranks = {
190
+ bytes(item["bytes"]): item["token"] for item in tok_dict["regular_tokens"]
191
+ }
192
+ special_tokens = {
193
+ bytes(item["bytes"]).decode(): item["token"] for item in tok_dict["special_tokens"]
194
+ }
195
+ assert tok_dict["word_split"] == "V1"
196
+
197
+ kwargs = {
198
+ "name": name,
199
+ "pat_str": tok_dict.get("pat_str", PAT_STR_B),
200
+ "mergeable_ranks": mergeable_ranks,
201
+ "special_tokens": special_tokens,
202
+ }
203
+ if "default_allowed_special" in tok_dict:
204
+ default_allowed_special = set(
205
+ [bytes(bytes_list).decode() for bytes_list in tok_dict["default_allowed_special"]]
206
+ )
207
+ else:
208
+ default_allowed_special = None
209
+ if "vocab_size" in tok_dict:
210
+ kwargs["explicit_n_vocab"] = tok_dict["vocab_size"]
211
+
212
+ tokenizer = tiktoken.Encoding(**kwargs)
213
+ tokenizer._default_allowed_special = default_allowed_special or set()
214
+
215
+ def encode_patched(
216
+ self,
217
+ text: str,
218
+ *,
219
+ allowed_special: Union[Literal["all"], AbstractSet[str]] = set(), # noqa: B006
220
+ disallowed_special: Union[Literal["all"], Collection[str]] = "all",
221
+ ) -> list[int]:
222
+ if isinstance(allowed_special, set):
223
+ allowed_special |= self._default_allowed_special
224
+ return tiktoken.Encoding.encode(
225
+ self, text, allowed_special=allowed_special, disallowed_special=disallowed_special
226
+ )
227
+ tokenizer.encode = functools.partial(encode_patched, tokenizer)
228
+
229
+ # Convert to HF interface
230
+ self.tokenizer = tokenizer
231
+ self.eos_token_id = tokenizer._special_tokens["<|eos|>"]
232
+ self.vocab_size = tokenizer.n_vocab
233
+
234
+ def encode(self, x, add_special_tokens=False):
235
+ return self.tokenizer.encode(x)
236
+
237
+ def decode(self, x):
238
+ return self.tokenizer.decode(x)
239
+
240
+ def batch_decode(self, batch, skip_special_tokens=True, spaces_between_special_tokens=False):
241
+ if isinstance(batch[0], int):
242
+ batch = [[x] for x in batch]
243
+ return self.tokenizer.decode_batch(batch)
244
+
245
+ def convert_ids_to_tokens(self, index):
246
+ return self.tokenizer.decode_single_token_bytes(index).decode("utf-8", errors="ignore")