sglang 0.1.17__py3-none-any.whl → 0.1.18__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. sglang/__init__.py +2 -2
  2. sglang/api.py +4 -4
  3. sglang/backend/litellm.py +2 -2
  4. sglang/backend/openai.py +26 -15
  5. sglang/bench_latency.py +299 -0
  6. sglang/global_config.py +4 -1
  7. sglang/lang/compiler.py +2 -2
  8. sglang/lang/interpreter.py +1 -1
  9. sglang/lang/ir.py +15 -5
  10. sglang/launch_server.py +4 -1
  11. sglang/launch_server_llavavid.py +2 -1
  12. sglang/srt/constrained/__init__.py +13 -6
  13. sglang/srt/constrained/fsm_cache.py +6 -3
  14. sglang/srt/constrained/jump_forward.py +113 -25
  15. sglang/srt/conversation.py +2 -0
  16. sglang/srt/flush_cache.py +2 -0
  17. sglang/srt/hf_transformers_utils.py +64 -9
  18. sglang/srt/layers/fused_moe.py +186 -89
  19. sglang/srt/layers/logits_processor.py +53 -25
  20. sglang/srt/layers/radix_attention.py +34 -7
  21. sglang/srt/managers/controller/dp_worker.py +6 -3
  22. sglang/srt/managers/controller/infer_batch.py +142 -67
  23. sglang/srt/managers/controller/manager_multi.py +5 -5
  24. sglang/srt/managers/controller/manager_single.py +8 -3
  25. sglang/srt/managers/controller/model_runner.py +154 -54
  26. sglang/srt/managers/controller/radix_cache.py +4 -0
  27. sglang/srt/managers/controller/schedule_heuristic.py +2 -0
  28. sglang/srt/managers/controller/tp_worker.py +140 -135
  29. sglang/srt/managers/detokenizer_manager.py +15 -19
  30. sglang/srt/managers/io_struct.py +10 -4
  31. sglang/srt/managers/tokenizer_manager.py +14 -13
  32. sglang/srt/model_config.py +83 -4
  33. sglang/srt/models/chatglm.py +399 -0
  34. sglang/srt/models/commandr.py +2 -2
  35. sglang/srt/models/dbrx.py +1 -1
  36. sglang/srt/models/gemma.py +5 -1
  37. sglang/srt/models/grok.py +204 -137
  38. sglang/srt/models/llama2.py +11 -4
  39. sglang/srt/models/llama_classification.py +104 -0
  40. sglang/srt/models/llava.py +11 -8
  41. sglang/srt/models/llavavid.py +1 -1
  42. sglang/srt/models/mixtral.py +164 -115
  43. sglang/srt/models/mixtral_quant.py +0 -1
  44. sglang/srt/models/qwen.py +1 -1
  45. sglang/srt/models/qwen2.py +1 -1
  46. sglang/srt/models/stablelm.py +1 -1
  47. sglang/srt/models/yivl.py +2 -2
  48. sglang/srt/openai_api_adapter.py +33 -23
  49. sglang/srt/openai_protocol.py +1 -1
  50. sglang/srt/server.py +60 -19
  51. sglang/srt/server_args.py +79 -44
  52. sglang/srt/utils.py +146 -37
  53. sglang/test/test_programs.py +28 -10
  54. sglang/utils.py +4 -3
  55. {sglang-0.1.17.dist-info → sglang-0.1.18.dist-info}/METADATA +29 -22
  56. sglang-0.1.18.dist-info/RECORD +78 -0
  57. {sglang-0.1.17.dist-info → sglang-0.1.18.dist-info}/WHEEL +1 -1
  58. sglang/srt/managers/router/infer_batch.py +0 -596
  59. sglang/srt/managers/router/manager.py +0 -82
  60. sglang/srt/managers/router/model_rpc.py +0 -818
  61. sglang/srt/managers/router/model_runner.py +0 -445
  62. sglang/srt/managers/router/radix_cache.py +0 -267
  63. sglang/srt/managers/router/scheduler.py +0 -59
  64. sglang-0.1.17.dist-info/RECORD +0 -81
  65. {sglang-0.1.17.dist-info → sglang-0.1.18.dist-info}/LICENSE +0 -0
  66. {sglang-0.1.17.dist-info → sglang-0.1.18.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,8 @@
1
+ """Logits processing."""
2
+
3
+ import dataclasses
4
+ from typing import List
5
+
1
6
  import torch
2
7
  from torch import nn
3
8
  from vllm.distributed import (
@@ -8,6 +13,24 @@ from vllm.distributed import (
8
13
  from sglang.srt.managers.controller.model_runner import ForwardMode, InputMetadata
9
14
 
10
15
 
16
+ @dataclasses.dataclass
17
+ class LogitProcessorOutput:
18
+ # The logits of the next tokens. shape: [#seq, vocab_size]
19
+ next_token_logits: torch.Tensor
20
+ # The logprobs of the next tokens. shape: [#seq, vocab_size]
21
+ next_token_logprobs: torch.Tensor
22
+
23
+ # The normlaized logprobs of prompts. shape: [#seq]
24
+ normalized_prompt_logprobs: torch.Tensor
25
+ # The logprobs of prefill tokens. shape: [#token, vocab_size]
26
+ prefill_token_logprobs: torch.Tensor
27
+
28
+ # The logprob and id of the top-k tokens in prefill positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
29
+ prefill_top_logprobs: List
30
+ # The logprob and id of the top-k tokens in decode positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
31
+ decode_top_logprobs: List
32
+
33
+
11
34
  class LogitsProcessor(nn.Module):
12
35
  def __init__(self, config):
13
36
  super().__init__()
@@ -37,6 +60,7 @@ class LogitsProcessor(nn.Module):
37
60
  return normalized_prompt_logprobs
38
61
 
39
62
  def _get_top_logprobs(self, all_logprobs, input_metadata: InputMetadata):
63
+ # TODO: vectorize the code below
40
64
  if input_metadata.forward_mode == ForwardMode.DECODE:
41
65
  decode_top_logprobs = []
42
66
  for i in range(all_logprobs.shape[0]):
@@ -49,7 +73,6 @@ class LogitsProcessor(nn.Module):
49
73
  else:
50
74
  prefill_top_logprobs, decode_top_logprobs = [], []
51
75
  pt = 0
52
- # NOTE: the GPU-CPU overhead can be reduced
53
76
  extend_seq_lens_cpu = input_metadata.extend_seq_lens.tolist()
54
77
  for i, extend_seq_len in enumerate(extend_seq_lens_cpu):
55
78
  if extend_seq_len == 0:
@@ -69,18 +92,15 @@ class LogitsProcessor(nn.Module):
69
92
  return prefill_top_logprobs, decode_top_logprobs
70
93
 
71
94
  def forward(self, input_ids, hidden_states, weight, input_metadata: InputMetadata):
72
- # Get last index for next token prediction, except for DECODE mode.
73
- last_index = None
74
- if input_metadata.forward_mode != ForwardMode.DECODE:
95
+ # Get the last hidden states and last logits for the next token prediction
96
+ if input_metadata.forward_mode == ForwardMode.DECODE:
97
+ last_index = None
98
+ last_hidden = hidden_states
99
+ else:
75
100
  last_index = (
76
101
  torch.cumsum(input_metadata.extend_seq_lens, dim=0, dtype=torch.long)
77
102
  - 1
78
103
  )
79
-
80
- # Get the last hidden states and last logits
81
- if input_metadata.forward_mode == ForwardMode.DECODE:
82
- last_hidden = hidden_states
83
- else:
84
104
  last_hidden = hidden_states[last_index]
85
105
 
86
106
  last_logits = torch.matmul(last_hidden, weight.T)
@@ -90,8 +110,14 @@ class LogitsProcessor(nn.Module):
90
110
 
91
111
  # Return only last_logits if logprob is not requested
92
112
  if not input_metadata.return_logprob:
93
- hidden_states = None
94
- return last_logits, (None, None, None, None, None)
113
+ return LogitProcessorOutput(
114
+ next_token_logits=last_logits,
115
+ next_token_logprobs=None,
116
+ normalized_prompt_logprobs=None,
117
+ prefill_token_logprobs=None,
118
+ prefill_top_logprobs=None,
119
+ decode_top_logprobs=None,
120
+ )
95
121
  else:
96
122
  # When logprob is requested, compute the logits for all tokens.
97
123
  if input_metadata.forward_mode == ForwardMode.DECODE:
@@ -106,6 +132,7 @@ class LogitsProcessor(nn.Module):
106
132
  del all_logits
107
133
  all_logprobs[:] = torch.nn.functional.log_softmax(all_logprobs, dim=-1)
108
134
 
135
+ # Get the logprob of top-k tokens
109
136
  return_top_logprob = any(x > 0 for x in input_metadata.top_logprobs_nums)
110
137
  if return_top_logprob:
111
138
  prefill_top_logprobs, decode_top_logprobs = self._get_top_logprobs(
@@ -115,16 +142,15 @@ class LogitsProcessor(nn.Module):
115
142
  prefill_top_logprobs = decode_top_logprobs = None
116
143
 
117
144
  if input_metadata.forward_mode == ForwardMode.DECODE:
118
- last_logprobs = all_logprobs
119
- return last_logits, (
120
- None,
121
- None,
122
- None,
123
- decode_top_logprobs,
124
- last_logprobs,
145
+ return LogitProcessorOutput(
146
+ next_token_logits=last_logits,
147
+ next_token_logprobs=all_logprobs,
148
+ normalized_prompt_logprobs=None,
149
+ prefill_token_logprobs=None,
150
+ prefill_top_logprobs=None,
151
+ decode_top_logprobs=decode_top_logprobs,
125
152
  )
126
153
  else:
127
- # Compute the logprobs for the last token of each request.
128
154
  last_logprobs = all_logprobs[last_index]
129
155
 
130
156
  # Compute the logprobs and normalized logprobs for the prefill tokens.
@@ -137,12 +163,14 @@ class LogitsProcessor(nn.Module):
137
163
  normalized_prompt_logprobs = self._get_normalized_prompt_logprobs(
138
164
  prefill_token_logprobs, input_metadata
139
165
  )
140
- return last_logits, (
141
- prefill_token_logprobs,
142
- normalized_prompt_logprobs,
143
- prefill_top_logprobs,
144
- decode_top_logprobs,
145
- last_logprobs,
166
+
167
+ return LogitProcessorOutput(
168
+ next_token_logits=last_logits,
169
+ next_token_logprobs=last_logprobs,
170
+ normalized_prompt_logprobs=normalized_prompt_logprobs,
171
+ prefill_token_logprobs=prefill_token_logprobs,
172
+ prefill_top_logprobs=prefill_top_logprobs,
173
+ decode_top_logprobs=decode_top_logprobs,
146
174
  )
147
175
 
148
176
 
@@ -1,7 +1,10 @@
1
- import torch
1
+ """Radix attention."""
2
+
2
3
  import numpy as np
4
+ import torch
3
5
  from torch import nn
4
6
 
7
+ from sglang.global_config import global_config
5
8
  from sglang.srt.layers.context_flashattention_nopad import context_attention_fwd
6
9
  from sglang.srt.layers.extend_attention import extend_attention_fwd
7
10
  from sglang.srt.layers.token_attention import token_attention_fwd
@@ -9,27 +12,32 @@ from sglang.srt.managers.controller.model_runner import ForwardMode, InputMetada
9
12
 
10
13
 
11
14
  class RadixAttention(nn.Module):
12
- def __init__(self, num_heads, head_dim, scaling, num_kv_heads, layer_id, logit_cap=-1):
15
+ def __init__(
16
+ self, num_heads: int, head_dim: int, scaling: float, num_kv_heads: int,
17
+ layer_id: int, logit_cap: int = -1
18
+ ):
13
19
  super().__init__()
14
20
  self.tp_q_head_num = num_heads
15
21
  self.tp_k_head_num = num_kv_heads
16
22
  self.tp_v_head_num = num_kv_heads
17
23
  self.head_dim = head_dim
18
24
  self.layer_id = layer_id
19
- self.logit_cap = logit_cap
20
25
 
21
26
  assert np.allclose(scaling, 1.0 / (head_dim**0.5))
22
27
 
23
28
  from sglang.srt.managers.controller.model_runner import global_server_args_dict
24
29
 
25
- if global_server_args_dict.get("enable_flashinfer", False):
30
+ if not global_server_args_dict.get("disable_flashinfer", False):
26
31
  self.prefill_forward = self.prefill_forward_flashinfer
27
32
  self.extend_forward = self.prefill_forward_flashinfer
28
33
  self.decode_forward = self.decode_forward_flashinfer
34
+ # flashinfer now accepts float logit_cap argument
35
+ self.logit_cap = logit_cap if logit_cap > 0 else 0
29
36
  else:
30
37
  self.prefill_forward = self.prefill_forward_triton
31
38
  self.extend_forward = self.extend_forward_triton
32
39
  self.decode_forward = self.decode_forward_triton
40
+ self.logit_cap = logit_cap
33
41
 
34
42
  def prefill_forward_triton(self, q, k, v, input_metadata: InputMetadata):
35
43
  o = torch.empty_like(q)
@@ -96,19 +104,38 @@ class RadixAttention(nn.Module):
96
104
  def prefill_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):
97
105
  self.store_kv_cache(k, v, input_metadata)
98
106
 
99
- o = input_metadata.prefill_wrapper.forward(
107
+ o1, s1 = input_metadata.flashinfer_prefill_wrapper_ragged.forward_return_lse(
100
108
  q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
101
- input_metadata.token_to_kv_pool.kv_data[self.layer_id],
109
+ k.contiguous().view(-1, self.tp_k_head_num, self.head_dim),
110
+ v.contiguous().view(-1, self.tp_v_head_num, self.head_dim),
111
+ logits_soft_cap=self.logit_cap,
102
112
  )
103
113
 
114
+ if input_metadata.no_prefix:
115
+ o = o1
116
+ else:
117
+ o2, s2 = input_metadata.flashinfer_prefill_wrapper_paged.forward_return_lse(
118
+ q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
119
+ input_metadata.token_to_kv_pool.kv_data[self.layer_id],
120
+ causal=False,
121
+ logits_soft_cap=self.logit_cap,
122
+ )
123
+
124
+ from flashinfer.cascade import merge_state
125
+ o, _ = merge_state(o1, s1, o2, s2)
126
+
127
+ if input_metadata.total_num_tokens >= global_config.layer_sync_threshold:
128
+ torch.cuda.synchronize()
129
+
104
130
  return o.view(-1, self.tp_q_head_num * self.head_dim)
105
131
 
106
132
  def decode_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):
107
133
  self.store_kv_cache(k, v, input_metadata)
108
134
 
109
- o = input_metadata.decode_wrapper.forward(
135
+ o = input_metadata.flashinfer_decode_wrapper.forward(
110
136
  q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
111
137
  input_metadata.token_to_kv_pool.kv_data[self.layer_id],
138
+ logits_soft_cap=self.logit_cap,
112
139
  )
113
140
 
114
141
  return o.view(-1, self.tp_q_head_num * self.head_dim)
@@ -1,9 +1,10 @@
1
1
  """A data parallel worker thread."""
2
+
2
3
  import asyncio
3
4
  import logging
4
5
  import queue
5
6
  import threading
6
- from typing import List, Callable
7
+ from typing import Callable, List
7
8
 
8
9
  import uvloop
9
10
  import zmq
@@ -69,7 +70,9 @@ class DataParallelWorkerThread(threading.Thread):
69
70
 
70
71
  # async sleep for receiving the subsequent request and avoiding cache miss
71
72
  if len(out_pyobjs) != 0:
72
- has_finished = any([obj.finished_reason is not None for obj in out_pyobjs])
73
+ has_finished = any(
74
+ [obj.finished_reason is not None for obj in out_pyobjs]
75
+ )
73
76
  if has_finished:
74
77
  await asyncio.sleep(self.request_dependency_delay)
75
78
  await asyncio.sleep(global_config.wait_for_new_request_delay)
@@ -107,4 +110,4 @@ def start_data_parallel_worker(
107
110
  step_func=model_tp_client.step,
108
111
  )
109
112
  worker_thread.start()
110
- return worker_thread
113
+ return worker_thread
@@ -1,4 +1,6 @@
1
1
  """Meta data for requests and batches"""
2
+
3
+ import warnings
2
4
  from dataclasses import dataclass
3
5
  from enum import IntEnum, auto
4
6
  from typing import List
@@ -6,9 +8,13 @@ from typing import List
6
8
  import numpy as np
7
9
  import torch
8
10
 
11
+ from sglang.srt.constrained import RegexGuide
12
+ from sglang.srt.constrained.jump_forward import JumpForwardMap
9
13
  from sglang.srt.managers.controller.radix_cache import RadixCache
10
14
  from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
11
15
 
16
+ INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
17
+
12
18
 
13
19
  class ForwardMode(IntEnum):
14
20
  PREFILL = auto()
@@ -63,12 +69,15 @@ class Req:
63
69
  def __init__(self, rid, origin_input_text, origin_input_ids):
64
70
  self.rid = rid
65
71
  self.origin_input_text = origin_input_text
72
+ self.origin_input_ids_unpadded = origin_input_ids # Before image padding
66
73
  self.origin_input_ids = origin_input_ids
67
- self.origin_input_ids_unpadded = origin_input_ids # before image padding
68
- self.prev_output_str = ""
69
- self.prev_output_ids = []
70
- self.output_ids = []
71
- self.input_ids = None # input_ids = origin_input_ids + prev_output_ids
74
+ self.output_ids = [] # Each decode stage's output ids
75
+ self.input_ids = None # input_ids = origin_input_ids + output_ids
76
+
77
+ # For incremental decode
78
+ self.decoded_text = ""
79
+ self.surr_offset = None # Surrounding offset to defeat the cleanup algorithm
80
+ self.read_offset = None
72
81
 
73
82
  # The number of decoded tokens for token usage report. Note that
74
83
  # this does not include the jump forward tokens.
@@ -108,20 +117,54 @@ class Req:
108
117
  self.last_update_decode_tokens = 0
109
118
 
110
119
  # Constrained decoding
111
- self.regex_fsm = None
112
- self.regex_fsm_state = 0
113
- self.jump_forward_map = None
120
+ self.regex_fsm: RegexGuide = None
121
+ self.regex_fsm_state: int = 0
122
+ self.jump_forward_map: JumpForwardMap = None
114
123
 
115
124
  # whether request reached finished condition
116
125
  def finished(self) -> bool:
117
126
  return self.finished_reason is not None
118
127
 
119
- def partial_decode(self, ids):
120
- first_token = self.tokenizer.convert_ids_to_tokens(ids[0])
121
- first_token = (
122
- first_token.decode() if isinstance(first_token, bytes) else first_token
128
+ # Based on https://github.com/vllm-project/vllm/blob/7a64d24aad69e4d2548aa0bf528d9fe63428ab01/vllm/transformers_utils/detokenizer.py#L194-L313
129
+ def init_detokenize_incrementally(self):
130
+ first_iter = self.surr_offset is None or self.read_offset is None
131
+
132
+ if first_iter:
133
+ self.read_offset = len(self.origin_input_ids_unpadded)
134
+ self.surr_offset = max(
135
+ self.read_offset - INIT_INCREMENTAL_DETOKENIZATION_OFFSET, 0
136
+ )
137
+
138
+ all_ids = self.origin_input_ids_unpadded + self.output_ids
139
+ surr_ids = all_ids[self.surr_offset : self.read_offset]
140
+ read_ids = all_ids[self.surr_offset :]
141
+
142
+ return surr_ids, read_ids, len(all_ids)
143
+
144
+ def detokenize_incrementally(self, inplace: bool = True):
145
+ surr_ids, read_ids, num_all_tokens = self.init_detokenize_incrementally()
146
+
147
+ surr_text = self.tokenizer.decode(
148
+ surr_ids,
149
+ skip_special_tokens=self.sampling_params.skip_special_tokens,
150
+ spaces_between_special_tokens=self.sampling_params.spaces_between_special_tokens,
151
+ )
152
+ new_text = self.tokenizer.decode(
153
+ read_ids,
154
+ skip_special_tokens=self.sampling_params.skip_special_tokens,
155
+ spaces_between_special_tokens=self.sampling_params.spaces_between_special_tokens,
123
156
  )
124
- return (" " if first_token.startswith("▁") else "") + self.tokenizer.decode(ids)
157
+
158
+ if len(new_text) > len(surr_text) and not new_text.endswith("�"):
159
+ new_text = new_text[len(surr_text) :]
160
+ if inplace:
161
+ self.decoded_text += new_text
162
+ self.surr_offset = self.read_offset
163
+ self.read_offset = num_all_tokens
164
+
165
+ return True, new_text
166
+
167
+ return False, ""
125
168
 
126
169
  def max_new_tokens(self):
127
170
  return self.sampling_params.max_new_tokens
@@ -130,18 +173,17 @@ class Req:
130
173
  if self.finished():
131
174
  return
132
175
 
133
- if (
134
- len(self.prev_output_ids) + len(self.output_ids)
135
- >= self.sampling_params.max_new_tokens
136
- ):
137
- self.finished_reason = FINISH_LENGTH(len(self.prev_output_ids) + len(self.output_ids))
176
+ if len(self.output_ids) >= self.sampling_params.max_new_tokens:
177
+ self.finished_reason = FINISH_LENGTH(len(self.output_ids))
138
178
  return
139
179
 
140
180
  if (
141
181
  self.output_ids[-1] == self.tokenizer.eos_token_id
142
182
  and not self.sampling_params.ignore_eos
143
183
  ):
144
- self.finished_reason = FINISH_MATCHED_TOKEN(matched=self.tokenizer.eos_token_id)
184
+ self.finished_reason = FINISH_MATCHED_TOKEN(
185
+ matched=self.tokenizer.eos_token_id
186
+ )
145
187
  return
146
188
 
147
189
  if len(self.sampling_params.stop_strs) > 0:
@@ -150,61 +192,59 @@ class Req:
150
192
  )
151
193
 
152
194
  for stop_str in self.sampling_params.stop_strs:
153
- # FIXME: (minor) try incremental match in prev_output_str
154
- if stop_str in tail_str or stop_str in self.prev_output_str:
195
+ if stop_str in tail_str or stop_str in self.decoded_text:
155
196
  self.finished_reason = FINISH_MATCHED_STR(matched=stop_str)
156
197
  return
157
198
 
158
199
  def jump_forward_and_retokenize(self, jump_forward_str, next_state):
159
- # FIXME: This logic does not really solve the problem of determining whether
160
- # there should be a leading space.
161
- cur_output_str = self.partial_decode(self.output_ids)
162
-
163
- # TODO(lsyin): apply re-tokenize only for decode tokens so that we do not need origin_input_text anymore
164
200
  if self.origin_input_text is None:
165
201
  # Recovering text can only use unpadded ids
166
202
  self.origin_input_text = self.tokenizer.decode(
167
203
  self.origin_input_ids_unpadded
168
204
  )
169
205
 
170
- all_text = (
171
- self.origin_input_text
172
- + self.prev_output_str
173
- + cur_output_str
174
- + jump_forward_str
175
- )
206
+ all_text = self.origin_input_text + self.decoded_text + jump_forward_str
176
207
  all_ids = self.tokenizer.encode(all_text)
177
208
  prompt_tokens = len(self.origin_input_ids_unpadded)
178
- self.origin_input_ids = all_ids[:prompt_tokens]
179
- self.origin_input_ids_unpadded = self.origin_input_ids
180
- # NOTE: the output ids may not strictly correspond to the output text
181
- old_prev_output_ids = self.prev_output_ids
182
- self.prev_output_ids = all_ids[prompt_tokens:]
183
- self.prev_output_str = self.prev_output_str + cur_output_str + jump_forward_str
184
- self.output_ids = []
209
+
210
+ if all_ids[prompt_tokens - 1] != self.origin_input_ids_unpadded[-1]:
211
+ # TODO(lsyin): fix token fusion
212
+ warnings.warn(
213
+ "Token fusion between input and output, try to avoid this by removing the space at the end of the input."
214
+ )
215
+ return False
216
+
217
+ old_output_ids = self.output_ids
218
+ self.output_ids = all_ids[prompt_tokens:]
219
+ self.decoded_text = self.decoded_text + jump_forward_str
220
+ self.surr_offset = prompt_tokens
221
+ self.read_offset = len(all_ids)
222
+
223
+ # NOTE: A trick to reduce the surrouding tokens decoding overhead
224
+ for i in range(0, INIT_INCREMENTAL_DETOKENIZATION_OFFSET):
225
+ surr_text_ = self.tokenizer.decode(
226
+ all_ids[self.read_offset - i : self.read_offset]
227
+ )
228
+ if not surr_text_.endswith("�"):
229
+ self.surr_offset = self.read_offset - i
230
+ break
185
231
 
186
232
  self.regex_fsm_state = next_state
187
233
 
188
234
  if self.return_logprob:
189
235
  # For fast-forward part's logprobs
190
236
  k = 0
191
- for i, old_id in enumerate(old_prev_output_ids):
192
- if old_id == self.prev_output_ids[i]:
237
+ for i, old_id in enumerate(old_output_ids):
238
+ if old_id == self.output_ids[i]:
193
239
  k = k + 1
194
240
  else:
195
241
  break
196
242
  self.decode_token_logprobs = self.decode_token_logprobs[:k]
197
243
  self.decode_top_logprobs = self.decode_top_logprobs[:k]
198
244
  self.logprob_start_len = prompt_tokens + k
199
- self.last_update_decode_tokens = len(self.prev_output_ids) - k
200
-
201
- # print("=" * 100)
202
- # print(f"Catch jump forward:\n{jump_forward_str}")
203
- # print(self.tokenizer.convert_ids_to_tokens(self.input_ids))
204
- # print(self.tokenizer.convert_ids_to_tokens(new_input_ids))
245
+ self.last_update_decode_tokens = len(self.output_ids) - k
205
246
 
206
- # print(f"Output and jump forward str:\n{self.output_and_jump_forward_str}")
207
- # print("*" * 100)
247
+ return True
208
248
 
209
249
  def __repr__(self):
210
250
  return f"rid(n={self.rid}, " f"input_ids={self.origin_input_ids}, "
@@ -263,6 +303,10 @@ class Batch:
263
303
  def is_empty(self):
264
304
  return len(self.reqs) == 0
265
305
 
306
+ # whether batch has at least 1 streaming request
307
+ def has_stream(self) -> bool:
308
+ return any(r.stream for r in self.reqs)
309
+
266
310
  def prepare_for_extend(self, vocab_size: int, int_token_logit_bias: torch.Tensor):
267
311
  device = "cuda"
268
312
  bs = len(self.reqs)
@@ -380,7 +424,10 @@ class Batch:
380
424
  sorted_indices = [i for i in range(len(self.reqs))]
381
425
  # TODO(lsyin): improve the priority of retraction
382
426
  sorted_indices.sort(
383
- key=lambda i: (len(self.reqs[i].output_ids), -len(self.reqs[i].input_ids)),
427
+ key=lambda i: (
428
+ len(self.reqs[i].output_ids),
429
+ -len(self.reqs[i].origin_input_ids),
430
+ ),
384
431
  reverse=True,
385
432
  )
386
433
 
@@ -402,14 +449,9 @@ class Batch:
402
449
  # release the last node
403
450
  self.tree_cache.dec_lock_ref(req.last_node)
404
451
 
405
- cur_output_str = req.partial_decode(req.output_ids)
406
- req.prev_output_str = req.prev_output_str + cur_output_str
407
- req.prev_output_ids.extend(req.output_ids)
408
-
409
452
  req.prefix_indices = None
410
453
  req.last_node = None
411
454
  req.extend_input_len = 0
412
- req.output_ids = []
413
455
 
414
456
  # For incremental logprobs
415
457
  req.last_update_decode_tokens = 0
@@ -427,18 +469,54 @@ class Batch:
427
469
 
428
470
  for i, req in enumerate(self.reqs):
429
471
  if req.jump_forward_map is not None:
430
- res = req.jump_forward_map.jump_forward(req.regex_fsm_state)
431
- if res is not None:
432
- jump_forward_str, next_state = res
433
- if len(jump_forward_str) <= 1:
472
+ jump_forward_bytes = req.jump_forward_map.jump_forward_byte(
473
+ req.regex_fsm_state
474
+ )
475
+ if jump_forward_bytes is not None and len(jump_forward_bytes) > 1:
476
+ suffix_bytes = []
477
+ continuation_range = range(0x80, 0xC0)
478
+ cur_state = req.regex_fsm_state
479
+ while (
480
+ len(jump_forward_bytes)
481
+ and jump_forward_bytes[0][0] in continuation_range
482
+ ):
483
+ # continuation bytes
484
+ byte_edge = jump_forward_bytes.pop(0)
485
+ suffix_bytes.append(byte_edge[0])
486
+ cur_state = byte_edge[1]
487
+
488
+ suffix_tokens = [f"<0x{hex(b)[2:].upper()}>" for b in suffix_bytes]
489
+ suffix_ids = req.tokenizer.convert_tokens_to_ids(suffix_tokens)
490
+
491
+ # Current ids, for cache and revert
492
+ cur_all_ids = tuple(req.origin_input_ids + req.output_ids)[:-1]
493
+ cur_output_ids = req.output_ids
494
+
495
+ req.output_ids.extend(suffix_ids)
496
+ decode_res, new_text = req.detokenize_incrementally(inplace=False)
497
+ if not decode_res:
498
+ req.output_ids = cur_output_ids
434
499
  continue
435
500
 
436
- if req_pool_indices_cpu is None:
437
- req_pool_indices_cpu = self.req_pool_indices.tolist()
501
+ (
502
+ jump_forward_str,
503
+ next_state,
504
+ ) = req.jump_forward_map.jump_forward_symbol(cur_state)
505
+
506
+ # Make the incrementally decoded text part of jump_forward_str
507
+ # so that the UTF-8 will not corrupt
508
+ jump_forward_str = new_text + jump_forward_str
509
+ if not req.jump_forward_and_retokenize(
510
+ jump_forward_str, next_state
511
+ ):
512
+ req.output_ids = cur_output_ids
513
+ continue
438
514
 
439
515
  # insert the old request into tree_cache
516
+ if req_pool_indices_cpu is None:
517
+ req_pool_indices_cpu = self.req_pool_indices.tolist()
440
518
  self.tree_cache.cache_req(
441
- token_ids=tuple(req.input_ids + req.output_ids)[:-1],
519
+ token_ids=cur_all_ids,
442
520
  last_uncached_pos=len(req.prefix_indices),
443
521
  req_pool_idx=req_pool_indices_cpu[i],
444
522
  )
@@ -446,9 +524,6 @@ class Batch:
446
524
  # unlock the last node
447
525
  self.tree_cache.dec_lock_ref(req.last_node)
448
526
 
449
- # jump-forward
450
- req.jump_forward_and_retokenize(jump_forward_str, next_state)
451
-
452
527
  # re-applying image padding
453
528
  if req.pixel_values is not None:
454
529
  (
@@ -582,7 +657,7 @@ class Batch:
582
657
  if req.regex_fsm is not None:
583
658
  allowed_mask.zero_()
584
659
  allowed_mask[
585
- req.regex_fsm.allowed_token_ids(req.regex_fsm_state)
660
+ req.regex_fsm.get_next_instruction(req.regex_fsm_state).tokens
586
661
  ] = 1
587
662
  logits[i].masked_fill_(~allowed_mask, float("-inf"))
588
663
 
@@ -601,7 +676,7 @@ class Batch:
601
676
  batch_next_token_ids_cpu = batch_next_token_ids.cpu().numpy()
602
677
  for i, req in enumerate(self.reqs):
603
678
  if req.regex_fsm is not None:
604
- req.regex_fsm_state = req.regex_fsm.next_state(
679
+ req.regex_fsm_state = req.regex_fsm.get_next_state(
605
680
  req.regex_fsm_state, batch_next_token_ids_cpu[i]
606
681
  )
607
682
 
@@ -13,15 +13,15 @@ import zmq
13
13
  import zmq.asyncio
14
14
 
15
15
  from sglang.global_config import global_config
16
+ from sglang.srt.managers.controller.dp_worker import (
17
+ DataParallelWorkerThread,
18
+ start_data_parallel_worker,
19
+ )
16
20
  from sglang.srt.managers.io_struct import (
17
21
  AbortReq,
18
22
  FlushCacheReq,
19
23
  TokenizedGenerateReqInput,
20
24
  )
21
- from sglang.srt.managers.controller.dp_worker import (
22
- DataParallelWorkerThread,
23
- start_data_parallel_worker,
24
- )
25
25
  from sglang.srt.server_args import PortArgs, ServerArgs
26
26
  from sglang.utils import get_exception_traceback
27
27
 
@@ -136,7 +136,7 @@ class Controller:
136
136
  self.recv_reqs = []
137
137
  if next_step_input:
138
138
  await self.dispatching(next_step_input)
139
- #else:
139
+ # else:
140
140
  # logger.error("There is no live worker.")
141
141
 
142
142
  await asyncio.sleep(global_config.wait_for_new_request_delay)
@@ -1,7 +1,8 @@
1
1
  """A controller that manages a group of tensor parallel workers."""
2
+
2
3
  import asyncio
3
4
  import logging
4
- import time
5
+ from concurrent.futures import ThreadPoolExecutor
5
6
 
6
7
  import uvloop
7
8
  import zmq
@@ -49,7 +50,9 @@ class ControllerSingle:
49
50
  # async sleep for receiving the subsequent request and avoiding cache miss
50
51
  slept = False
51
52
  if len(out_pyobjs) != 0:
52
- has_finished = any([obj.finished_reason is not None for obj in out_pyobjs])
53
+ has_finished = any(
54
+ [obj.finished_reason is not None for obj in out_pyobjs]
55
+ )
53
56
  if has_finished:
54
57
  if self.request_dependency_delay > 0:
55
58
  slept = True
@@ -73,8 +76,9 @@ def start_controller_process(
73
76
  )
74
77
 
75
78
  try:
79
+ tp_size_local = server_args.tp_size // server_args.nnodes
76
80
  model_client = ModelTpClient(
77
- list(range(server_args.tp_size)),
81
+ [i for _ in range(server_args.nnodes) for i in range(tp_size_local)],
78
82
  server_args,
79
83
  port_args.model_port_args[0],
80
84
  model_overide_args,
@@ -87,6 +91,7 @@ def start_controller_process(
87
91
  pipe_writer.send("init ok")
88
92
 
89
93
  loop = asyncio.new_event_loop()
94
+ loop.set_default_executor(ThreadPoolExecutor(max_workers=256))
90
95
  asyncio.set_event_loop(loop)
91
96
  loop.create_task(controller.loop_for_recv_requests())
92
97
  try: