sglang 0.3.0__py3-none-any.whl → 0.3.1.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. sglang/bench_latency.py +17 -8
  2. sglang/bench_serving.py +33 -38
  3. sglang/global_config.py +5 -17
  4. sglang/lang/backend/runtime_endpoint.py +5 -2
  5. sglang/lang/interpreter.py +1 -4
  6. sglang/launch_server.py +3 -6
  7. sglang/launch_server_llavavid.py +7 -8
  8. sglang/srt/{model_config.py → configs/model_config.py} +5 -0
  9. sglang/srt/constrained/__init__.py +2 -0
  10. sglang/srt/constrained/fsm_cache.py +33 -38
  11. sglang/srt/constrained/jump_forward.py +0 -1
  12. sglang/srt/conversation.py +4 -1
  13. sglang/srt/hf_transformers_utils.py +1 -3
  14. sglang/srt/layers/activation.py +12 -0
  15. sglang/srt/layers/attention_backend.py +480 -0
  16. sglang/srt/layers/flashinfer_utils.py +235 -0
  17. sglang/srt/layers/fused_moe/layer.py +27 -7
  18. sglang/srt/layers/layernorm.py +12 -0
  19. sglang/srt/layers/logits_processor.py +64 -77
  20. sglang/srt/layers/radix_attention.py +11 -161
  21. sglang/srt/layers/sampler.py +38 -122
  22. sglang/srt/layers/torchao_utils.py +75 -0
  23. sglang/srt/layers/{decode_attention.py → triton_attention/decode_attention.py} +67 -63
  24. sglang/srt/layers/{extend_attention.py → triton_attention/extend_attention.py} +40 -132
  25. sglang/srt/layers/{prefill_attention.py → triton_attention/prefill_attention.py} +13 -7
  26. sglang/srt/lora/lora.py +403 -0
  27. sglang/srt/lora/lora_config.py +43 -0
  28. sglang/srt/lora/lora_manager.py +259 -0
  29. sglang/srt/managers/controller_multi.py +1 -5
  30. sglang/srt/managers/controller_single.py +0 -5
  31. sglang/srt/managers/io_struct.py +16 -1
  32. sglang/srt/managers/policy_scheduler.py +122 -5
  33. sglang/srt/managers/schedule_batch.py +105 -71
  34. sglang/srt/managers/tokenizer_manager.py +17 -8
  35. sglang/srt/managers/tp_worker.py +188 -121
  36. sglang/srt/model_executor/cuda_graph_runner.py +69 -133
  37. sglang/srt/model_executor/forward_batch_info.py +35 -312
  38. sglang/srt/model_executor/model_runner.py +123 -154
  39. sglang/srt/models/baichuan.py +416 -0
  40. sglang/srt/models/chatglm.py +1 -5
  41. sglang/srt/models/commandr.py +1 -5
  42. sglang/srt/models/dbrx.py +1 -5
  43. sglang/srt/models/deepseek.py +1 -5
  44. sglang/srt/models/deepseek_v2.py +7 -6
  45. sglang/srt/models/exaone.py +1 -5
  46. sglang/srt/models/gemma.py +1 -5
  47. sglang/srt/models/gemma2.py +1 -5
  48. sglang/srt/models/gpt_bigcode.py +1 -5
  49. sglang/srt/models/grok.py +1 -5
  50. sglang/srt/models/internlm2.py +1 -5
  51. sglang/srt/models/llama.py +51 -5
  52. sglang/srt/models/llama_classification.py +1 -20
  53. sglang/srt/models/llava.py +30 -5
  54. sglang/srt/models/llavavid.py +2 -2
  55. sglang/srt/models/minicpm.py +1 -5
  56. sglang/srt/models/minicpm3.py +669 -0
  57. sglang/srt/models/mixtral.py +6 -5
  58. sglang/srt/models/mixtral_quant.py +1 -5
  59. sglang/srt/models/olmoe.py +415 -0
  60. sglang/srt/models/qwen.py +1 -5
  61. sglang/srt/models/qwen2.py +1 -5
  62. sglang/srt/models/qwen2_moe.py +6 -5
  63. sglang/srt/models/stablelm.py +1 -5
  64. sglang/srt/models/xverse.py +375 -0
  65. sglang/srt/models/xverse_moe.py +445 -0
  66. sglang/srt/openai_api/adapter.py +65 -46
  67. sglang/srt/openai_api/protocol.py +11 -3
  68. sglang/srt/sampling/sampling_batch_info.py +46 -80
  69. sglang/srt/server.py +30 -15
  70. sglang/srt/server_args.py +163 -28
  71. sglang/srt/utils.py +19 -51
  72. sglang/test/few_shot_gsm8k.py +132 -0
  73. sglang/test/runners.py +114 -22
  74. sglang/test/test_programs.py +7 -5
  75. sglang/test/test_utils.py +85 -2
  76. sglang/utils.py +32 -37
  77. sglang/version.py +1 -1
  78. {sglang-0.3.0.dist-info → sglang-0.3.1.post1.dist-info}/METADATA +30 -18
  79. sglang-0.3.1.post1.dist-info/RECORD +130 -0
  80. {sglang-0.3.0.dist-info → sglang-0.3.1.post1.dist-info}/WHEEL +1 -1
  81. sglang-0.3.0.dist-info/RECORD +0 -118
  82. {sglang-0.3.0.dist-info → sglang-0.3.1.post1.dist-info}/LICENSE +0 -0
  83. {sglang-0.3.0.dist-info → sglang-0.3.1.post1.dist-info}/top_level.txt +0 -0
@@ -71,12 +71,10 @@ class ControllerMulti:
71
71
  self,
72
72
  server_args: ServerArgs,
73
73
  port_args: PortArgs,
74
- model_override_args,
75
74
  ):
76
75
  # Parse args
77
76
  self.server_args = server_args
78
77
  self.port_args = port_args
79
- self.model_override_args = model_override_args
80
78
  self.load_balance_method = LoadBalanceMethod.from_str(
81
79
  server_args.load_balance_method
82
80
  )
@@ -114,7 +112,6 @@ class ControllerMulti:
114
112
  self.server_args,
115
113
  self.port_args,
116
114
  pipe_controller_writer,
117
- self.model_override_args,
118
115
  True,
119
116
  gpu_ids,
120
117
  dp_worker_id,
@@ -189,14 +186,13 @@ def start_controller_process(
189
186
  server_args: ServerArgs,
190
187
  port_args: PortArgs,
191
188
  pipe_writer,
192
- model_override_args: dict,
193
189
  ):
194
190
  """Start a controller process."""
195
191
 
196
192
  configure_logger(server_args)
197
193
 
198
194
  try:
199
- controller = ControllerMulti(server_args, port_args, model_override_args)
195
+ controller = ControllerMulti(server_args, port_args)
200
196
  except Exception:
201
197
  pipe_writer.send(get_exception_traceback())
202
198
  raise
@@ -40,7 +40,6 @@ class ControllerSingle:
40
40
  self,
41
41
  server_args: ServerArgs,
42
42
  port_args: PortArgs,
43
- model_override_args: dict,
44
43
  gpu_ids: List[int],
45
44
  is_data_parallel_worker: bool,
46
45
  dp_worker_id: int,
@@ -76,7 +75,6 @@ class ControllerSingle:
76
75
  tp_rank_range,
77
76
  server_args,
78
77
  port_args.nccl_ports[dp_worker_id],
79
- model_override_args,
80
78
  )
81
79
 
82
80
  # Launch tp rank 0
@@ -85,7 +83,6 @@ class ControllerSingle:
85
83
  0,
86
84
  server_args,
87
85
  port_args.nccl_ports[dp_worker_id],
88
- model_override_args,
89
86
  )
90
87
  self.tp_cpu_group = self.tp_server.model_runner.tp_group.cpu_group
91
88
 
@@ -126,7 +123,6 @@ def start_controller_process(
126
123
  server_args: ServerArgs,
127
124
  port_args: PortArgs,
128
125
  pipe_writer: multiprocessing.connection.Connection,
129
- model_override_args: dict,
130
126
  is_data_parallel_worker: bool = False,
131
127
  gpu_ids: List[int] = None,
132
128
  dp_worker_id: int = None,
@@ -149,7 +145,6 @@ def start_controller_process(
149
145
  controller = ControllerSingle(
150
146
  server_args,
151
147
  port_args,
152
- model_override_args,
153
148
  gpu_ids,
154
149
  is_data_parallel_worker,
155
150
  dp_worker_id,
@@ -20,7 +20,7 @@ processes (TokenizerManager, DetokenizerManager, Controller).
20
20
 
21
21
  import copy
22
22
  import uuid
23
- from dataclasses import dataclass, field
23
+ from dataclasses import dataclass
24
24
  from typing import Dict, List, Optional, Union
25
25
 
26
26
  from sglang.srt.managers.schedule_batch import BaseFinishReason
@@ -43,6 +43,7 @@ class GenerateReqInput:
43
43
  # Whether to return logprobs.
44
44
  return_logprob: Optional[Union[List[bool], bool]] = None
45
45
  # If return logprobs, the start location in the prompt for returning logprobs.
46
+ # By default, this value is "-1", which means it will only return logprobs for output tokens.
46
47
  logprob_start_len: Optional[Union[List[int], int]] = None
47
48
  # If return logprobs, the number of top logprobs to return at each position.
48
49
  top_logprobs_num: Optional[Union[List[int], int]] = None
@@ -50,6 +51,13 @@ class GenerateReqInput:
50
51
  return_text_in_logprobs: bool = False
51
52
  # Whether to stream output.
52
53
  stream: bool = False
54
+ # The modalities of the image data [image, multi-images, video]
55
+ modalities: Optional[List[str]] = None
56
+
57
+ is_single: bool = True
58
+
59
+ # LoRA related
60
+ lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
53
61
 
54
62
  def post_init(self):
55
63
  if (self.text is None and self.input_ids is None) or (
@@ -177,6 +185,11 @@ class TokenizedGenerateReqInput:
177
185
  top_logprobs_num: int
178
186
  # Whether to stream output
179
187
  stream: bool
188
+ # Modalities of the input images
189
+ modalites: Optional[List[str]] = None
190
+
191
+ # LoRA related
192
+ lora_path: Optional[str] = None # None means just use the base model
180
193
 
181
194
 
182
195
  @dataclass
@@ -190,6 +203,8 @@ class EmbeddingReqInput:
190
203
  # Dummy sampling params for compatibility
191
204
  sampling_params: Union[List[Dict], Dict] = None
192
205
 
206
+ is_single: bool = True
207
+
193
208
  def post_init(self):
194
209
  if (self.text is None and self.input_ids is None) or (
195
210
  self.text is not None and self.input_ids is not None
@@ -108,18 +108,25 @@ class PrefillAdder:
108
108
  def __init__(
109
109
  self,
110
110
  tree_cache: BasePrefixCache,
111
+ running_batch: ScheduleBatch,
112
+ new_token_ratio: float,
111
113
  rem_total_tokens: int,
112
114
  rem_input_tokens: int,
113
115
  rem_chunk_tokens: Optional[int],
114
116
  mixed_with_decode_tokens: int = 0,
115
117
  ):
116
118
  self.tree_cache = tree_cache
119
+ self.running_batch = running_batch
120
+ self.new_token_ratio = new_token_ratio
117
121
  self.rem_total_tokens = rem_total_tokens - mixed_with_decode_tokens
122
+ self.rem_total_tokens_ = self.rem_total_tokens
123
+ self.total_tokens = rem_total_tokens
118
124
  self.rem_input_tokens = rem_input_tokens - mixed_with_decode_tokens
119
125
  self.rem_chunk_tokens = rem_chunk_tokens
120
126
  if self.rem_chunk_tokens is not None:
121
127
  self.rem_chunk_tokens -= mixed_with_decode_tokens
122
128
 
129
+ self.req_states = None
123
130
  self.can_run_list = []
124
131
  self.new_inflight_req = None
125
132
  self.log_hit_tokens = 0
@@ -136,16 +143,20 @@ class PrefillAdder:
136
143
  )
137
144
  )
138
145
 
139
- def remove_running_tokens(
140
- self, running_batch: ScheduleBatch, new_token_ratio: float
141
- ):
146
+ def remove_running_tokens(self, running_batch: ScheduleBatch):
142
147
  self.rem_total_tokens -= sum(
143
148
  [
144
149
  min(
145
150
  (r.sampling_params.max_new_tokens - len(r.output_ids)),
146
151
  CLIP_MAX_NEW_TOKENS,
147
152
  )
148
- * new_token_ratio
153
+ * self.new_token_ratio
154
+ for r in running_batch.reqs
155
+ ]
156
+ )
157
+ self.rem_total_tokens_ -= sum(
158
+ [
159
+ r.sampling_params.max_new_tokens - len(r.output_ids)
149
160
  for r in running_batch.reqs
150
161
  ]
151
162
  )
@@ -154,6 +165,7 @@ class PrefillAdder:
154
165
  self, prefix_len: int, extend_input_len: int, max_new_tokens: int
155
166
  ):
156
167
  self.rem_total_tokens -= extend_input_len + max_new_tokens
168
+ self.rem_total_tokens_ -= extend_input_len + max_new_tokens
157
169
  self.rem_input_tokens -= extend_input_len
158
170
  if self.rem_chunk_tokens is not None:
159
171
  self.rem_chunk_tokens -= extend_input_len
@@ -161,7 +173,29 @@ class PrefillAdder:
161
173
  self.log_hit_tokens += prefix_len
162
174
  self.log_input_tokens += extend_input_len
163
175
 
176
+ def add_inflight_req_ignore_eos(self, req: Req):
177
+ truncated = req.extend_input_len > self.rem_chunk_tokens
178
+ req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens)
179
+ req.fill_ids = req.fill_ids[: len(req.prefix_indices) + req.extend_input_len]
180
+ self.can_run_list.append(req)
181
+
182
+ self._prefill_one_req(
183
+ 0,
184
+ req.extend_input_len,
185
+ (
186
+ min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS)
187
+ if not truncated
188
+ else 0
189
+ ),
190
+ )
191
+
192
+ # Return if chunked prefill not finished
193
+ return req if truncated else None
194
+
164
195
  def add_inflight_req(self, req: Req):
196
+ if req.sampling_params.ignore_eos:
197
+ return self.add_inflight_req_ignore_eos(req)
198
+
165
199
  truncated = req.extend_input_len > self.rem_chunk_tokens
166
200
  req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens)
167
201
  req.fill_ids = req.fill_ids[: len(req.prefix_indices) + req.extend_input_len]
@@ -190,7 +224,90 @@ class PrefillAdder:
190
224
  delta = self.tree_cache.dec_lock_ref(last_node)
191
225
  self.rem_total_tokens += delta
192
226
 
227
+ def add_one_req_ignore_eos(self, req: Req):
228
+ def get_req_state(r):
229
+ new_token_ratio = (
230
+ 1.0 if r.sampling_params.ignore_eos else self.new_token_ratio
231
+ )
232
+ tokens_left = r.sampling_params.max_new_tokens * new_token_ratio - len(
233
+ r.output_ids
234
+ )
235
+ tokens_occupied = len(r.origin_input_ids) + len(r.output_ids)
236
+
237
+ if tokens_left > 0:
238
+ return (tokens_left, tokens_occupied)
239
+
240
+ return None
241
+
242
+ # Quick Check
243
+ can_run = False
244
+ if (
245
+ req.extend_input_len + req.sampling_params.max_new_tokens
246
+ <= self.rem_total_tokens
247
+ ):
248
+ can_run = True
249
+
250
+ if not can_run:
251
+ if self.req_states is None:
252
+ self.req_states = []
253
+ if self.running_batch is not None:
254
+ for r in self.running_batch.reqs:
255
+ state = get_req_state(r)
256
+ if state is not None:
257
+ self.req_states.append(state)
258
+ for r in self.can_run_list:
259
+ state = get_req_state(r)
260
+ if state is not None:
261
+ self.req_states.append(state)
262
+ state = get_req_state(req)
263
+ if state is not None:
264
+ self.req_states.append(state)
265
+
266
+ self.req_states.sort(key=lambda x: x[0])
267
+ else:
268
+ state = get_req_state(req)
269
+ if state is not None:
270
+ for i, (tokens_left, tokens_occupied) in enumerate(self.req_states):
271
+ if tokens_left >= state[0]:
272
+ self.req_states.insert(i, state)
273
+ break
274
+ else:
275
+ self.req_states.append(state)
276
+
277
+ tokens_freed = 0
278
+ for i, (tokens_left, tokens_occupied) in enumerate(self.req_states):
279
+ decode_steps = (
280
+ self.req_states[i + 1][0]
281
+ if i + 1 < len(self.req_states)
282
+ else tokens_left
283
+ )
284
+ bs = len(self.req_states) - i
285
+ if self.total_tokens + tokens_freed - decode_steps * bs <= 0:
286
+ return False
287
+ tokens_freed += tokens_occupied
288
+
289
+ if req.extend_input_len <= self.rem_chunk_tokens:
290
+ self.can_run_list.append(req)
291
+ self._prefill_one_req(
292
+ 0,
293
+ req.extend_input_len,
294
+ min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS),
295
+ )
296
+ else:
297
+ # Chunked prefill
298
+ trunc_len = self.rem_chunk_tokens
299
+ req.extend_input_len = trunc_len
300
+ req.fill_ids = req.fill_ids[:trunc_len]
301
+ self.can_run_list.append(req)
302
+ self.new_inflight_req = req
303
+ self._prefill_one_req(0, trunc_len, 0)
304
+
305
+ return True
306
+
193
307
  def add_one_req(self, req: Req):
308
+ if req.sampling_params.ignore_eos and self.tree_cache.disable:
309
+ return self.add_one_req_ignore_eos(req)
310
+
194
311
  total_tokens = req.extend_input_len + min(
195
312
  req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS
196
313
  )
@@ -233,4 +350,4 @@ class PrefillAdder:
233
350
  self.tree_cache.inc_lock_ref(req.last_node)
234
351
  self._prefill_one_req(prefix_len, trunc_len, 0)
235
352
 
236
- return True
353
+ return True and not self.no_remaining_tokens()
@@ -19,7 +19,7 @@ limitations under the License.
19
19
 
20
20
  import logging
21
21
  from dataclasses import dataclass
22
- from typing import TYPE_CHECKING, List, Optional, Union
22
+ from typing import List, Optional, Tuple, Union
23
23
 
24
24
  import torch
25
25
 
@@ -29,20 +29,19 @@ from sglang.srt.constrained.jump_forward import JumpForwardMap
29
29
  from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
30
30
  from sglang.srt.mem_cache.chunk_cache import ChunkCache
31
31
  from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
32
+ from sglang.srt.model_executor.forward_batch_info import ForwardMode
32
33
  from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
33
-
34
- if TYPE_CHECKING:
35
- from sglang.srt.layers.sampler import SampleOutput
36
-
34
+ from sglang.srt.server_args import ServerArgs
37
35
 
38
36
  INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
39
37
 
40
38
  # Put some global args for easy access
41
39
  global_server_args_dict = {
42
- "disable_flashinfer": False,
43
- "disable_flashinfer_sampling": False,
44
- "triton_attention_reduce_in_fp32": False,
45
- "enable_mla": False,
40
+ "attention_backend": ServerArgs.attention_backend,
41
+ "sampling_backend": ServerArgs.sampling_backend,
42
+ "triton_attention_reduce_in_fp32": ServerArgs.triton_attention_reduce_in_fp32,
43
+ "enable_mla": ServerArgs.enable_mla,
44
+ "torchao_config": ServerArgs.torchao_config,
46
45
  }
47
46
 
48
47
 
@@ -53,8 +52,8 @@ class BaseFinishReason:
53
52
  def __init__(self, is_error: bool = False):
54
53
  self.is_error = is_error
55
54
 
56
- def __str__(self):
57
- raise NotImplementedError("Subclasses must implement this method")
55
+ def to_json(self):
56
+ raise NotImplementedError()
58
57
 
59
58
 
60
59
  class FINISH_MATCHED_TOKEN(BaseFinishReason):
@@ -62,40 +61,57 @@ class FINISH_MATCHED_TOKEN(BaseFinishReason):
62
61
  super().__init__()
63
62
  self.matched = matched
64
63
 
65
- def __str__(self) -> str:
66
- return f"FINISH_MATCHED_TOKEN: {self.matched}"
64
+ def to_json(self):
65
+ return {
66
+ "type": "stop", # to match OpenAI API's return value
67
+ "matched": self.matched,
68
+ }
67
69
 
68
70
 
69
- class FINISH_LENGTH(BaseFinishReason):
70
- def __init__(self, length: int):
71
+ class FINISH_MATCHED_STR(BaseFinishReason):
72
+ def __init__(self, matched: str):
71
73
  super().__init__()
72
- self.length = length
74
+ self.matched = matched
73
75
 
74
- def __str__(self) -> str:
75
- return f"FINISH_LENGTH: {self.length}"
76
+ def to_json(self):
77
+ return {
78
+ "type": "stop", # to match OpenAI API's return value
79
+ "matched": self.matched,
80
+ }
76
81
 
77
82
 
78
- class FINISH_MATCHED_STR(BaseFinishReason):
79
- def __init__(self, matched: str):
83
+ class FINISH_LENGTH(BaseFinishReason):
84
+ def __init__(self, length: int):
80
85
  super().__init__()
81
- self.matched = matched
86
+ self.length = length
82
87
 
83
- def __str__(self) -> str:
84
- return f"FINISH_MATCHED_STR: {self.matched}"
88
+ def to_json(self):
89
+ return {
90
+ "type": "length", # to match OpenAI API's return value
91
+ "length": self.length,
92
+ }
85
93
 
86
94
 
87
95
  class FINISH_ABORT(BaseFinishReason):
88
96
  def __init__(self):
89
97
  super().__init__(is_error=True)
90
98
 
91
- def __str__(self) -> str:
92
- return "FINISH_ABORT"
99
+ def to_json(self):
100
+ return {
101
+ "type": "abort",
102
+ }
93
103
 
94
104
 
95
105
  class Req:
96
106
  """Store all inforamtion of a request."""
97
107
 
98
- def __init__(self, rid, origin_input_text, origin_input_ids):
108
+ def __init__(
109
+ self,
110
+ rid: str,
111
+ origin_input_text: str,
112
+ origin_input_ids: Tuple[int],
113
+ lora_path: Optional[str] = None,
114
+ ):
99
115
  # Input and output info
100
116
  self.rid = rid
101
117
  self.origin_input_text = origin_input_text
@@ -103,10 +119,15 @@ class Req:
103
119
  self.origin_input_ids = origin_input_ids
104
120
  self.output_ids = [] # Each decode stage's output ids
105
121
  self.fill_ids = None # fill_ids = origin_input_ids + output_ids
122
+ self.lora_path = lora_path
106
123
 
107
124
  # Memory info
108
125
  self.req_pool_idx = None
109
126
 
127
+ # Check finish
128
+ self.tokenizer = None
129
+ self.finished_reason = None
130
+
110
131
  # For incremental decoding
111
132
  # ----- | --------- read_ids -------|
112
133
  # ----- | surr_ids |
@@ -125,38 +146,43 @@ class Req:
125
146
  # this does not include the jump forward tokens.
126
147
  self.completion_tokens_wo_jump_forward = 0
127
148
 
128
- # For vision input
149
+ # For vision inputs
129
150
  self.pixel_values = None
130
151
  self.image_sizes = None
131
152
  self.image_offsets = None
132
153
  self.pad_value = None
154
+ self.modalities = None
133
155
 
134
156
  # Prefix info
135
- self.extend_input_len = 0
136
157
  self.prefix_indices = []
158
+ self.extend_input_len = 0
137
159
  self.last_node = None
138
160
 
139
161
  # Sampling parameters
140
162
  self.sampling_params = None
141
163
  self.stream = False
142
164
 
143
- # Check finish
144
- self.tokenizer = None
145
- self.finished_reason = None
146
-
147
- # Logprobs
165
+ # Logprobs (arguments)
148
166
  self.return_logprob = False
149
- self.embedding = None
150
167
  self.logprob_start_len = 0
151
168
  self.top_logprobs_num = 0
169
+
170
+ # Logprobs (return value)
152
171
  self.normalized_prompt_logprob = None
153
172
  self.input_token_logprobs = None
154
173
  self.input_top_logprobs = None
155
174
  self.output_token_logprobs = []
156
175
  self.output_top_logprobs = []
176
+
177
+ # Logprobs (internal values)
157
178
  # The tokens is prefilled but need to be considered as decode tokens
158
179
  # and should be updated for the decode logprobs
159
180
  self.last_update_decode_tokens = 0
181
+ # The relative logprob_start_len in an extend batch
182
+ self.extend_logprob_start_len = 0
183
+
184
+ # Embedding
185
+ self.embedding = None
160
186
 
161
187
  # Constrained decoding
162
188
  self.regex_fsm: RegexGuide = None
@@ -333,6 +359,9 @@ class ScheduleBatch:
333
359
  token_to_kv_pool: BaseTokenToKVPool
334
360
  tree_cache: BasePrefixCache
335
361
 
362
+ forward_mode: ForwardMode = None
363
+ sampling_info: SamplingBatchInfo = None
364
+
336
365
  # Batched arguments to model runner
337
366
  input_ids: torch.Tensor = None
338
367
  req_pool_indices: torch.Tensor = None
@@ -343,14 +372,19 @@ class ScheduleBatch:
343
372
 
344
373
  # For mixed chunekd prefill
345
374
  prefix_lens_cpu: List[int] = None
375
+ running_bs: int = None
346
376
 
347
377
  # For processing logprobs
348
378
  return_logprob: bool = False
349
379
  top_logprobs_nums: List[int] = None
350
380
 
381
+ # Stream
382
+ has_stream: bool = False
383
+
351
384
  @classmethod
352
385
  def init_new(cls, reqs, req_to_token_pool, token_to_kv_pool, tree_cache):
353
386
  return_logprob = any(req.return_logprob for req in reqs)
387
+ has_stream = any(req.stream for req in reqs)
354
388
 
355
389
  return cls(
356
390
  reqs=reqs,
@@ -358,18 +392,15 @@ class ScheduleBatch:
358
392
  token_to_kv_pool=token_to_kv_pool,
359
393
  tree_cache=tree_cache,
360
394
  return_logprob=return_logprob,
395
+ has_stream=has_stream,
361
396
  )
362
397
 
363
398
  def batch_size(self):
364
- return len(self.reqs) if self.reqs is not None else 0
399
+ return len(self.reqs)
365
400
 
366
401
  def is_empty(self):
367
402
  return len(self.reqs) == 0
368
403
 
369
- def has_stream(self) -> bool:
370
- # Return whether batch has at least 1 streaming request
371
- return any(r.stream for r in self.reqs)
372
-
373
404
  def alloc_req_slots(self, num_reqs):
374
405
  req_pool_indices = self.req_to_token_pool.alloc(num_reqs)
375
406
  if req_pool_indices is None:
@@ -396,6 +427,8 @@ class ScheduleBatch:
396
427
  return out_cache_loc
397
428
 
398
429
  def prepare_for_extend(self, vocab_size: int):
430
+ self.forward_mode = ForwardMode.EXTEND
431
+
399
432
  bs = self.batch_size()
400
433
  reqs = self.reqs
401
434
  input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
@@ -410,8 +443,8 @@ class ScheduleBatch:
410
443
  for i, req in enumerate(reqs):
411
444
  req.req_pool_idx = req_pool_indices_cpu[i]
412
445
  pre_len, seq_len = len(req.prefix_indices), len(req.fill_ids)
413
- ext_len = seq_len - pre_len
414
446
  seq_lens.append(seq_len)
447
+ assert seq_len - pre_len == req.extend_input_len
415
448
 
416
449
  if pre_len > 0:
417
450
  self.req_to_token_pool.req_to_token[req.req_pool_idx][
@@ -419,9 +452,19 @@ class ScheduleBatch:
419
452
  ] = req.prefix_indices
420
453
 
421
454
  self.req_to_token_pool.req_to_token[req.req_pool_idx][pre_len:seq_len] = (
422
- out_cache_loc[pt : pt + ext_len]
455
+ out_cache_loc[pt : pt + req.extend_input_len]
423
456
  )
424
- pt += ext_len
457
+
458
+ # Compute the relative logprob_start_len in an extend batch
459
+ if req.logprob_start_len >= pre_len:
460
+ extend_logprob_start_len = min(
461
+ req.logprob_start_len - pre_len, req.extend_input_len - 1
462
+ )
463
+ else:
464
+ extend_logprob_start_len = req.extend_input_len - 1
465
+
466
+ req.extend_logprob_start_len = extend_logprob_start_len
467
+ pt += req.extend_input_len
425
468
 
426
469
  # Set fields
427
470
  with torch.device("cuda"):
@@ -434,18 +477,13 @@ class ScheduleBatch:
434
477
  self.out_cache_loc = out_cache_loc
435
478
  self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
436
479
  self.prefix_lens_cpu = [len(r.prefix_indices) for r in reqs]
437
-
480
+ self.extend_lens_cpu = [r.extend_input_len for r in reqs]
481
+ self.extend_logprob_start_lens_cpu = [r.extend_logprob_start_len for r in reqs]
438
482
  self.sampling_info = SamplingBatchInfo.from_schedule_batch(self, vocab_size)
439
483
 
440
484
  def mix_with_running(self, running_batch: "ScheduleBatch"):
441
- # NOTE: prefix_indices is what has been cached, but we don't cache each decode step
442
- prefix_lens_cpu = [len(r.prefix_indices) for r in self.reqs]
443
- prefix_lens_cpu.extend(
444
- [
445
- len(r.origin_input_ids) + len(r.output_ids) - 1
446
- for r in running_batch.reqs
447
- ]
448
- )
485
+ self.forward_mode = ForwardMode.MIXED
486
+ running_bs = running_batch.batch_size()
449
487
 
450
488
  for req in running_batch.reqs:
451
489
  req.fill_ids = req.origin_input_ids + req.output_ids
@@ -453,12 +491,22 @@ class ScheduleBatch:
453
491
 
454
492
  input_ids = torch.cat([self.input_ids, running_batch.input_ids])
455
493
  out_cache_loc = torch.cat([self.out_cache_loc, running_batch.out_cache_loc])
456
- extend_num_tokens = self.extend_num_tokens + running_batch.batch_size()
494
+ extend_num_tokens = self.extend_num_tokens + running_bs
495
+
457
496
  self.merge(running_batch)
458
497
  self.input_ids = input_ids
459
498
  self.out_cache_loc = out_cache_loc
460
499
  self.extend_num_tokens = extend_num_tokens
461
- self.prefix_lens_cpu = prefix_lens_cpu
500
+
501
+ # NOTE: prefix_indices is what has been cached, but we don't cache each decode step
502
+ self.prefix_lens_cpu.extend(
503
+ [
504
+ len(r.origin_input_ids) + len(r.output_ids) - 1
505
+ for r in running_batch.reqs
506
+ ]
507
+ )
508
+ self.extend_lens_cpu.extend([1] * running_bs)
509
+ self.extend_logprob_start_lens_cpu.extend([0] * running_bs)
462
510
 
463
511
  def check_decode_mem(self):
464
512
  bs = self.batch_size()
@@ -625,6 +673,8 @@ class ScheduleBatch:
625
673
  return jump_forward_reqs
626
674
 
627
675
  def prepare_for_decode(self, input_ids=None):
676
+ self.forward_mode = ForwardMode.DECODE
677
+
628
678
  if input_ids is None:
629
679
  input_ids = [
630
680
  r.output_ids[-1] if r.output_ids else r.origin_input_ids[-1]
@@ -644,8 +694,6 @@ class ScheduleBatch:
644
694
  self.req_pool_indices, self.seq_lens - 1
645
695
  ] = self.out_cache_loc
646
696
 
647
- self.sampling_info.update_regex_vocab_mask(self)
648
-
649
697
  def filter_batch(self, unfinished_indices: List[int]):
650
698
  if unfinished_indices is None or len(unfinished_indices) == 0:
651
699
  # Filter out all requests
@@ -665,6 +713,7 @@ class ScheduleBatch:
665
713
  self.out_cache_loc = None
666
714
  self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in unfinished_indices]
667
715
  self.return_logprob = any(req.return_logprob for req in self.reqs)
716
+ self.has_stream = any(req.stream for req in self.reqs)
668
717
 
669
718
  self.sampling_info.filter(unfinished_indices, new_indices)
670
719
 
@@ -675,7 +724,6 @@ class ScheduleBatch:
675
724
  self.sampling_info.merge(other.sampling_info)
676
725
 
677
726
  self.reqs.extend(other.reqs)
678
-
679
727
  self.req_pool_indices = torch.concat(
680
728
  [self.req_pool_indices, other.req_pool_indices]
681
729
  )
@@ -686,18 +734,4 @@ class ScheduleBatch:
686
734
  self.out_cache_loc = None
687
735
  self.top_logprobs_nums.extend(other.top_logprobs_nums)
688
736
  self.return_logprob = any(req.return_logprob for req in self.reqs)
689
-
690
- def check_sample_results(self, sample_output: SampleOutput):
691
- if not torch.all(sample_output.success):
692
- probs = sample_output.probs
693
- batch_next_token_ids = sample_output.batch_next_token_ids
694
- logging.warning("Sampling failed, fallback to top_k=1 strategy")
695
- probs = probs.masked_fill(torch.isnan(probs), 0.0)
696
- argmax_ids = torch.argmax(probs, dim=-1)
697
- batch_next_token_ids = torch.where(
698
- sample_output.success, batch_next_token_ids, argmax_ids
699
- )
700
- sample_output.probs = probs
701
- sample_output.batch_next_token_ids = batch_next_token_ids
702
-
703
- return sample_output.batch_next_token_ids
737
+ self.has_stream = any(req.stream for req in self.reqs)