sglang 0.1.17__py3-none-any.whl → 0.1.18__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. sglang/__init__.py +2 -2
  2. sglang/api.py +4 -4
  3. sglang/backend/litellm.py +2 -2
  4. sglang/backend/openai.py +26 -15
  5. sglang/bench_latency.py +299 -0
  6. sglang/global_config.py +4 -1
  7. sglang/lang/compiler.py +2 -2
  8. sglang/lang/interpreter.py +1 -1
  9. sglang/lang/ir.py +15 -5
  10. sglang/launch_server.py +4 -1
  11. sglang/launch_server_llavavid.py +2 -1
  12. sglang/srt/constrained/__init__.py +13 -6
  13. sglang/srt/constrained/fsm_cache.py +6 -3
  14. sglang/srt/constrained/jump_forward.py +113 -25
  15. sglang/srt/conversation.py +2 -0
  16. sglang/srt/flush_cache.py +2 -0
  17. sglang/srt/hf_transformers_utils.py +64 -9
  18. sglang/srt/layers/fused_moe.py +186 -89
  19. sglang/srt/layers/logits_processor.py +53 -25
  20. sglang/srt/layers/radix_attention.py +34 -7
  21. sglang/srt/managers/controller/dp_worker.py +6 -3
  22. sglang/srt/managers/controller/infer_batch.py +142 -67
  23. sglang/srt/managers/controller/manager_multi.py +5 -5
  24. sglang/srt/managers/controller/manager_single.py +8 -3
  25. sglang/srt/managers/controller/model_runner.py +154 -54
  26. sglang/srt/managers/controller/radix_cache.py +4 -0
  27. sglang/srt/managers/controller/schedule_heuristic.py +2 -0
  28. sglang/srt/managers/controller/tp_worker.py +140 -135
  29. sglang/srt/managers/detokenizer_manager.py +15 -19
  30. sglang/srt/managers/io_struct.py +10 -4
  31. sglang/srt/managers/tokenizer_manager.py +14 -13
  32. sglang/srt/model_config.py +83 -4
  33. sglang/srt/models/chatglm.py +399 -0
  34. sglang/srt/models/commandr.py +2 -2
  35. sglang/srt/models/dbrx.py +1 -1
  36. sglang/srt/models/gemma.py +5 -1
  37. sglang/srt/models/grok.py +204 -137
  38. sglang/srt/models/llama2.py +11 -4
  39. sglang/srt/models/llama_classification.py +104 -0
  40. sglang/srt/models/llava.py +11 -8
  41. sglang/srt/models/llavavid.py +1 -1
  42. sglang/srt/models/mixtral.py +164 -115
  43. sglang/srt/models/mixtral_quant.py +0 -1
  44. sglang/srt/models/qwen.py +1 -1
  45. sglang/srt/models/qwen2.py +1 -1
  46. sglang/srt/models/stablelm.py +1 -1
  47. sglang/srt/models/yivl.py +2 -2
  48. sglang/srt/openai_api_adapter.py +33 -23
  49. sglang/srt/openai_protocol.py +1 -1
  50. sglang/srt/server.py +60 -19
  51. sglang/srt/server_args.py +79 -44
  52. sglang/srt/utils.py +146 -37
  53. sglang/test/test_programs.py +28 -10
  54. sglang/utils.py +4 -3
  55. {sglang-0.1.17.dist-info → sglang-0.1.18.dist-info}/METADATA +29 -22
  56. sglang-0.1.18.dist-info/RECORD +78 -0
  57. {sglang-0.1.17.dist-info → sglang-0.1.18.dist-info}/WHEEL +1 -1
  58. sglang/srt/managers/router/infer_batch.py +0 -596
  59. sglang/srt/managers/router/manager.py +0 -82
  60. sglang/srt/managers/router/model_rpc.py +0 -818
  61. sglang/srt/managers/router/model_runner.py +0 -445
  62. sglang/srt/managers/router/radix_cache.py +0 -267
  63. sglang/srt/managers/router/scheduler.py +0 -59
  64. sglang-0.1.17.dist-info/RECORD +0 -81
  65. {sglang-0.1.17.dist-info → sglang-0.1.18.dist-info}/LICENSE +0 -0
  66. {sglang-0.1.17.dist-info → sglang-0.1.18.dist-info}/top_level.txt +0 -0
sglang/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- __version__ = "0.1.17"
1
+ __version__ = "0.1.18"
2
2
 
3
3
  # SGL API Components
4
4
  from sglang.api import (
@@ -24,10 +24,10 @@ from sglang.api import (
24
24
 
25
25
  # SGL Backends
26
26
  from sglang.backend.anthropic import Anthropic
27
+ from sglang.backend.litellm import LiteLLM
27
28
  from sglang.backend.openai import OpenAI
28
29
  from sglang.backend.runtime_endpoint import RuntimeEndpoint
29
30
  from sglang.backend.vertexai import VertexAI
30
- from sglang.backend.litellm import LiteLLM
31
31
 
32
32
  # Global Configurations
33
33
  from sglang.global_config import global_config
sglang/api.py CHANGED
@@ -1,4 +1,4 @@
1
- """Some Public API Definitions"""
1
+ """Public APIs of the language."""
2
2
 
3
3
  import os
4
4
  import re
@@ -43,14 +43,14 @@ def set_default_backend(backend: BaseBackend):
43
43
  global_config.default_backend = backend
44
44
 
45
45
 
46
- def flush_cache(backend: BaseBackend = None):
46
+ def flush_cache(backend: Optional[BaseBackend] = None):
47
47
  backend = backend or global_config.default_backend
48
48
  if backend is None:
49
49
  return False
50
50
  return backend.flush_cache()
51
51
 
52
52
 
53
- def get_server_args(backend: BaseBackend = None):
53
+ def get_server_args(backend: Optional[BaseBackend] = None):
54
54
  backend = backend or global_config.default_backend
55
55
  if backend is None:
56
56
  return None
@@ -158,7 +158,7 @@ def video(path: str, num_frames: int):
158
158
 
159
159
  def select(
160
160
  name: Optional[str] = None,
161
- choices: List[str] = None,
161
+ choices: Optional[List[str]] = None,
162
162
  temperature: float = 0.0,
163
163
  ):
164
164
  assert choices is not None
sglang/backend/litellm.py CHANGED
@@ -13,7 +13,6 @@ except ImportError as e:
13
13
 
14
14
 
15
15
  class LiteLLM(BaseBackend):
16
-
17
16
  def __init__(
18
17
  self,
19
18
  model_name,
@@ -33,7 +32,8 @@ class LiteLLM(BaseBackend):
33
32
  self.model_name = model_name
34
33
 
35
34
  self.chat_template = chat_template or get_chat_template_by_model_path(
36
- model_name)
35
+ model_name
36
+ )
37
37
 
38
38
  self.client_params = {
39
39
  "api_key": api_key,
sglang/backend/openai.py CHANGED
@@ -1,7 +1,7 @@
1
+ import dataclasses
1
2
  import logging
2
3
  import time
3
4
  import warnings
4
- import dataclasses
5
5
  from typing import Callable, List, Optional, Union
6
6
 
7
7
  import numpy as np
@@ -105,14 +105,16 @@ class OpenAI(BaseBackend):
105
105
  def get_chat_template(self):
106
106
  return self.chat_template
107
107
 
108
- def _prepare_spec_execution(self, sampling_params: SglSamplingParams,
109
- num_api_spec_tokens: int, spec_var_name: str):
108
+ def _prepare_spec_execution(
109
+ self,
110
+ sampling_params: SglSamplingParams,
111
+ num_api_spec_tokens: int,
112
+ spec_var_name: str,
113
+ ):
110
114
  if "max_tokens" not in self.spec_kwargs:
111
115
  self.spec_kwargs["max_tokens"] = num_api_spec_tokens
112
116
  else:
113
- assert (
114
- self.spec_kwargs["max_tokens"] == num_api_spec_tokens
115
- )
117
+ assert self.spec_kwargs["max_tokens"] == num_api_spec_tokens
116
118
 
117
119
  params = sampling_params.to_openai_kwargs()
118
120
  for key, value in params.items():
@@ -151,8 +153,9 @@ class OpenAI(BaseBackend):
151
153
  )
152
154
  prompt = s.messages_
153
155
  else:
154
- return self._prepare_spec_execution(sampling_params,
155
- s.num_api_spec_tokens, spec_var_name)
156
+ return self._prepare_spec_execution(
157
+ sampling_params, s.num_api_spec_tokens, spec_var_name
158
+ )
156
159
  else:
157
160
  prompt = s.text_
158
161
 
@@ -325,7 +328,7 @@ class OpenAI(BaseBackend):
325
328
  ret_str = ret.choices[0].text
326
329
  ret_token = self.tokenizer.encode(ret_str)[0]
327
330
  self.token_usage.prompt_tokens += ret.usage.prompt_tokens
328
- self.token_usage.completion_tokens= ret.usage.completion_tokens
331
+ self.token_usage.completion_tokens = ret.usage.completion_tokens
329
332
 
330
333
  # TODO:
331
334
  # 1. return logits as the scores
@@ -355,7 +358,9 @@ class OpenAI(BaseBackend):
355
358
  return decision, scores, None, None
356
359
 
357
360
 
358
- def openai_completion(client, token_usage, is_chat=None, retries=3, prompt=None, **kwargs):
361
+ def openai_completion(
362
+ client, token_usage, is_chat=None, retries=3, prompt=None, **kwargs
363
+ ):
359
364
  for attempt in range(retries):
360
365
  try:
361
366
  if is_chat:
@@ -385,15 +390,19 @@ def openai_completion(client, token_usage, is_chat=None, retries=3, prompt=None,
385
390
  return comp
386
391
 
387
392
 
388
- def openai_completion_stream(client, token_usage, is_chat=None, retries=3, prompt=None, **kwargs):
393
+ def openai_completion_stream(
394
+ client, token_usage, is_chat=None, retries=3, prompt=None, **kwargs
395
+ ):
389
396
  for attempt in range(retries):
390
397
  try:
391
398
  if is_chat:
392
399
  if "stop" in kwargs and kwargs["stop"] is None:
393
400
  kwargs.pop("stop")
394
401
  generator = client.chat.completions.create(
395
- messages=prompt, stream=True, stream_options={"include_usage": True},
396
- **kwargs
402
+ messages=prompt,
403
+ stream=True,
404
+ stream_options={"include_usage": True},
405
+ **kwargs,
397
406
  )
398
407
  for ret in generator:
399
408
  if len(ret.choices) == 0:
@@ -405,8 +414,10 @@ def openai_completion_stream(client, token_usage, is_chat=None, retries=3, promp
405
414
  yield content or "", {}
406
415
  else:
407
416
  generator = client.completions.create(
408
- prompt=prompt, stream=True, stream_options={"include_usage": True},
409
- **kwargs
417
+ prompt=prompt,
418
+ stream=True,
419
+ stream_options={"include_usage": True},
420
+ **kwargs,
410
421
  )
411
422
  for ret in generator:
412
423
  if len(ret.choices) == 0:
@@ -0,0 +1,299 @@
1
+ """
2
+ Benchmark the latency of a given model. It accepts arguments similar to those of launch_server.py.
3
+
4
+ # Usage (latency test):
5
+ python -m sglang.bench_latency --model-path meta-llama/Meta-Llama-3-8B-Instruct --load-format dummy
6
+
7
+ # Usage (correctness test):
8
+ python -m sglang.bench_latency --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --correct
9
+
10
+ ### Reference output:
11
+ prefill logits (first half) tensor([[-10.0312, -9.5000, 0.8936, ..., -4.9414, -3.2402, -3.3633],
12
+ [-10.0312, -9.5000, 0.8936, ..., -4.9414, -3.2402, -3.3633],
13
+ [ -9.1875, -10.2500, 2.7109, ..., -4.3359, -4.0664, -4.1328]],
14
+ device='cuda:0', dtype=torch.float16)
15
+ prefill logits (final) tensor([[-8.3203, -7.1211, 3.3379, ..., -4.9570, -4.1328, -3.4141],
16
+ [-8.9062, -9.0156, 4.1445, ..., -4.9922, -4.4961, -4.0742],
17
+ [-9.6328, -9.0547, 4.0117, ..., -5.3047, -4.7148, -4.4609]],
18
+ device='cuda:0', dtype=torch.float16)
19
+ <s> The capital of France is.
20
+ The capital of the United States is Washington, D.C.
21
+
22
+ <s> The capital of the United Kindom is.
23
+ The capital of the United Kingdom is London.
24
+ The capital of the
25
+ <s> Today is a sunny day and I like go for a walk in the park.
26
+ I'm going to the park
27
+ """
28
+
29
+ import argparse
30
+ import dataclasses
31
+ import logging
32
+ import multiprocessing
33
+ import time
34
+
35
+ import numpy as np
36
+ import torch
37
+ import torch.distributed as dist
38
+
39
+ from sglang.srt.hf_transformers_utils import get_tokenizer
40
+ from sglang.srt.managers.controller.infer_batch import Batch, ForwardMode, Req
41
+ from sglang.srt.managers.controller.model_runner import ModelRunner
42
+ from sglang.srt.model_config import ModelConfig
43
+ from sglang.srt.sampling_params import SamplingParams
44
+ from sglang.srt.server_args import ServerArgs
45
+ from sglang.srt.utils import suppress_other_loggers
46
+
47
+
48
+ @dataclasses.dataclass
49
+ class BenchArgs:
50
+ batch_size: int = 1
51
+ input_len: int = 1024
52
+ output_len: int = 4
53
+ correctness_test: bool = False
54
+ # This is only used for correctness test
55
+ cut_len: int = 4
56
+
57
+ @staticmethod
58
+ def add_cli_args(parser: argparse.ArgumentParser):
59
+ parser.add_argument("--batch-size", type=int, default=BenchArgs.batch_size)
60
+ parser.add_argument("--input-len", type=int, default=BenchArgs.input_len)
61
+ parser.add_argument("--output-len", type=int, default=BenchArgs.output_len)
62
+ parser.add_argument("--correctness-test", action="store_true")
63
+ parser.add_argument("--cut-len", type=int, default=BenchArgs.cut_len)
64
+
65
+ @classmethod
66
+ def from_cli_args(cls, args: argparse.Namespace):
67
+ attrs = [attr.name for attr in dataclasses.fields(cls)]
68
+ return cls(**{attr: getattr(args, attr) for attr in attrs})
69
+
70
+
71
+ def load_model(server_args, tp_rank):
72
+ suppress_other_loggers()
73
+
74
+ model_config = ModelConfig(path=server_args.model_path)
75
+ model_runner = ModelRunner(
76
+ model_config=model_config,
77
+ mem_fraction_static=server_args.mem_fraction_static,
78
+ gpu_id=tp_rank,
79
+ tp_rank=tp_rank,
80
+ tp_size=server_args.tp_size,
81
+ nccl_port=28888,
82
+ server_args=server_args,
83
+ )
84
+ print(f"max_total_num_tokens={model_runner.max_total_num_tokens}")
85
+ tokenizer = get_tokenizer(
86
+ server_args.tokenizer_path,
87
+ tokenizer_mode=server_args.tokenizer_mode,
88
+ trust_remote_code=server_args.trust_remote_code,
89
+ )
90
+ if server_args.tp_size > 1:
91
+ dist.barrier()
92
+ return model_runner, tokenizer
93
+
94
+
95
+ def prepare_inputs(bench_args, tokenizer):
96
+ prompts = [
97
+ "The capital of France is",
98
+ "The capital of the United Kindom is",
99
+ "Today is a sunny day and I like",
100
+ ]
101
+ input_ids = [tokenizer.encode(p) for p in prompts]
102
+ sampling_params = SamplingParams(
103
+ temperature=0,
104
+ max_new_tokens=BenchArgs.output_len,
105
+ )
106
+
107
+ reqs = []
108
+ for i in range(len(prompts)):
109
+ assert len(input_ids[i]) > bench_args.cut_len
110
+
111
+ tmp_input_ids = input_ids[i][:bench_args.cut_len]
112
+ req = Req(rid=i, origin_input_text=prompts[i], origin_input_ids=tmp_input_ids)
113
+ req.prefix_indices = []
114
+ req.sampling_params = sampling_params
115
+ req.input_ids = req.origin_input_ids
116
+ reqs.append(req)
117
+
118
+ return input_ids, reqs
119
+
120
+
121
+ def prepare_extend_inputs(bench_args, input_ids, reqs, model_runner):
122
+ for i in range(len(reqs)):
123
+ req = reqs[i]
124
+ req.input_ids += input_ids[i][bench_args.cut_len:]
125
+ req.prefix_indices = model_runner.req_to_token_pool.req_to_token[
126
+ i, :bench_args.cut_len
127
+ ]
128
+ return reqs
129
+
130
+
131
+ def prepare_synthetic_inputs(bench_args, tokenizer):
132
+ input_ids = np.ones((bench_args.batch_size, bench_args.input_len), dtype=np.int32)
133
+ sampling_params = SamplingParams(
134
+ temperature=0,
135
+ max_new_tokens=BenchArgs.output_len,
136
+ )
137
+
138
+ reqs = []
139
+ for i in range(len(input_ids)):
140
+ req = Req(rid=i, origin_input_text="", origin_input_ids=list(input_ids[i]))
141
+ req.prefix_indices = []
142
+ req.sampling_params = sampling_params
143
+ req.input_ids = req.origin_input_ids
144
+ reqs.append(req)
145
+
146
+ return reqs
147
+
148
+
149
+ def extend(reqs, model_runner):
150
+ batch = Batch.init_new(
151
+ reqs=reqs,
152
+ req_to_token_pool=model_runner.req_to_token_pool,
153
+ token_to_kv_pool=model_runner.token_to_kv_pool,
154
+ tree_cache=None)
155
+ batch.prepare_for_extend(model_runner.model_config.vocab_size, None)
156
+ output = model_runner.forward(batch, ForwardMode.EXTEND)
157
+ next_token_ids, _ = batch.sample(output.next_token_logits)
158
+ return next_token_ids, output.next_token_logits, batch
159
+
160
+
161
+ def decode(input_token_ids, batch, model_runner):
162
+ batch.prepare_for_decode(input_token_ids.cpu().numpy())
163
+ output = model_runner.forward(batch, ForwardMode.DECODE)
164
+ next_token_ids, _ = batch.sample(output.next_token_logits)
165
+ return next_token_ids, output.next_token_logits
166
+
167
+
168
+ def correctness_test(
169
+ server_args,
170
+ bench_args,
171
+ tp_rank,
172
+ ):
173
+ rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
174
+
175
+ # Load the model
176
+ model_runner, tokenizer = load_model(server_args, tp_rank)
177
+
178
+ # Prepare inputs
179
+ input_ids, reqs = prepare_inputs(bench_args, tokenizer)
180
+
181
+ # Prefill
182
+ next_token_ids, next_token_logits, batch = extend(reqs, model_runner)
183
+ rank_print("prefill logits (first half)", next_token_logits)
184
+
185
+ # Prepare extend inputs
186
+ reqs = prepare_extend_inputs(bench_args, input_ids, reqs, model_runner)
187
+
188
+ # Extend
189
+ next_token_ids, next_token_logits, batch = extend(reqs, model_runner)
190
+ rank_print("prefill logits (final)", next_token_logits)
191
+
192
+ # Decode
193
+ output_ids = [list(req.input_ids) for req in reqs]
194
+ for _ in range(bench_args.output_len):
195
+ next_token_ids, _ = decode(next_token_ids, batch, model_runner)
196
+ for i in range(len(reqs)):
197
+ output_ids[i].append(next_token_ids[i])
198
+
199
+ # Print
200
+ for i in range(len(reqs)):
201
+ print(tokenizer.decode(output_ids[i]))
202
+
203
+
204
+ def latency_test(
205
+ server_args,
206
+ bench_args,
207
+ tp_rank,
208
+ ):
209
+ rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
210
+
211
+ # Load the model
212
+ model_runner, tokenizer = load_model(server_args, tp_rank)
213
+ print(f"max_batch_size={model_runner.max_total_num_tokens // (bench_args.input_len + bench_args.output_len)}")
214
+
215
+ # Prepare inputs
216
+ reqs = prepare_synthetic_inputs(bench_args, tokenizer)
217
+
218
+ def clear():
219
+ model_runner.req_to_token_pool.clear()
220
+ model_runner.token_to_kv_pool.clear()
221
+
222
+ @torch.inference_mode()
223
+ def run_once(output_len):
224
+ # Prefill
225
+ torch.cuda.synchronize()
226
+ tot_latency = 0
227
+ tic = time.time()
228
+ next_token_ids, _, batch = extend(reqs, model_runner)
229
+ torch.cuda.synchronize()
230
+ prefill_latency = time.time() - tic
231
+ tot_latency += prefill_latency
232
+ throughput = bench_args.input_len * bench_args.batch_size / prefill_latency
233
+ rank_print(f"Prefill. latency: {prefill_latency:6.5f} s, throughput: {throughput:9.2f} token/s")
234
+
235
+ # Decode
236
+ for i in range(output_len):
237
+ torch.cuda.synchronize()
238
+ tic = time.time()
239
+ next_token_ids, _ = decode(next_token_ids, batch, model_runner)
240
+ torch.cuda.synchronize()
241
+ latency = time.time() - tic
242
+ tot_latency += latency
243
+ throughput = bench_args.batch_size / latency
244
+ if i < 5: rank_print(f"Decode. latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s")
245
+ avg_decode_latency = (tot_latency - prefill_latency) / output_len
246
+ avg_decode_throughput = bench_args.batch_size / avg_decode_latency
247
+ rank_print(f"Decode. avg latency: {avg_decode_latency:6.5f} s, avg throughput: {avg_decode_throughput:9.2f} token/s")
248
+
249
+ throughput = (bench_args.input_len + bench_args.output_len) * bench_args.batch_size / tot_latency
250
+ rank_print(f"Total. latency: {tot_latency:6.3f} s, throughput: {throughput:9.2f} token/s")
251
+
252
+ # Warm up
253
+ run_once(4)
254
+ clear()
255
+
256
+ # Run again
257
+ run_once(bench_args.output_len)
258
+
259
+
260
+ def main(server_args, bench_args):
261
+ print(bench_args)
262
+
263
+ if bench_args.correctness_test:
264
+ work_func = correctness_test
265
+ else:
266
+ work_func = latency_test
267
+
268
+ workers = []
269
+ for tp_rank in range(server_args.tp_size):
270
+ proc = multiprocessing.Process(
271
+ target=work_func,
272
+ args=(
273
+ server_args,
274
+ bench_args,
275
+ tp_rank,
276
+ ),
277
+ )
278
+ proc.start()
279
+ workers.append(proc)
280
+
281
+ for proc in workers:
282
+ proc.join()
283
+
284
+
285
+ if __name__ == "__main__":
286
+ parser = argparse.ArgumentParser()
287
+ ServerArgs.add_cli_args(parser)
288
+ BenchArgs.add_cli_args(parser)
289
+ args = parser.parse_args()
290
+
291
+ server_args = ServerArgs.from_cli_args(args)
292
+ bench_args = BenchArgs.from_cli_args(args)
293
+
294
+ logging.basicConfig(
295
+ level=getattr(logging, server_args.log_level.upper()),
296
+ format="%(message)s",
297
+ )
298
+
299
+ main(server_args, bench_args)
sglang/global_config.py CHANGED
@@ -27,7 +27,7 @@ class GlobalConfig:
27
27
 
28
28
  # Request dependency time due to network delay
29
29
  self.request_dependency_delay = 0.02
30
- self.wait_for_new_request_delay = 0.0004
30
+ self.wait_for_new_request_delay = 0.0006
31
31
 
32
32
  # New generation token ratio estimation
33
33
  self.base_new_token_ratio = 0.4
@@ -35,5 +35,8 @@ class GlobalConfig:
35
35
  self.new_token_ratio_decay = 0.0001
36
36
  self.new_token_ratio_recovery = 0.05
37
37
 
38
+ # The threshold (number of tokens) to trigger layer-wise cuda sync.
39
+ # This can improve the speed for large batch sizes during prefill.
40
+ self.layer_sync_threshold = 8192
38
41
 
39
42
  global_config = GlobalConfig()
sglang/lang/compiler.py CHANGED
@@ -4,7 +4,7 @@ from queue import Queue
4
4
  from typing import List, Union
5
5
 
6
6
  from sglang.global_config import global_config
7
- from sglang.lang.interpreter import ProgramState, StreamExecutor, pin_program
7
+ from sglang.lang.interpreter import ProgramState, StreamExecutor, cache_program
8
8
  from sglang.lang.ir import (
9
9
  SglArgument,
10
10
  SglConstantText,
@@ -184,7 +184,7 @@ class CompiledFunction:
184
184
 
185
185
  # Extract prefix by tracing and cache it
186
186
  if len(batch_kwargs) > 1:
187
- pin_program(self.function, backend)
187
+ cache_program(self.function, backend)
188
188
 
189
189
  # Run all programs
190
190
  if num_threads == "auto":
@@ -507,7 +507,7 @@ class StreamExecutor:
507
507
  )
508
508
  return
509
509
 
510
- else: # Speculative execution on models with completion interface
510
+ else: # Speculative execution on models with completion interface
511
511
  comp, meta_info = self._spec_gen(sampling_params)
512
512
 
513
513
  self.text_ += comp
sglang/lang/ir.py CHANGED
@@ -81,12 +81,10 @@ class SglSamplingParams:
81
81
  "top_p": self.top_p,
82
82
  "top_k": self.top_k,
83
83
  }
84
-
84
+
85
85
  def to_litellm_kwargs(self):
86
86
  if self.regex is not None:
87
- warnings.warn(
88
- "Regular expression is not supported in the LiteLLM backend."
89
- )
87
+ warnings.warn("Regular expression is not supported in the LiteLLM backend.")
90
88
  return {
91
89
  "max_tokens": self.max_new_tokens,
92
90
  "stop": self.stop or None,
@@ -122,6 +120,7 @@ class SglFunction:
122
120
  argspec = inspect.getfullargspec(func)
123
121
  assert argspec.args[0] == "s", 'The first argument must be "s"'
124
122
  self.arg_names = argspec.args[1:]
123
+ self.arg_defaults = argspec.defaults if argspec.defaults is not None else []
125
124
 
126
125
  def bind(self, **kwargs):
127
126
  assert all(key in self.arg_names for key in kwargs)
@@ -180,7 +179,18 @@ class SglFunction:
180
179
  assert isinstance(batch_kwargs, (list, tuple))
181
180
  if len(batch_kwargs) == 0:
182
181
  return []
183
- assert isinstance(batch_kwargs[0], dict)
182
+ if not isinstance(batch_kwargs[0], dict):
183
+ num_programs = len(batch_kwargs)
184
+ # change the list of argument values to dict of arg_name -> arg_value
185
+ batch_kwargs = [
186
+ {self.arg_names[i]: v for i, v in enumerate(arg_values)}
187
+ for arg_values in batch_kwargs
188
+ if isinstance(arg_values, (list, tuple)) and
189
+ len(self.arg_names) - len(self.arg_defaults) <= len(arg_values) <= len(self.arg_names)
190
+ ]
191
+ # Ensure to raise an exception if the number of arguments mismatch
192
+ if len(batch_kwargs) != num_programs:
193
+ raise Exception("Given arguments mismatch the SGL function signature")
184
194
 
185
195
  default_sampling_para = SglSamplingParams(
186
196
  max_new_tokens=max_new_tokens,
sglang/launch_server.py CHANGED
@@ -1,6 +1,9 @@
1
+ """Launch the inference server."""
2
+
1
3
  import argparse
2
4
 
3
- from sglang.srt.server import ServerArgs, launch_server
5
+ from sglang.srt.server import launch_server
6
+ from sglang.srt.server_args import ServerArgs
4
7
 
5
8
  if __name__ == "__main__":
6
9
  parser = argparse.ArgumentParser()
@@ -1,10 +1,11 @@
1
+ """Launch the inference server for Llava-video model."""
2
+
1
3
  import argparse
2
4
  import multiprocessing as mp
3
5
 
4
6
  from sglang.srt.server import ServerArgs, launch_server
5
7
 
6
8
  if __name__ == "__main__":
7
-
8
9
  model_overide_args = {}
9
10
 
10
11
  model_overide_args["mm_spatial_pool_stride"] = 2
@@ -1,13 +1,19 @@
1
1
  import json
2
2
  from typing import Dict, Optional, Union
3
3
 
4
- from outlines.caching import cache as disk_cache
5
- from outlines.caching import disable_cache
6
- from outlines.fsm.fsm import RegexFSM
7
- from outlines.fsm.regex import FSMInfo, make_deterministic_fsm
8
- from outlines.models.transformers import TransformerTokenizer
9
4
  from pydantic import BaseModel
10
5
 
6
+ try:
7
+ from outlines.caching import cache as disk_cache
8
+ from outlines.fsm.guide import RegexGuide
9
+ from outlines.caching import disable_cache
10
+ from outlines.fsm.guide import RegexGuide
11
+ from outlines.fsm.regex import FSMInfo, make_byte_level_fsm, make_deterministic_fsm
12
+ from outlines.models.transformers import TransformerTokenizer
13
+ except ImportError as e:
14
+ print(f'\nError: {e}. Please install a new version of outlines by `pip install "outlines>=0.0.44"`\n')
15
+ raise
16
+
11
17
  try:
12
18
  from outlines.fsm.json_schema import build_regex_from_object
13
19
  except ImportError:
@@ -28,11 +34,12 @@ except ImportError:
28
34
 
29
35
 
30
36
  __all__ = [
31
- "RegexFSM",
37
+ "RegexGuide",
32
38
  "FSMInfo",
33
39
  "make_deterministic_fsm",
34
40
  "build_regex_from_object",
35
41
  "TransformerTokenizer",
36
42
  "disk_cache",
37
43
  "disable_cache",
44
+ "make_byte_level_fsm",
38
45
  ]
@@ -1,4 +1,6 @@
1
- from sglang.srt.constrained import RegexFSM, TransformerTokenizer
1
+ """Cache for the compressed finite state machine."""
2
+
3
+ from sglang.srt.constrained import RegexGuide, TransformerTokenizer
2
4
  from sglang.srt.constrained.base_cache import BaseCache
3
5
 
4
6
 
@@ -6,7 +8,8 @@ class FSMCache(BaseCache):
6
8
  def __init__(self, tokenizer_path, tokenizer_args_dict, enable=True):
7
9
  super().__init__(enable=enable)
8
10
 
9
- if tokenizer_path.endswith(".json"):
11
+ if tokenizer_path.endswith(".json") or tokenizer_path.endswith(".model"):
12
+ # Do not support TiktokenTokenizer or SentencePieceTokenizer
10
13
  return
11
14
 
12
15
  from importlib.metadata import version
@@ -25,4 +28,4 @@ class FSMCache(BaseCache):
25
28
  )
26
29
 
27
30
  def init_value(self, regex):
28
- return RegexFSM(regex, self.outlines_tokenizer)
31
+ return RegexGuide(regex, self.outlines_tokenizer)