sglang 0.1.16__py3-none-any.whl → 0.1.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. sglang/__init__.py +3 -1
  2. sglang/api.py +3 -3
  3. sglang/backend/anthropic.py +1 -1
  4. sglang/backend/litellm.py +90 -0
  5. sglang/backend/openai.py +148 -12
  6. sglang/backend/runtime_endpoint.py +18 -10
  7. sglang/global_config.py +8 -1
  8. sglang/lang/interpreter.py +114 -67
  9. sglang/lang/ir.py +17 -2
  10. sglang/srt/constrained/fsm_cache.py +3 -0
  11. sglang/srt/flush_cache.py +1 -1
  12. sglang/srt/hf_transformers_utils.py +75 -1
  13. sglang/srt/layers/extend_attention.py +17 -0
  14. sglang/srt/layers/fused_moe.py +485 -0
  15. sglang/srt/layers/logits_processor.py +12 -7
  16. sglang/srt/layers/radix_attention.py +10 -3
  17. sglang/srt/layers/token_attention.py +16 -1
  18. sglang/srt/managers/controller/dp_worker.py +110 -0
  19. sglang/srt/managers/controller/infer_batch.py +619 -0
  20. sglang/srt/managers/controller/manager_multi.py +191 -0
  21. sglang/srt/managers/controller/manager_single.py +97 -0
  22. sglang/srt/managers/controller/model_runner.py +462 -0
  23. sglang/srt/managers/controller/radix_cache.py +267 -0
  24. sglang/srt/managers/controller/schedule_heuristic.py +59 -0
  25. sglang/srt/managers/controller/tp_worker.py +791 -0
  26. sglang/srt/managers/detokenizer_manager.py +45 -45
  27. sglang/srt/managers/io_struct.py +15 -11
  28. sglang/srt/managers/router/infer_batch.py +103 -59
  29. sglang/srt/managers/router/manager.py +1 -1
  30. sglang/srt/managers/router/model_rpc.py +175 -122
  31. sglang/srt/managers/router/model_runner.py +91 -104
  32. sglang/srt/managers/router/radix_cache.py +7 -1
  33. sglang/srt/managers/router/scheduler.py +6 -6
  34. sglang/srt/managers/tokenizer_manager.py +152 -89
  35. sglang/srt/model_config.py +4 -5
  36. sglang/srt/models/commandr.py +10 -13
  37. sglang/srt/models/dbrx.py +9 -15
  38. sglang/srt/models/gemma.py +8 -15
  39. sglang/srt/models/grok.py +671 -0
  40. sglang/srt/models/llama2.py +19 -15
  41. sglang/srt/models/llava.py +84 -20
  42. sglang/srt/models/llavavid.py +11 -20
  43. sglang/srt/models/mixtral.py +248 -118
  44. sglang/srt/models/mixtral_quant.py +373 -0
  45. sglang/srt/models/qwen.py +9 -13
  46. sglang/srt/models/qwen2.py +11 -13
  47. sglang/srt/models/stablelm.py +9 -15
  48. sglang/srt/models/yivl.py +17 -22
  49. sglang/srt/openai_api_adapter.py +140 -95
  50. sglang/srt/openai_protocol.py +10 -1
  51. sglang/srt/server.py +77 -42
  52. sglang/srt/server_args.py +51 -6
  53. sglang/srt/utils.py +124 -66
  54. sglang/test/test_programs.py +44 -0
  55. sglang/test/test_utils.py +32 -1
  56. sglang/utils.py +22 -4
  57. {sglang-0.1.16.dist-info → sglang-0.1.17.dist-info}/METADATA +15 -9
  58. sglang-0.1.17.dist-info/RECORD +81 -0
  59. sglang/srt/backend_config.py +0 -13
  60. sglang/srt/models/dbrx_config.py +0 -281
  61. sglang/srt/weight_utils.py +0 -417
  62. sglang-0.1.16.dist-info/RECORD +0 -72
  63. {sglang-0.1.16.dist-info → sglang-0.1.17.dist-info}/LICENSE +0 -0
  64. {sglang-0.1.16.dist-info → sglang-0.1.17.dist-info}/WHEEL +0 -0
  65. {sglang-0.1.16.dist-info → sglang-0.1.17.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,5 @@
1
1
  import asyncio
2
+ import inspect
2
3
 
3
4
  import uvloop
4
5
  import zmq
@@ -7,7 +8,8 @@ import zmq.asyncio
7
8
  from sglang.srt.hf_transformers_utils import get_tokenizer
8
9
  from sglang.srt.managers.io_struct import BatchStrOut, BatchTokenIDOut
9
10
  from sglang.srt.server_args import PortArgs, ServerArgs
10
- from sglang.srt.utils import get_exception_traceback
11
+ from sglang.utils import get_exception_traceback, graceful_registry
12
+ from sglang.srt.managers.controller.infer_batch import FINISH_MATCHED_STR
11
13
 
12
14
  asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
13
15
 
@@ -33,51 +35,47 @@ class DetokenizerManager:
33
35
 
34
36
  async def handle_loop(self):
35
37
  while True:
36
- recv_obj = await self.recv_from_router.recv_pyobj()
37
-
38
- if isinstance(recv_obj, BatchTokenIDOut):
39
- output_tokens = recv_obj.output_tokens
40
-
41
- # TODO(lmzheng): handle skip_special_tokens/spaces_between_special_tokens per request
42
- output_strs = self.tokenizer.batch_decode(
43
- output_tokens,
44
- skip_special_tokens=recv_obj.skip_special_tokens[0],
45
- spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[
46
- 0
47
- ],
48
- )
49
-
50
- # Trim stop str
51
- # TODO(lmzheng): handle the case where multiple stop strs are hit
52
- for i in range(len(output_strs)):
53
- if recv_obj.hit_stop_str[i] is not None:
54
- pos = output_strs[i].find(recv_obj.hit_stop_str[i])
55
- if pos != -1:
56
- output_strs[i] = output_strs[i][:pos]
57
-
58
- if len(output_tokens[i]) > 0:
59
- first_token = self.tokenizer.convert_ids_to_tokens(
60
- int(output_tokens[i][0])
61
- )
62
- if not isinstance(first_token, str):
63
- first_token = first_token.decode("utf-8", errors="ignore")
64
- if first_token.startswith("▁"):
65
- output_strs[i] = " " + output_strs[i]
66
-
67
- output_strs[i] = (
68
- recv_obj.output_and_jump_forward_strs[i] + output_strs[i]
69
- )
70
-
71
- self.send_to_tokenizer.send_pyobj(
72
- BatchStrOut(
73
- recv_obj.rids,
74
- output_strs,
75
- recv_obj.meta_info,
76
- recv_obj.finished,
38
+ recv_obj: BatchTokenIDOut = await self.recv_from_router.recv_pyobj()
39
+ assert isinstance(recv_obj, BatchTokenIDOut)
40
+
41
+ output_tokens = recv_obj.output_tokens
42
+
43
+ # TODO(lmzheng): handle skip_special_tokens/spaces_between_special_tokens per request
44
+ output_strs = self.tokenizer.batch_decode(
45
+ output_tokens,
46
+ skip_special_tokens=recv_obj.skip_special_tokens[0],
47
+ spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[
48
+ 0
49
+ ],
50
+ )
51
+
52
+ # Trim stop str
53
+ # TODO(lmzheng): handle the case where multiple stop strs are hit
54
+ for i in range(len(output_strs)):
55
+ if len(output_tokens[i]) > 0:
56
+ first_token = self.tokenizer.convert_ids_to_tokens(
57
+ int(output_tokens[i][0])
77
58
  )
59
+ if not isinstance(first_token, str):
60
+ first_token = first_token.decode("utf-8", errors="ignore")
61
+ if first_token.startswith("▁"):
62
+ output_strs[i] = " " + output_strs[i]
63
+
64
+ output_strs[i] = recv_obj.prev_output_strs[i] + output_strs[i]
65
+
66
+ if isinstance(recv_obj.finished_reason[i], FINISH_MATCHED_STR):
67
+ pos = output_strs[i].find(recv_obj.finished_reason[i].matched)
68
+ if pos != -1:
69
+ output_strs[i] = output_strs[i][:pos]
70
+
71
+ self.send_to_tokenizer.send_pyobj(
72
+ BatchStrOut(
73
+ rids=recv_obj.rids,
74
+ output_str=output_strs,
75
+ meta_info=recv_obj.meta_info,
76
+ finished_reason=recv_obj.finished_reason,
78
77
  )
79
- else:
80
- raise ValueError(f"Invalid object: {recv_obj}")
78
+ )
81
79
 
82
80
 
83
81
  def start_detokenizer_process(
@@ -85,9 +83,11 @@ def start_detokenizer_process(
85
83
  port_args: PortArgs,
86
84
  pipe_writer,
87
85
  ):
86
+ graceful_registry(inspect.currentframe().f_code.co_name)
87
+
88
88
  try:
89
89
  manager = DetokenizerManager(server_args, port_args)
90
- except Exception as e:
90
+ except Exception:
91
91
  pipe_writer.send(get_exception_traceback())
92
92
  raise
93
93
  pipe_writer.send("init ok")
@@ -3,6 +3,7 @@ from dataclasses import dataclass
3
3
  from typing import Dict, List, Optional, Union
4
4
 
5
5
  from sglang.srt.sampling_params import SamplingParams
6
+ from sglang.srt.managers.controller.infer_batch import BaseFinishReason
6
7
 
7
8
 
8
9
  @dataclass
@@ -27,14 +28,13 @@ class GenerateReqInput:
27
28
  return_text_in_logprobs: bool = False
28
29
  # Whether to stream output
29
30
  stream: bool = False
30
- # TODO: make all parameters a Union[List[T], T] to allow for batched requests
31
31
 
32
32
  def post_init(self):
33
33
 
34
- if self.text is None:
35
- assert self.input_ids is not None, "Either text or input_ids should be provided"
36
- else:
37
- assert self.input_ids is None, "Either text or input_ids should be provided"
34
+ if (self.text is None and self.input_ids is None) or (
35
+ self.text is not None and self.input_ids is not None
36
+ ):
37
+ raise ValueError("Either text or input_ids should be provided.")
38
38
 
39
39
  if self.text is not None:
40
40
  is_single = isinstance(self.text, str)
@@ -69,7 +69,8 @@ class GenerateReqInput:
69
69
  if self.rid is None:
70
70
  self.rid = [uuid.uuid4().hex for _ in range(num)]
71
71
  else:
72
- assert isinstance(self.rid, list)
72
+ if not isinstance(self.rid, list):
73
+ raise ValueError("The rid should be a list.")
73
74
 
74
75
  if self.return_logprob is None:
75
76
  self.return_logprob = [False] * num
@@ -105,21 +106,19 @@ class TokenizedGenerateReqInput:
105
106
  @dataclass
106
107
  class BatchTokenIDOut:
107
108
  rids: List[str]
109
+ prev_output_strs: List[str]
108
110
  output_tokens: List[List[int]]
109
- output_and_jump_forward_strs: List[str]
110
- hit_stop_str: List[Optional[str]]
111
111
  skip_special_tokens: List[bool]
112
112
  spaces_between_special_tokens: List[bool]
113
113
  meta_info: List[Dict]
114
- finished: List[bool]
115
-
114
+ finished_reason: List[BaseFinishReason]
116
115
 
117
116
  @dataclass
118
117
  class BatchStrOut:
119
118
  rids: List[str]
120
119
  output_str: List[str]
121
120
  meta_info: List[Dict]
122
- finished: List[bool]
121
+ finished_reason: List[BaseFinishReason]
123
122
 
124
123
 
125
124
  @dataclass
@@ -127,6 +126,11 @@ class FlushCacheReq:
127
126
  pass
128
127
 
129
128
 
129
+ @dataclass
130
+ class AbortReq:
131
+ rid: str
132
+
133
+
130
134
  @dataclass
131
135
  class DetokenizeReqInput:
132
136
  input_ids: List[int]
@@ -19,6 +19,7 @@ class FinishReason(IntEnum):
19
19
  EOS_TOKEN = auto()
20
20
  LENGTH = auto()
21
21
  STOP_STR = auto()
22
+ ABORT = auto()
22
23
 
23
24
  @staticmethod
24
25
  def to_str(reason):
@@ -28,20 +29,22 @@ class FinishReason(IntEnum):
28
29
  return "length"
29
30
  elif reason == FinishReason.STOP_STR:
30
31
  return "stop"
32
+ elif reason == FinishReason.ABORT:
33
+ return "abort"
31
34
  else:
32
35
  return None
33
36
 
34
37
 
35
38
  class Req:
36
- def __init__(self, rid, input_text, input_ids):
39
+ def __init__(self, rid, origin_input_text, origin_input_ids):
37
40
  self.rid = rid
38
- self.input_text = input_text
39
- self.input_ids = input_ids
41
+ self.origin_input_text = origin_input_text
42
+ self.origin_input_ids = origin_input_ids
43
+ self.origin_input_ids_unpadded = origin_input_ids # before image padding
44
+ self.prev_output_str = ""
45
+ self.prev_output_ids = []
40
46
  self.output_ids = []
41
-
42
- # Since jump forward may retokenize the prompt with partial outputs,
43
- # we maintain the original prompt length to report the correct usage.
44
- self.prompt_tokens = len(input_ids)
47
+ self.input_ids = None # input_ids = origin_input_ids + prev_output_ids
45
48
 
46
49
  # The number of decoded tokens for token usage report. Note that
47
50
  # this does not include the jump forward tokens.
@@ -63,6 +66,7 @@ class Req:
63
66
  self.finish_reason = None
64
67
  self.hit_stop_str = None
65
68
 
69
+ # Prefix info
66
70
  self.extend_input_len = 0
67
71
  self.prefix_indices = []
68
72
  self.last_node = None
@@ -73,70 +77,36 @@ class Req:
73
77
  self.top_logprobs_num = 0
74
78
  self.normalized_prompt_logprob = None
75
79
  self.prefill_token_logprobs = None
76
- self.decode_token_logprobs = None
77
80
  self.prefill_top_logprobs = None
78
- self.decode_top_logprobs = None
81
+ self.decode_token_logprobs = []
82
+ self.decode_top_logprobs = []
83
+ # The tokens is prefilled but need to be considered as decode tokens
84
+ # and should be updated for the decode logprobs
85
+ self.last_update_decode_tokens = 0
79
86
 
80
87
  # Constrained decoding
81
88
  self.regex_fsm = None
82
89
  self.regex_fsm_state = 0
83
90
  self.jump_forward_map = None
84
- self.output_and_jump_forward_str = ""
85
-
86
- def max_new_tokens(self):
87
- return self.sampling_params.max_new_tokens
88
91
 
89
- def jump_forward_and_retokenize(self, jump_forward_str, next_state):
90
- old_output_str = self.tokenizer.decode(self.output_ids)
91
- # FIXME: This logic does not really solve the problem of determining whether
92
- # there should be a leading space.
93
- first_token = self.tokenizer.convert_ids_to_tokens(self.output_ids[0])
92
+ def partial_decode(self, ids):
93
+ first_token = self.tokenizer.convert_ids_to_tokens(ids[0])
94
94
  first_token = (
95
95
  first_token.decode() if isinstance(first_token, bytes) else first_token
96
96
  )
97
- if first_token.startswith("▁"):
98
- old_output_str = " " + old_output_str
99
- if self.input_text is None:
100
- # TODO(lmzheng): This can be wrong. Check with Liangsheng.
101
- self.input_text = self.tokenizer.decode(self.input_ids)
102
- new_input_string = (
103
- self.input_text
104
- + self.output_and_jump_forward_str
105
- + old_output_str
106
- + jump_forward_str
107
- )
108
- new_input_ids = self.tokenizer.encode(new_input_string)
109
- if self.pixel_values is not None:
110
- # NOTE: This is a hack because the old input_ids contains the image padding
111
- jump_forward_tokens_len = len(self.tokenizer.encode(jump_forward_str))
112
- else:
113
- jump_forward_tokens_len = (
114
- len(new_input_ids) - len(self.input_ids) - len(self.output_ids)
115
- )
116
-
117
- # print("=" * 100)
118
- # print(f"Catch jump forward:\n{jump_forward_str}")
119
- # print(self.tokenizer.convert_ids_to_tokens(self.input_ids))
120
- # print(self.tokenizer.convert_ids_to_tokens(new_input_ids))
121
-
122
- self.input_ids = new_input_ids
123
- self.output_ids = []
124
- self.sampling_params.max_new_tokens = max(
125
- self.sampling_params.max_new_tokens - jump_forward_tokens_len, 0
126
- )
127
- self.regex_fsm_state = next_state
128
- self.output_and_jump_forward_str = (
129
- self.output_and_jump_forward_str + old_output_str + jump_forward_str
130
- )
97
+ return (" " if first_token.startswith("▁") else "") + self.tokenizer.decode(ids)
131
98
 
132
- # print(f"Output and jump forward str:\n{self.output_and_jump_forward_str}")
133
- # print("*" * 100)
99
+ def max_new_tokens(self):
100
+ return self.sampling_params.max_new_tokens
134
101
 
135
102
  def check_finished(self):
136
103
  if self.finished:
137
104
  return
138
105
 
139
- if len(self.output_ids) >= self.sampling_params.max_new_tokens:
106
+ if (
107
+ len(self.prev_output_ids) + len(self.output_ids)
108
+ >= self.sampling_params.max_new_tokens
109
+ ):
140
110
  self.finished = True
141
111
  self.finish_reason = FinishReason.LENGTH
142
112
  return
@@ -155,14 +125,66 @@ class Req:
155
125
  )
156
126
 
157
127
  for stop_str in self.sampling_params.stop_strs:
158
- if stop_str in tail_str:
128
+ # FIXME: (minor) try incremental match in prev_output_str
129
+ if stop_str in tail_str or stop_str in self.prev_output_str:
159
130
  self.finished = True
160
131
  self.finish_reason = FinishReason.STOP_STR
161
132
  self.hit_stop_str = stop_str
162
133
  return
163
134
 
135
+ def jump_forward_and_retokenize(self, jump_forward_str, next_state):
136
+ # FIXME: This logic does not really solve the problem of determining whether
137
+ # there should be a leading space.
138
+ cur_output_str = self.partial_decode(self.output_ids)
139
+
140
+ # TODO(lsyin): apply re-tokenize only for decode tokens so that we do not need origin_input_text anymore
141
+ if self.origin_input_text is None:
142
+ # Recovering text can only use unpadded ids
143
+ self.origin_input_text = self.tokenizer.decode(
144
+ self.origin_input_ids_unpadded
145
+ )
146
+
147
+ all_text = (
148
+ self.origin_input_text
149
+ + self.prev_output_str
150
+ + cur_output_str
151
+ + jump_forward_str
152
+ )
153
+ all_ids = self.tokenizer.encode(all_text)
154
+ prompt_tokens = len(self.origin_input_ids_unpadded)
155
+ self.origin_input_ids = all_ids[:prompt_tokens]
156
+ self.origin_input_ids_unpadded = self.origin_input_ids
157
+ # NOTE: the output ids may not strictly correspond to the output text
158
+ old_prev_output_ids = self.prev_output_ids
159
+ self.prev_output_ids = all_ids[prompt_tokens:]
160
+ self.prev_output_str = self.prev_output_str + cur_output_str + jump_forward_str
161
+ self.output_ids = []
162
+
163
+ self.regex_fsm_state = next_state
164
+
165
+ if self.return_logprob:
166
+ # For fast-forward part's logprobs
167
+ k = 0
168
+ for i, old_id in enumerate(old_prev_output_ids):
169
+ if old_id == self.prev_output_ids[i]:
170
+ k = k + 1
171
+ else:
172
+ break
173
+ self.decode_token_logprobs = self.decode_token_logprobs[:k]
174
+ self.decode_top_logprobs = self.decode_top_logprobs[:k]
175
+ self.logprob_start_len = prompt_tokens + k
176
+ self.last_update_decode_tokens = len(self.prev_output_ids) - k
177
+
178
+ # print("=" * 100)
179
+ # print(f"Catch jump forward:\n{jump_forward_str}")
180
+ # print(self.tokenizer.convert_ids_to_tokens(self.input_ids))
181
+ # print(self.tokenizer.convert_ids_to_tokens(new_input_ids))
182
+
183
+ # print(f"Output and jump forward str:\n{self.output_and_jump_forward_str}")
184
+ # print("*" * 100)
185
+
164
186
  def __repr__(self):
165
- return f"rid(n={self.rid}, " f"input_ids={self.input_ids}, "
187
+ return f"rid(n={self.rid}, " f"input_ids={self.origin_input_ids}, "
166
188
 
167
189
 
168
190
  @dataclass
@@ -333,6 +355,7 @@ class Batch:
333
355
 
334
356
  def retract_decode(self):
335
357
  sorted_indices = [i for i in range(len(self.reqs))]
358
+ # TODO(lsyin): improve the priority of retraction
336
359
  sorted_indices.sort(
337
360
  key=lambda i: (len(self.reqs[i].output_ids), -len(self.reqs[i].input_ids)),
338
361
  reverse=True,
@@ -353,18 +376,27 @@ class Batch:
353
376
  ][last_uncached_pos : seq_lens_cpu[idx]]
354
377
  self.token_to_kv_pool.dec_refs(token_indices)
355
378
 
379
+ # release the last node
356
380
  self.tree_cache.dec_lock_ref(req.last_node)
381
+
382
+ cur_output_str = req.partial_decode(req.output_ids)
383
+ req.prev_output_str = req.prev_output_str + cur_output_str
384
+ req.prev_output_ids.extend(req.output_ids)
385
+
357
386
  req.prefix_indices = None
358
387
  req.last_node = None
359
388
  req.extend_input_len = 0
360
389
  req.output_ids = []
361
- req.regex_fsm_state = 0
390
+
391
+ # For incremental logprobs
392
+ req.last_update_decode_tokens = 0
393
+ req.logprob_start_len = 10**9
362
394
 
363
395
  self.filter_batch(sorted_indices)
364
396
 
365
397
  return retracted_reqs
366
398
 
367
- def check_for_jump_forward(self):
399
+ def check_for_jump_forward(self, model_runner):
368
400
  jump_forward_reqs = []
369
401
  filter_indices = [i for i in range(len(self.reqs))]
370
402
 
@@ -394,6 +426,18 @@ class Batch:
394
426
  # jump-forward
395
427
  req.jump_forward_and_retokenize(jump_forward_str, next_state)
396
428
 
429
+ # re-applying image padding
430
+ if req.pixel_values is not None:
431
+ (
432
+ req.origin_input_ids,
433
+ req.image_offset,
434
+ ) = model_runner.model.pad_input_ids(
435
+ req.origin_input_ids_unpadded,
436
+ req.pad_value,
437
+ req.pixel_values.shape,
438
+ req.image_size,
439
+ )
440
+
397
441
  jump_forward_reqs.append(req)
398
442
  filter_indices.remove(i)
399
443
 
@@ -8,7 +8,7 @@ import zmq.asyncio
8
8
  from sglang.global_config import global_config
9
9
  from sglang.srt.managers.router.model_rpc import ModelRpcClient
10
10
  from sglang.srt.server_args import PortArgs, ServerArgs
11
- from sglang.srt.utils import get_exception_traceback
11
+ from sglang.utils import get_exception_traceback
12
12
 
13
13
  asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
14
14