sglang 0.1.17__py3-none-any.whl → 0.1.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. sglang/__init__.py +2 -2
  2. sglang/api.py +30 -4
  3. sglang/backend/litellm.py +2 -2
  4. sglang/backend/openai.py +26 -15
  5. sglang/backend/runtime_endpoint.py +18 -14
  6. sglang/bench_latency.py +317 -0
  7. sglang/global_config.py +5 -1
  8. sglang/lang/chat_template.py +41 -6
  9. sglang/lang/compiler.py +2 -2
  10. sglang/lang/interpreter.py +6 -2
  11. sglang/lang/ir.py +74 -28
  12. sglang/launch_server.py +4 -1
  13. sglang/launch_server_llavavid.py +2 -1
  14. sglang/srt/constrained/__init__.py +14 -6
  15. sglang/srt/constrained/fsm_cache.py +6 -3
  16. sglang/srt/constrained/jump_forward.py +113 -25
  17. sglang/srt/conversation.py +2 -0
  18. sglang/srt/flush_cache.py +2 -0
  19. sglang/srt/hf_transformers_utils.py +68 -9
  20. sglang/srt/layers/extend_attention.py +2 -1
  21. sglang/srt/layers/fused_moe.py +280 -169
  22. sglang/srt/layers/logits_processor.py +106 -42
  23. sglang/srt/layers/radix_attention.py +53 -29
  24. sglang/srt/layers/token_attention.py +4 -1
  25. sglang/srt/managers/controller/dp_worker.py +6 -3
  26. sglang/srt/managers/controller/infer_batch.py +144 -69
  27. sglang/srt/managers/controller/manager_multi.py +5 -5
  28. sglang/srt/managers/controller/manager_single.py +9 -4
  29. sglang/srt/managers/controller/model_runner.py +167 -55
  30. sglang/srt/managers/controller/radix_cache.py +4 -0
  31. sglang/srt/managers/controller/schedule_heuristic.py +2 -0
  32. sglang/srt/managers/controller/tp_worker.py +156 -134
  33. sglang/srt/managers/detokenizer_manager.py +19 -21
  34. sglang/srt/managers/io_struct.py +11 -5
  35. sglang/srt/managers/tokenizer_manager.py +16 -14
  36. sglang/srt/model_config.py +89 -4
  37. sglang/srt/models/chatglm.py +399 -0
  38. sglang/srt/models/commandr.py +2 -2
  39. sglang/srt/models/dbrx.py +1 -1
  40. sglang/srt/models/gemma.py +5 -1
  41. sglang/srt/models/gemma2.py +436 -0
  42. sglang/srt/models/grok.py +204 -137
  43. sglang/srt/models/llama2.py +12 -5
  44. sglang/srt/models/llama_classification.py +107 -0
  45. sglang/srt/models/llava.py +11 -8
  46. sglang/srt/models/llavavid.py +1 -1
  47. sglang/srt/models/minicpm.py +373 -0
  48. sglang/srt/models/mixtral.py +164 -115
  49. sglang/srt/models/mixtral_quant.py +0 -1
  50. sglang/srt/models/qwen.py +1 -1
  51. sglang/srt/models/qwen2.py +1 -1
  52. sglang/srt/models/qwen2_moe.py +454 -0
  53. sglang/srt/models/stablelm.py +1 -1
  54. sglang/srt/models/yivl.py +2 -2
  55. sglang/srt/openai_api_adapter.py +35 -25
  56. sglang/srt/openai_protocol.py +2 -2
  57. sglang/srt/server.py +69 -19
  58. sglang/srt/server_args.py +76 -43
  59. sglang/srt/utils.py +177 -35
  60. sglang/test/test_programs.py +28 -10
  61. sglang/utils.py +4 -3
  62. {sglang-0.1.17.dist-info → sglang-0.1.19.dist-info}/METADATA +44 -31
  63. sglang-0.1.19.dist-info/RECORD +81 -0
  64. {sglang-0.1.17.dist-info → sglang-0.1.19.dist-info}/WHEEL +1 -1
  65. sglang/srt/managers/router/infer_batch.py +0 -596
  66. sglang/srt/managers/router/manager.py +0 -82
  67. sglang/srt/managers/router/model_rpc.py +0 -818
  68. sglang/srt/managers/router/model_runner.py +0 -445
  69. sglang/srt/managers/router/radix_cache.py +0 -267
  70. sglang/srt/managers/router/scheduler.py +0 -59
  71. sglang-0.1.17.dist-info/RECORD +0 -81
  72. {sglang-0.1.17.dist-info → sglang-0.1.19.dist-info}/LICENSE +0 -0
  73. {sglang-0.1.17.dist-info → sglang-0.1.19.dist-info}/top_level.txt +0 -0
@@ -1,14 +1,20 @@
1
1
  """Meta data for requests and batches"""
2
+
3
+ import warnings
2
4
  from dataclasses import dataclass
3
5
  from enum import IntEnum, auto
4
- from typing import List
6
+ from typing import List, Union
5
7
 
6
8
  import numpy as np
7
9
  import torch
8
10
 
11
+ from sglang.srt.constrained import RegexGuide
12
+ from sglang.srt.constrained.jump_forward import JumpForwardMap
9
13
  from sglang.srt.managers.controller.radix_cache import RadixCache
10
14
  from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
11
15
 
16
+ INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
17
+
12
18
 
13
19
  class ForwardMode(IntEnum):
14
20
  PREFILL = auto()
@@ -25,7 +31,7 @@ class BaseFinishReason:
25
31
 
26
32
 
27
33
  class FINISH_MATCHED_TOKEN(BaseFinishReason):
28
- def __init__(self, matched: int | List[int]):
34
+ def __init__(self, matched: Union[int, List[int]]):
29
35
  super().__init__()
30
36
  self.matched = matched
31
37
 
@@ -63,12 +69,15 @@ class Req:
63
69
  def __init__(self, rid, origin_input_text, origin_input_ids):
64
70
  self.rid = rid
65
71
  self.origin_input_text = origin_input_text
72
+ self.origin_input_ids_unpadded = origin_input_ids # Before image padding
66
73
  self.origin_input_ids = origin_input_ids
67
- self.origin_input_ids_unpadded = origin_input_ids # before image padding
68
- self.prev_output_str = ""
69
- self.prev_output_ids = []
70
- self.output_ids = []
71
- self.input_ids = None # input_ids = origin_input_ids + prev_output_ids
74
+ self.output_ids = [] # Each decode stage's output ids
75
+ self.input_ids = None # input_ids = origin_input_ids + output_ids
76
+
77
+ # For incremental decode
78
+ self.decoded_text = ""
79
+ self.surr_offset = None # Surrounding offset to defeat the cleanup algorithm
80
+ self.read_offset = None
72
81
 
73
82
  # The number of decoded tokens for token usage report. Note that
74
83
  # this does not include the jump forward tokens.
@@ -108,20 +117,54 @@ class Req:
108
117
  self.last_update_decode_tokens = 0
109
118
 
110
119
  # Constrained decoding
111
- self.regex_fsm = None
112
- self.regex_fsm_state = 0
113
- self.jump_forward_map = None
120
+ self.regex_fsm: RegexGuide = None
121
+ self.regex_fsm_state: int = 0
122
+ self.jump_forward_map: JumpForwardMap = None
114
123
 
115
124
  # whether request reached finished condition
116
125
  def finished(self) -> bool:
117
126
  return self.finished_reason is not None
118
127
 
119
- def partial_decode(self, ids):
120
- first_token = self.tokenizer.convert_ids_to_tokens(ids[0])
121
- first_token = (
122
- first_token.decode() if isinstance(first_token, bytes) else first_token
128
+ # Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313
129
+ def init_detokenize_incrementally(self):
130
+ first_iter = self.surr_offset is None or self.read_offset is None
131
+
132
+ if first_iter:
133
+ self.read_offset = len(self.origin_input_ids_unpadded)
134
+ self.surr_offset = max(
135
+ self.read_offset - INIT_INCREMENTAL_DETOKENIZATION_OFFSET, 0
136
+ )
137
+
138
+ all_ids = self.origin_input_ids_unpadded + self.output_ids
139
+ surr_ids = all_ids[self.surr_offset : self.read_offset]
140
+ read_ids = all_ids[self.surr_offset :]
141
+
142
+ return surr_ids, read_ids, len(all_ids)
143
+
144
+ def detokenize_incrementally(self, inplace: bool = True):
145
+ surr_ids, read_ids, num_all_tokens = self.init_detokenize_incrementally()
146
+
147
+ surr_text = self.tokenizer.decode(
148
+ surr_ids,
149
+ skip_special_tokens=self.sampling_params.skip_special_tokens,
150
+ spaces_between_special_tokens=self.sampling_params.spaces_between_special_tokens,
151
+ )
152
+ new_text = self.tokenizer.decode(
153
+ read_ids,
154
+ skip_special_tokens=self.sampling_params.skip_special_tokens,
155
+ spaces_between_special_tokens=self.sampling_params.spaces_between_special_tokens,
123
156
  )
124
- return (" " if first_token.startswith("▁") else "") + self.tokenizer.decode(ids)
157
+
158
+ if len(new_text) > len(surr_text) and not new_text.endswith("�"):
159
+ new_text = new_text[len(surr_text) :]
160
+ if inplace:
161
+ self.decoded_text += new_text
162
+ self.surr_offset = self.read_offset
163
+ self.read_offset = num_all_tokens
164
+
165
+ return True, new_text
166
+
167
+ return False, ""
125
168
 
126
169
  def max_new_tokens(self):
127
170
  return self.sampling_params.max_new_tokens
@@ -130,18 +173,17 @@ class Req:
130
173
  if self.finished():
131
174
  return
132
175
 
133
- if (
134
- len(self.prev_output_ids) + len(self.output_ids)
135
- >= self.sampling_params.max_new_tokens
136
- ):
137
- self.finished_reason = FINISH_LENGTH(len(self.prev_output_ids) + len(self.output_ids))
176
+ if len(self.output_ids) >= self.sampling_params.max_new_tokens:
177
+ self.finished_reason = FINISH_LENGTH(len(self.output_ids))
138
178
  return
139
179
 
140
180
  if (
141
181
  self.output_ids[-1] == self.tokenizer.eos_token_id
142
182
  and not self.sampling_params.ignore_eos
143
183
  ):
144
- self.finished_reason = FINISH_MATCHED_TOKEN(matched=self.tokenizer.eos_token_id)
184
+ self.finished_reason = FINISH_MATCHED_TOKEN(
185
+ matched=self.tokenizer.eos_token_id
186
+ )
145
187
  return
146
188
 
147
189
  if len(self.sampling_params.stop_strs) > 0:
@@ -150,61 +192,59 @@ class Req:
150
192
  )
151
193
 
152
194
  for stop_str in self.sampling_params.stop_strs:
153
- # FIXME: (minor) try incremental match in prev_output_str
154
- if stop_str in tail_str or stop_str in self.prev_output_str:
195
+ if stop_str in tail_str or stop_str in self.decoded_text:
155
196
  self.finished_reason = FINISH_MATCHED_STR(matched=stop_str)
156
197
  return
157
198
 
158
199
  def jump_forward_and_retokenize(self, jump_forward_str, next_state):
159
- # FIXME: This logic does not really solve the problem of determining whether
160
- # there should be a leading space.
161
- cur_output_str = self.partial_decode(self.output_ids)
162
-
163
- # TODO(lsyin): apply re-tokenize only for decode tokens so that we do not need origin_input_text anymore
164
200
  if self.origin_input_text is None:
165
201
  # Recovering text can only use unpadded ids
166
202
  self.origin_input_text = self.tokenizer.decode(
167
203
  self.origin_input_ids_unpadded
168
204
  )
169
205
 
170
- all_text = (
171
- self.origin_input_text
172
- + self.prev_output_str
173
- + cur_output_str
174
- + jump_forward_str
175
- )
206
+ all_text = self.origin_input_text + self.decoded_text + jump_forward_str
176
207
  all_ids = self.tokenizer.encode(all_text)
177
208
  prompt_tokens = len(self.origin_input_ids_unpadded)
178
- self.origin_input_ids = all_ids[:prompt_tokens]
179
- self.origin_input_ids_unpadded = self.origin_input_ids
180
- # NOTE: the output ids may not strictly correspond to the output text
181
- old_prev_output_ids = self.prev_output_ids
182
- self.prev_output_ids = all_ids[prompt_tokens:]
183
- self.prev_output_str = self.prev_output_str + cur_output_str + jump_forward_str
184
- self.output_ids = []
209
+
210
+ if all_ids[prompt_tokens - 1] != self.origin_input_ids_unpadded[-1]:
211
+ # TODO(lsyin): fix token fusion
212
+ warnings.warn(
213
+ "Token fusion between input and output, try to avoid this by removing the space at the end of the input."
214
+ )
215
+ return False
216
+
217
+ old_output_ids = self.output_ids
218
+ self.output_ids = all_ids[prompt_tokens:]
219
+ self.decoded_text = self.decoded_text + jump_forward_str
220
+ self.surr_offset = prompt_tokens
221
+ self.read_offset = len(all_ids)
222
+
223
+ # NOTE: A trick to reduce the surrouding tokens decoding overhead
224
+ for i in range(0, INIT_INCREMENTAL_DETOKENIZATION_OFFSET):
225
+ surr_text_ = self.tokenizer.decode(
226
+ all_ids[self.read_offset - i : self.read_offset]
227
+ )
228
+ if not surr_text_.endswith("�"):
229
+ self.surr_offset = self.read_offset - i
230
+ break
185
231
 
186
232
  self.regex_fsm_state = next_state
187
233
 
188
234
  if self.return_logprob:
189
235
  # For fast-forward part's logprobs
190
236
  k = 0
191
- for i, old_id in enumerate(old_prev_output_ids):
192
- if old_id == self.prev_output_ids[i]:
237
+ for i, old_id in enumerate(old_output_ids):
238
+ if old_id == self.output_ids[i]:
193
239
  k = k + 1
194
240
  else:
195
241
  break
196
242
  self.decode_token_logprobs = self.decode_token_logprobs[:k]
197
243
  self.decode_top_logprobs = self.decode_top_logprobs[:k]
198
244
  self.logprob_start_len = prompt_tokens + k
199
- self.last_update_decode_tokens = len(self.prev_output_ids) - k
200
-
201
- # print("=" * 100)
202
- # print(f"Catch jump forward:\n{jump_forward_str}")
203
- # print(self.tokenizer.convert_ids_to_tokens(self.input_ids))
204
- # print(self.tokenizer.convert_ids_to_tokens(new_input_ids))
245
+ self.last_update_decode_tokens = len(self.output_ids) - k
205
246
 
206
- # print(f"Output and jump forward str:\n{self.output_and_jump_forward_str}")
207
- # print("*" * 100)
247
+ return True
208
248
 
209
249
  def __repr__(self):
210
250
  return f"rid(n={self.rid}, " f"input_ids={self.origin_input_ids}, "
@@ -263,6 +303,10 @@ class Batch:
263
303
  def is_empty(self):
264
304
  return len(self.reqs) == 0
265
305
 
306
+ # whether batch has at least 1 streaming request
307
+ def has_stream(self) -> bool:
308
+ return any(r.stream for r in self.reqs)
309
+
266
310
  def prepare_for_extend(self, vocab_size: int, int_token_logit_bias: torch.Tensor):
267
311
  device = "cuda"
268
312
  bs = len(self.reqs)
@@ -380,7 +424,10 @@ class Batch:
380
424
  sorted_indices = [i for i in range(len(self.reqs))]
381
425
  # TODO(lsyin): improve the priority of retraction
382
426
  sorted_indices.sort(
383
- key=lambda i: (len(self.reqs[i].output_ids), -len(self.reqs[i].input_ids)),
427
+ key=lambda i: (
428
+ len(self.reqs[i].output_ids),
429
+ -len(self.reqs[i].origin_input_ids),
430
+ ),
384
431
  reverse=True,
385
432
  )
386
433
 
@@ -402,14 +449,9 @@ class Batch:
402
449
  # release the last node
403
450
  self.tree_cache.dec_lock_ref(req.last_node)
404
451
 
405
- cur_output_str = req.partial_decode(req.output_ids)
406
- req.prev_output_str = req.prev_output_str + cur_output_str
407
- req.prev_output_ids.extend(req.output_ids)
408
-
409
452
  req.prefix_indices = None
410
453
  req.last_node = None
411
454
  req.extend_input_len = 0
412
- req.output_ids = []
413
455
 
414
456
  # For incremental logprobs
415
457
  req.last_update_decode_tokens = 0
@@ -427,18 +469,54 @@ class Batch:
427
469
 
428
470
  for i, req in enumerate(self.reqs):
429
471
  if req.jump_forward_map is not None:
430
- res = req.jump_forward_map.jump_forward(req.regex_fsm_state)
431
- if res is not None:
432
- jump_forward_str, next_state = res
433
- if len(jump_forward_str) <= 1:
472
+ jump_forward_bytes = req.jump_forward_map.jump_forward_byte(
473
+ req.regex_fsm_state
474
+ )
475
+ if jump_forward_bytes is not None and len(jump_forward_bytes) > 1:
476
+ suffix_bytes = []
477
+ continuation_range = range(0x80, 0xC0)
478
+ cur_state = req.regex_fsm_state
479
+ while (
480
+ len(jump_forward_bytes)
481
+ and jump_forward_bytes[0][0] in continuation_range
482
+ ):
483
+ # continuation bytes
484
+ byte_edge = jump_forward_bytes.pop(0)
485
+ suffix_bytes.append(byte_edge[0])
486
+ cur_state = byte_edge[1]
487
+
488
+ suffix_tokens = [f"<0x{hex(b)[2:].upper()}>" for b in suffix_bytes]
489
+ suffix_ids = req.tokenizer.convert_tokens_to_ids(suffix_tokens)
490
+
491
+ # Current ids, for cache and revert
492
+ cur_all_ids = tuple(req.origin_input_ids + req.output_ids)[:-1]
493
+ cur_output_ids = req.output_ids
494
+
495
+ req.output_ids.extend(suffix_ids)
496
+ decode_res, new_text = req.detokenize_incrementally(inplace=False)
497
+ if not decode_res:
498
+ req.output_ids = cur_output_ids
434
499
  continue
435
500
 
436
- if req_pool_indices_cpu is None:
437
- req_pool_indices_cpu = self.req_pool_indices.tolist()
501
+ (
502
+ jump_forward_str,
503
+ next_state,
504
+ ) = req.jump_forward_map.jump_forward_symbol(cur_state)
505
+
506
+ # Make the incrementally decoded text part of jump_forward_str
507
+ # so that the UTF-8 will not corrupt
508
+ jump_forward_str = new_text + jump_forward_str
509
+ if not req.jump_forward_and_retokenize(
510
+ jump_forward_str, next_state
511
+ ):
512
+ req.output_ids = cur_output_ids
513
+ continue
438
514
 
439
515
  # insert the old request into tree_cache
516
+ if req_pool_indices_cpu is None:
517
+ req_pool_indices_cpu = self.req_pool_indices.tolist()
440
518
  self.tree_cache.cache_req(
441
- token_ids=tuple(req.input_ids + req.output_ids)[:-1],
519
+ token_ids=cur_all_ids,
442
520
  last_uncached_pos=len(req.prefix_indices),
443
521
  req_pool_idx=req_pool_indices_cpu[i],
444
522
  )
@@ -446,9 +524,6 @@ class Batch:
446
524
  # unlock the last node
447
525
  self.tree_cache.dec_lock_ref(req.last_node)
448
526
 
449
- # jump-forward
450
- req.jump_forward_and_retokenize(jump_forward_str, next_state)
451
-
452
527
  # re-applying image padding
453
528
  if req.pixel_values is not None:
454
529
  (
@@ -582,7 +657,7 @@ class Batch:
582
657
  if req.regex_fsm is not None:
583
658
  allowed_mask.zero_()
584
659
  allowed_mask[
585
- req.regex_fsm.allowed_token_ids(req.regex_fsm_state)
660
+ req.regex_fsm.get_next_instruction(req.regex_fsm_state).tokens
586
661
  ] = 1
587
662
  logits[i].masked_fill_(~allowed_mask, float("-inf"))
588
663
 
@@ -601,7 +676,7 @@ class Batch:
601
676
  batch_next_token_ids_cpu = batch_next_token_ids.cpu().numpy()
602
677
  for i, req in enumerate(self.reqs):
603
678
  if req.regex_fsm is not None:
604
- req.regex_fsm_state = req.regex_fsm.next_state(
679
+ req.regex_fsm_state = req.regex_fsm.get_next_state(
605
680
  req.regex_fsm_state, batch_next_token_ids_cpu[i]
606
681
  )
607
682
 
@@ -13,15 +13,15 @@ import zmq
13
13
  import zmq.asyncio
14
14
 
15
15
  from sglang.global_config import global_config
16
+ from sglang.srt.managers.controller.dp_worker import (
17
+ DataParallelWorkerThread,
18
+ start_data_parallel_worker,
19
+ )
16
20
  from sglang.srt.managers.io_struct import (
17
21
  AbortReq,
18
22
  FlushCacheReq,
19
23
  TokenizedGenerateReqInput,
20
24
  )
21
- from sglang.srt.managers.controller.dp_worker import (
22
- DataParallelWorkerThread,
23
- start_data_parallel_worker,
24
- )
25
25
  from sglang.srt.server_args import PortArgs, ServerArgs
26
26
  from sglang.utils import get_exception_traceback
27
27
 
@@ -136,7 +136,7 @@ class Controller:
136
136
  self.recv_reqs = []
137
137
  if next_step_input:
138
138
  await self.dispatching(next_step_input)
139
- #else:
139
+ # else:
140
140
  # logger.error("There is no live worker.")
141
141
 
142
142
  await asyncio.sleep(global_config.wait_for_new_request_delay)
@@ -1,7 +1,8 @@
1
1
  """A controller that manages a group of tensor parallel workers."""
2
+
2
3
  import asyncio
3
4
  import logging
4
- import time
5
+ from concurrent.futures import ThreadPoolExecutor
5
6
 
6
7
  import uvloop
7
8
  import zmq
@@ -49,7 +50,9 @@ class ControllerSingle:
49
50
  # async sleep for receiving the subsequent request and avoiding cache miss
50
51
  slept = False
51
52
  if len(out_pyobjs) != 0:
52
- has_finished = any([obj.finished_reason is not None for obj in out_pyobjs])
53
+ has_finished = any(
54
+ [obj.finished_reason is not None for obj in out_pyobjs]
55
+ )
53
56
  if has_finished:
54
57
  if self.request_dependency_delay > 0:
55
58
  slept = True
@@ -73,8 +76,9 @@ def start_controller_process(
73
76
  )
74
77
 
75
78
  try:
79
+ tp_size_local = server_args.tp_size // server_args.nnodes
76
80
  model_client = ModelTpClient(
77
- list(range(server_args.tp_size)),
81
+ [i for _ in range(server_args.nnodes) for i in range(tp_size_local)],
78
82
  server_args,
79
83
  port_args.model_port_args[0],
80
84
  model_overide_args,
@@ -87,6 +91,7 @@ def start_controller_process(
87
91
  pipe_writer.send("init ok")
88
92
 
89
93
  loop = asyncio.new_event_loop()
94
+ loop.set_default_executor(ThreadPoolExecutor(max_workers=256))
90
95
  asyncio.set_event_loop(loop)
91
96
  loop.create_task(controller.loop_for_recv_requests())
92
97
  try:
@@ -94,4 +99,4 @@ def start_controller_process(
94
99
  except Exception:
95
100
  logger.error("Exception in ControllerSingle:\n" + get_exception_traceback())
96
101
  finally:
97
- kill_parent_process()
102
+ kill_parent_process()