sglang 0.3.2__py3-none-any.whl → 0.3.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (87) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +23 -1
  3. sglang/bench_latency.py +46 -25
  4. sglang/bench_serving.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +14 -1
  6. sglang/lang/interpreter.py +16 -6
  7. sglang/lang/ir.py +20 -4
  8. sglang/srt/configs/model_config.py +11 -9
  9. sglang/srt/constrained/fsm_cache.py +9 -1
  10. sglang/srt/constrained/jump_forward.py +15 -2
  11. sglang/srt/layers/activation.py +4 -4
  12. sglang/srt/layers/attention/__init__.py +49 -0
  13. sglang/srt/layers/attention/flashinfer_backend.py +277 -0
  14. sglang/srt/layers/{flashinfer_utils.py → attention/flashinfer_utils.py} +82 -80
  15. sglang/srt/layers/attention/triton_backend.py +161 -0
  16. sglang/srt/layers/{triton_attention → attention/triton_ops}/extend_attention.py +3 -1
  17. sglang/srt/layers/layernorm.py +4 -4
  18. sglang/srt/layers/logits_processor.py +19 -15
  19. sglang/srt/layers/pooler.py +3 -3
  20. sglang/srt/layers/quantization/__init__.py +0 -2
  21. sglang/srt/layers/radix_attention.py +6 -4
  22. sglang/srt/layers/sampler.py +6 -4
  23. sglang/srt/layers/torchao_utils.py +18 -0
  24. sglang/srt/lora/lora.py +20 -21
  25. sglang/srt/lora/lora_manager.py +97 -25
  26. sglang/srt/managers/detokenizer_manager.py +31 -18
  27. sglang/srt/managers/image_processor.py +187 -0
  28. sglang/srt/managers/io_struct.py +99 -75
  29. sglang/srt/managers/schedule_batch.py +184 -63
  30. sglang/srt/managers/{policy_scheduler.py → schedule_policy.py} +31 -21
  31. sglang/srt/managers/scheduler.py +1021 -0
  32. sglang/srt/managers/tokenizer_manager.py +120 -248
  33. sglang/srt/managers/tp_worker.py +28 -925
  34. sglang/srt/mem_cache/memory_pool.py +34 -52
  35. sglang/srt/model_executor/cuda_graph_runner.py +15 -19
  36. sglang/srt/model_executor/forward_batch_info.py +94 -95
  37. sglang/srt/model_executor/model_runner.py +76 -75
  38. sglang/srt/models/baichuan.py +10 -10
  39. sglang/srt/models/chatglm.py +12 -12
  40. sglang/srt/models/commandr.py +10 -10
  41. sglang/srt/models/dbrx.py +12 -12
  42. sglang/srt/models/deepseek.py +10 -10
  43. sglang/srt/models/deepseek_v2.py +14 -15
  44. sglang/srt/models/exaone.py +10 -10
  45. sglang/srt/models/gemma.py +10 -10
  46. sglang/srt/models/gemma2.py +11 -11
  47. sglang/srt/models/gpt_bigcode.py +10 -10
  48. sglang/srt/models/grok.py +10 -10
  49. sglang/srt/models/internlm2.py +10 -10
  50. sglang/srt/models/llama.py +14 -10
  51. sglang/srt/models/llama_classification.py +5 -5
  52. sglang/srt/models/llama_embedding.py +4 -4
  53. sglang/srt/models/llama_reward.py +142 -0
  54. sglang/srt/models/llava.py +39 -33
  55. sglang/srt/models/llavavid.py +31 -28
  56. sglang/srt/models/minicpm.py +10 -10
  57. sglang/srt/models/minicpm3.py +14 -15
  58. sglang/srt/models/mixtral.py +10 -10
  59. sglang/srt/models/mixtral_quant.py +10 -10
  60. sglang/srt/models/olmoe.py +10 -10
  61. sglang/srt/models/qwen.py +10 -10
  62. sglang/srt/models/qwen2.py +11 -11
  63. sglang/srt/models/qwen2_moe.py +10 -10
  64. sglang/srt/models/stablelm.py +10 -10
  65. sglang/srt/models/torch_native_llama.py +506 -0
  66. sglang/srt/models/xverse.py +10 -10
  67. sglang/srt/models/xverse_moe.py +10 -10
  68. sglang/srt/sampling/sampling_batch_info.py +36 -27
  69. sglang/srt/sampling/sampling_params.py +3 -1
  70. sglang/srt/server.py +170 -119
  71. sglang/srt/server_args.py +54 -27
  72. sglang/srt/utils.py +101 -128
  73. sglang/test/runners.py +71 -26
  74. sglang/test/test_programs.py +38 -5
  75. sglang/test/test_utils.py +18 -9
  76. sglang/version.py +1 -1
  77. {sglang-0.3.2.dist-info → sglang-0.3.3.dist-info}/METADATA +37 -19
  78. sglang-0.3.3.dist-info/RECORD +139 -0
  79. sglang/srt/layers/attention_backend.py +0 -474
  80. sglang/srt/managers/controller_multi.py +0 -207
  81. sglang/srt/managers/controller_single.py +0 -164
  82. sglang-0.3.2.dist-info/RECORD +0 -135
  83. /sglang/srt/layers/{triton_attention → attention/triton_ops}/decode_attention.py +0 -0
  84. /sglang/srt/layers/{triton_attention → attention/triton_ops}/prefill_attention.py +0 -0
  85. {sglang-0.3.2.dist-info → sglang-0.3.3.dist-info}/LICENSE +0 -0
  86. {sglang-0.3.2.dist-info → sglang-0.3.3.dist-info}/WHEEL +0 -0
  87. {sglang-0.3.2.dist-info → sglang-0.3.3.dist-info}/top_level.txt +0 -0
@@ -18,7 +18,6 @@ The definition of objects transfered between different
18
18
  processes (TokenizerManager, DetokenizerManager, Controller).
19
19
  """
20
20
 
21
- import copy
22
21
  import uuid
23
22
  from dataclasses import dataclass
24
23
  from typing import Dict, List, Optional, Union
@@ -37,7 +36,7 @@ class GenerateReqInput:
37
36
  # See also python/sglang/srt/utils.py:load_image.
38
37
  image_data: Optional[Union[List[str], str]] = None
39
38
  # The sampling_params. See descriptions below.
40
- sampling_params: Union[List[Dict], Dict] = None
39
+ sampling_params: Optional[Union[List[Dict], Dict]] = None
41
40
  # The request id.
42
41
  rid: Optional[Union[List[str], str]] = None
43
42
  # Whether to return logprobs.
@@ -53,9 +52,6 @@ class GenerateReqInput:
53
52
  stream: bool = False
54
53
  # The modalities of the image data [image, multi-images, video]
55
54
  modalities: Optional[List[str]] = None
56
-
57
- is_single: bool = True
58
-
59
55
  # LoRA related
60
56
  lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
61
57
 
@@ -65,19 +61,41 @@ class GenerateReqInput:
65
61
  ):
66
62
  raise ValueError("Either text or input_ids should be provided.")
67
63
 
68
- if (
69
- isinstance(self.sampling_params, dict)
70
- and self.sampling_params.get("n", 1) != 1
71
- ):
72
- is_single = False
64
+ self.is_single = False
65
+ if self.text is not None:
66
+ if isinstance(self.text, str):
67
+ self.is_single = True
68
+ self.batch_size = 1
69
+ else:
70
+ self.batch_size = len(self.text)
73
71
  else:
74
- if self.text is not None:
75
- is_single = isinstance(self.text, str)
72
+ if isinstance(self.input_ids[0], int):
73
+ self.is_single = True
74
+ self.batch_size = 1
76
75
  else:
77
- is_single = isinstance(self.input_ids[0], int)
78
- self.is_single = is_single
79
-
80
- if is_single:
76
+ self.batch_size = len(self.input_ids)
77
+
78
+ if self.sampling_params is None:
79
+ self.parallel_sample_num = 1
80
+ elif isinstance(self.sampling_params, dict):
81
+ self.parallel_sample_num = self.sampling_params.get("n", 1)
82
+ else: # isinstance(self.sampling_params, list):
83
+ self.parallel_sample_num = self.sampling_params[0].get("n", 1)
84
+ for sp in self.sampling_params:
85
+ # TODO cope with the case that the parallel_sample_num is different for different samples
86
+ assert self.parallel_sample_num == sp.get(
87
+ "n", 1
88
+ ), "The parallel_sample_num should be the same for all samples in sample params."
89
+
90
+ if self.parallel_sample_num > 1:
91
+ if self.is_single:
92
+ self.is_single = False
93
+ if self.text is not None:
94
+ self.text = [self.text]
95
+ if self.input_ids is not None:
96
+ self.input_ids = [self.input_ids]
97
+
98
+ if self.is_single:
81
99
  if self.sampling_params is None:
82
100
  self.sampling_params = {}
83
101
  if self.rid is None:
@@ -89,79 +107,54 @@ class GenerateReqInput:
89
107
  if self.top_logprobs_num is None:
90
108
  self.top_logprobs_num = 0
91
109
  else:
92
- parallel_sample_num_list = []
93
- if isinstance(self.sampling_params, dict):
94
- parallel_sample_num = self.sampling_params.get("n", 1)
95
- elif isinstance(self.sampling_params, list):
96
- for sp in self.sampling_params:
97
- parallel_sample_num = sp.get("n", 1)
98
- parallel_sample_num_list.append(parallel_sample_num)
99
- parallel_sample_num = max(parallel_sample_num_list)
100
- all_equal = all(
101
- element == parallel_sample_num
102
- for element in parallel_sample_num_list
103
- )
104
- if parallel_sample_num > 1 and (not all_equal):
105
- # TODO cope with the case that the parallel_sample_num is different for different samples
106
- raise ValueError(
107
- "The parallel_sample_num should be the same for all samples in sample params."
108
- )
110
+ if self.parallel_sample_num == 1:
111
+ num = self.batch_size
109
112
  else:
110
- parallel_sample_num = 1
111
- self.parallel_sample_num = parallel_sample_num
112
-
113
- if parallel_sample_num != 1:
114
- # parallel sampling +1 represents the original prefill stage
115
- num = parallel_sample_num + 1
116
- if isinstance(self.text, list):
117
- # suppot batch operation
118
- self.batch_size = len(self.text)
119
- num = num * len(self.text)
120
- elif isinstance(self.input_ids, list) and isinstance(
121
- self.input_ids[0], list
122
- ):
123
- self.batch_size = len(self.input_ids)
124
- num = num * len(self.input_ids)
125
- else:
126
- self.batch_size = 1
127
- else:
128
- # support select operation
129
- num = len(self.text) if self.text is not None else len(self.input_ids)
130
- self.batch_size = num
113
+ # FIXME support cascade inference
114
+ # first bs samples are used for caching the prefix for parallel sampling
115
+ num = self.batch_size + self.parallel_sample_num * self.batch_size
131
116
 
132
117
  if self.image_data is None:
133
118
  self.image_data = [None] * num
134
119
  elif not isinstance(self.image_data, list):
135
120
  self.image_data = [self.image_data] * num
136
121
  elif isinstance(self.image_data, list):
137
- # multi-image with n > 1
122
+ # FIXME incorrect order for duplication
138
123
  self.image_data = self.image_data * num
139
124
 
140
125
  if self.sampling_params is None:
141
126
  self.sampling_params = [{}] * num
142
127
  elif not isinstance(self.sampling_params, list):
143
128
  self.sampling_params = [self.sampling_params] * num
129
+ else:
130
+ assert self.parallel_sample_num == 1
144
131
 
145
132
  if self.rid is None:
146
133
  self.rid = [uuid.uuid4().hex for _ in range(num)]
147
134
  else:
148
- if not isinstance(self.rid, list):
149
- raise ValueError("The rid should be a list.")
135
+ assert isinstance(self.rid, list), "The rid should be a list."
136
+ assert self.parallel_sample_num == 1
150
137
 
151
138
  if self.return_logprob is None:
152
139
  self.return_logprob = [False] * num
153
140
  elif not isinstance(self.return_logprob, list):
154
141
  self.return_logprob = [self.return_logprob] * num
142
+ else:
143
+ assert self.parallel_sample_num == 1
155
144
 
156
145
  if self.logprob_start_len is None:
157
146
  self.logprob_start_len = [-1] * num
158
147
  elif not isinstance(self.logprob_start_len, list):
159
148
  self.logprob_start_len = [self.logprob_start_len] * num
149
+ else:
150
+ assert self.parallel_sample_num == 1
160
151
 
161
152
  if self.top_logprobs_num is None:
162
153
  self.top_logprobs_num = [0] * num
163
154
  elif not isinstance(self.top_logprobs_num, list):
164
155
  self.top_logprobs_num = [self.top_logprobs_num] * num
156
+ else:
157
+ assert self.parallel_sample_num == 1
165
158
 
166
159
 
167
160
  @dataclass
@@ -172,12 +165,8 @@ class TokenizedGenerateReqInput:
172
165
  input_text: str
173
166
  # The input token ids
174
167
  input_ids: List[int]
175
- # The pixel values for input images
176
- pixel_values: List[float]
177
- # The hash values of input images
178
- image_hashes: List[int]
179
- # The image sizes
180
- image_sizes: List[List[int]]
168
+ # The image input
169
+ image_inputs: dict
181
170
  # The sampling parameters
182
171
  sampling_params: SamplingParams
183
172
  # Whether to return the logprobs
@@ -188,8 +177,6 @@ class TokenizedGenerateReqInput:
188
177
  top_logprobs_num: int
189
178
  # Whether to stream output
190
179
  stream: bool
191
- # Modalities of the input images
192
- modalites: Optional[List[str]] = None
193
180
 
194
181
  # LoRA related
195
182
  lora_path: Optional[str] = None # None means just use the base model
@@ -206,8 +193,6 @@ class EmbeddingReqInput:
206
193
  # Dummy sampling params for compatibility
207
194
  sampling_params: Union[List[Dict], Dict] = None
208
195
 
209
- is_single: bool = True
210
-
211
196
  def post_init(self):
212
197
  if (self.text is None and self.input_ids is None) or (
213
198
  self.text is not None and self.input_ids is not None
@@ -215,12 +200,11 @@ class EmbeddingReqInput:
215
200
  raise ValueError("Either text or input_ids should be provided.")
216
201
 
217
202
  if self.text is not None:
218
- is_single = isinstance(self.text, str)
203
+ self.is_single = isinstance(self.text, str)
219
204
  else:
220
- is_single = isinstance(self.input_ids[0], int)
221
- self.is_single = is_single
205
+ self.is_single = isinstance(self.input_ids[0], int)
222
206
 
223
- if is_single:
207
+ if self.is_single:
224
208
  if self.rid is None:
225
209
  self.rid = uuid.uuid4().hex
226
210
  if self.sampling_params is None:
@@ -254,6 +238,50 @@ class TokenizedEmbeddingReqInput:
254
238
  sampling_params: SamplingParams
255
239
 
256
240
 
241
+ @dataclass
242
+ class RewardReqInput:
243
+ # The input prompt in the chat format. It can be a single prompt or a batch of prompts.
244
+ conv: Union[List[List[Dict]], List[Dict]]
245
+ # The request id.
246
+ rid: Optional[Union[List[str], str]] = None
247
+ # Dummy sampling params for compatibility
248
+ sampling_params: Union[List[Dict], Dict] = None
249
+
250
+ def post_init(self):
251
+ self.is_single = isinstance(self.conv[0], dict)
252
+
253
+ if self.is_single:
254
+ if self.rid is None:
255
+ self.rid = uuid.uuid4().hex
256
+ if self.sampling_params is None:
257
+ self.sampling_params = {}
258
+ self.sampling_params["max_new_tokens"] = 1
259
+ else:
260
+ # support select operation
261
+ self.batch_size = len(self.conv)
262
+ if self.rid is None:
263
+ self.rid = [uuid.uuid4().hex for _ in range(self.batch_size)]
264
+ else:
265
+ if not isinstance(self.rid, list):
266
+ raise ValueError("The rid should be a list.")
267
+ if self.sampling_params is None:
268
+ self.sampling_params = [{}] * self.batch_size
269
+ for i in range(self.batch_size):
270
+ self.sampling_params[i]["max_new_tokens"] = 1
271
+
272
+
273
+ @dataclass
274
+ class TokenizedRewardReqInput:
275
+ # The request id
276
+ rid: str
277
+ # The input text
278
+ input_text: str
279
+ # The input token ids
280
+ input_ids: List[int]
281
+ # Dummy sampling params for compatibility
282
+ sampling_params: SamplingParams
283
+
284
+
257
285
  @dataclass
258
286
  class BatchTokenIDOut:
259
287
  # The request id
@@ -268,10 +296,6 @@ class BatchTokenIDOut:
268
296
  meta_info: List[Dict]
269
297
  finished_reason: List[BaseFinishReason]
270
298
 
271
- def __post_init__(self):
272
- # deepcopy meta_info to avoid modification in place
273
- self.meta_info = copy.deepcopy(self.meta_info)
274
-
275
299
 
276
300
  @dataclass
277
301
  class BatchStrOut:
@@ -1,5 +1,3 @@
1
- from __future__ import annotations
2
-
3
1
  """
4
2
  Copyright 2023-2024 SGLang Team
5
3
  Licensed under the Apache License, Version 2.0 (the "License");
@@ -15,7 +13,19 @@ See the License for the specific language governing permissions and
15
13
  limitations under the License.
16
14
  """
17
15
 
18
- """Meta data for requests and batches"""
16
+ """
17
+ Store information about requests and batches.
18
+
19
+ The following is the flow of data structures for a batch:
20
+
21
+ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
22
+
23
+ - ScheduleBatch is managed by `scheduler.py::Scheduler`.
24
+ It contains high-level scheduling data. Most of the data is on the CPU.
25
+ - ModelWorkerBatch is managed by `tp_worker.py::TpModelWorker`.
26
+ - ForwardBatch is managed by `model_runner.py::ModelRunner`.
27
+ It contains low-level tensor data. Most of the data consists of GPU tensors.
28
+ """
19
29
 
20
30
  import logging
21
31
  from dataclasses import dataclass
@@ -31,6 +41,7 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache
31
41
  from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
32
42
  from sglang.srt.model_executor.forward_batch_info import ForwardMode
33
43
  from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
44
+ from sglang.srt.sampling.sampling_params import SamplingParams
34
45
  from sglang.srt.server_args import ServerArgs
35
46
 
36
47
  INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
@@ -102,14 +113,50 @@ class FINISH_ABORT(BaseFinishReason):
102
113
  }
103
114
 
104
115
 
116
+ @dataclass
117
+ class ImageInputs:
118
+ """The image related inputs."""
119
+
120
+ pixel_values: torch.Tensor
121
+ image_hash: int
122
+ image_sizes: Optional[list] = None
123
+ image_offsets: Optional[list] = None
124
+ pad_values: Optional[list] = None
125
+ modalities: Optional[list] = None
126
+
127
+ image_embeds: Optional[List[torch.Tensor]] = None
128
+ aspect_ratio_ids: Optional[List[torch.Tensor]] = None
129
+ aspect_ratio_mask: Optional[List[torch.Tensor]] = None
130
+
131
+ @staticmethod
132
+ def from_dict(obj, vocab_size):
133
+ # Use image hash as fake token_ids, which is then used for prefix matching
134
+ ret = ImageInputs(
135
+ pixel_values=obj["pixel_values"],
136
+ image_hash=hash(tuple(obj["image_hashes"])),
137
+ )
138
+ image_hash = ret.image_hash
139
+ ret.pad_values = [
140
+ (image_hash) % vocab_size,
141
+ (image_hash >> 16) % vocab_size,
142
+ (image_hash >> 32) % vocab_size,
143
+ (image_hash >> 64) % vocab_size,
144
+ ]
145
+ ret.image_sizes = obj["image_sizes"]
146
+ # Only when pixel values is not None we have modalities
147
+ ret.modalities = obj["modalities"] or ["image"]
148
+ return ret
149
+
150
+
105
151
  class Req:
106
- """Store all inforamtion of a request."""
152
+ """The input and output status of a request."""
107
153
 
108
154
  def __init__(
109
155
  self,
110
156
  rid: str,
111
157
  origin_input_text: str,
112
158
  origin_input_ids: Tuple[int],
159
+ sampling_params: SamplingParams,
113
160
  lora_path: Optional[str] = None,
114
161
  ):
115
162
  # Input and output info
@@ -119,6 +166,8 @@ class Req:
119
166
  self.origin_input_ids = origin_input_ids
120
167
  self.output_ids = [] # Each decode stage's output ids
121
168
  self.fill_ids = None # fill_ids = origin_input_ids + output_ids
169
+
170
+ self.sampling_params = sampling_params
122
171
  self.lora_path = lora_path
123
172
 
124
173
  # Memory info
@@ -127,6 +176,7 @@ class Req:
127
176
  # Check finish
128
177
  self.tokenizer = None
129
178
  self.finished_reason = None
179
+ self.stream = False
130
180
 
131
181
  # For incremental decoding
132
182
  # ----- | --------- read_ids -------|
@@ -147,21 +197,13 @@ class Req:
147
197
  self.completion_tokens_wo_jump_forward = 0
148
198
 
149
199
  # For vision inputs
150
- self.pixel_values = None
151
- self.image_sizes = None
152
- self.image_offsets = None
153
- self.pad_value = None
154
- self.modalities = None
200
+ self.image_inputs: Optional[ImageInputs] = None
155
201
 
156
202
  # Prefix info
157
203
  self.prefix_indices = []
158
204
  self.extend_input_len = 0
159
205
  self.last_node = None
160
206
 
161
- # Sampling parameters
162
- self.sampling_params = None
163
- self.stream = False
164
-
165
207
  # Logprobs (arguments)
166
208
  self.return_logprob = False
167
209
  self.logprob_start_len = 0
@@ -363,28 +405,32 @@ class ScheduleBatch:
363
405
  sampling_info: SamplingBatchInfo = None
364
406
 
365
407
  # Batched arguments to model runner
366
- input_ids: torch.Tensor = None
367
- req_pool_indices: torch.Tensor = None
368
- seq_lens: torch.Tensor = None
369
- position_ids_offsets: torch.Tensor = None
408
+ input_ids: List[int] = None
409
+ req_pool_indices: List[int] = None
410
+ seq_lens: List[int] = None
370
411
  out_cache_loc: torch.Tensor = None
371
- extend_num_tokens: int = None
372
-
373
- # For mixed chunekd prefill
374
- prefix_lens_cpu: List[int] = None
375
- running_bs: int = None
376
412
 
377
413
  # For processing logprobs
378
414
  return_logprob: bool = False
379
- top_logprobs_nums: List[int] = None
415
+ top_logprobs_nums: Optional[List[int]] = None
416
+
417
+ # For extend and mixed chunekd prefill
418
+ prefix_lens: List[int] = None
419
+ extend_lens: List[int] = None
420
+ extend_num_tokens: int = None
421
+ running_bs: int = None
380
422
 
381
423
  # Stream
382
424
  has_stream: bool = False
383
425
 
426
+ # Has regex
427
+ has_regex: bool = False
428
+
384
429
  @classmethod
385
430
  def init_new(cls, reqs, req_to_token_pool, token_to_kv_pool, tree_cache):
386
431
  return_logprob = any(req.return_logprob for req in reqs)
387
432
  has_stream = any(req.stream for req in reqs)
433
+ has_regex = any(req.regex_fsm for req in reqs)
388
434
 
389
435
  return cls(
390
436
  reqs=reqs,
@@ -393,6 +439,7 @@ class ScheduleBatch:
393
439
  tree_cache=tree_cache,
394
440
  return_logprob=return_logprob,
395
441
  has_stream=has_stream,
442
+ has_regex=has_regex,
396
443
  )
397
444
 
398
445
  def batch_size(self):
@@ -436,12 +483,12 @@ class ScheduleBatch:
436
483
  seq_lens = []
437
484
 
438
485
  # Allocate memory
439
- req_pool_indices_cpu = self.alloc_req_slots(bs)
486
+ req_pool_indices = self.alloc_req_slots(bs)
440
487
  out_cache_loc = self.alloc_token_slots(extend_num_tokens)
441
488
 
442
489
  pt = 0
443
490
  for i, req in enumerate(reqs):
444
- req.req_pool_idx = req_pool_indices_cpu[i]
491
+ req.req_pool_idx = req_pool_indices[i]
445
492
  pre_len, seq_len = len(req.prefix_indices), len(req.fill_ids)
446
493
  seq_lens.append(seq_len)
447
494
  assert seq_len - pre_len == req.extend_input_len
@@ -467,18 +514,19 @@ class ScheduleBatch:
467
514
  pt += req.extend_input_len
468
515
 
469
516
  # Set fields
470
- with torch.device("cuda"):
517
+ with out_cache_loc.device:
471
518
  self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32)
472
- self.req_pool_indices = torch.tensor(req_pool_indices_cpu)
473
- self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32)
474
- self.position_ids_offsets = torch.zeros((bs,), dtype=torch.int64)
519
+ self.req_pool_indices = torch.tensor(req_pool_indices)
520
+ self.seq_lens = torch.tensor(seq_lens)
475
521
 
476
522
  self.extend_num_tokens = extend_num_tokens
477
523
  self.out_cache_loc = out_cache_loc
478
- self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
479
- self.prefix_lens_cpu = [len(r.prefix_indices) for r in reqs]
480
- self.extend_lens_cpu = [r.extend_input_len for r in reqs]
481
- self.extend_logprob_start_lens_cpu = [r.extend_logprob_start_len for r in reqs]
524
+ if self.return_logprob:
525
+ self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
526
+ self.prefix_lens = [len(r.prefix_indices) for r in reqs]
527
+ self.extend_lens = [r.extend_input_len for r in reqs]
528
+ self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
529
+
482
530
  self.sampling_info = SamplingBatchInfo.from_schedule_batch(self, vocab_size)
483
531
 
484
532
  def mix_with_running(self, running_batch: "ScheduleBatch"):
@@ -493,20 +541,20 @@ class ScheduleBatch:
493
541
  out_cache_loc = torch.cat([self.out_cache_loc, running_batch.out_cache_loc])
494
542
  extend_num_tokens = self.extend_num_tokens + running_bs
495
543
 
496
- self.merge(running_batch)
544
+ self.merge_batch(running_batch)
497
545
  self.input_ids = input_ids
498
546
  self.out_cache_loc = out_cache_loc
499
547
  self.extend_num_tokens = extend_num_tokens
500
548
 
501
549
  # NOTE: prefix_indices is what has been cached, but we don't cache each decode step
502
- self.prefix_lens_cpu.extend(
550
+ self.prefix_lens.extend(
503
551
  [
504
552
  len(r.origin_input_ids) + len(r.output_ids) - 1
505
553
  for r in running_batch.reqs
506
554
  ]
507
555
  )
508
- self.extend_lens_cpu.extend([1] * running_bs)
509
- self.extend_logprob_start_lens_cpu.extend([0] * running_bs)
556
+ self.extend_lens.extend([1] * running_bs)
557
+ self.extend_logprob_start_lens.extend([0] * running_bs)
510
558
 
511
559
  def check_decode_mem(self):
512
560
  bs = len(self.reqs)
@@ -598,7 +646,7 @@ class ScheduleBatch:
598
646
 
599
647
  return retracted_reqs, new_estimate_ratio
600
648
 
601
- def check_for_jump_forward(self, model_runner):
649
+ def check_for_jump_forward(self, pad_input_ids_func):
602
650
  jump_forward_reqs = []
603
651
  filter_indices = [i for i in range(len(self.reqs))]
604
652
 
@@ -654,15 +702,9 @@ class ScheduleBatch:
654
702
  self.tree_cache.cache_finished_req(req, cur_all_ids)
655
703
 
656
704
  # re-applying image padding
657
- if req.pixel_values is not None:
658
- (
659
- req.origin_input_ids,
660
- req.image_offsets,
661
- ) = model_runner.model.pad_input_ids(
662
- req.origin_input_ids_unpadded,
663
- req.pad_value,
664
- req.pixel_values,
665
- req.image_sizes,
705
+ if req.image_inputs is not None:
706
+ req.origin_input_ids = pad_input_ids_func(
707
+ req.origin_input_ids_unpadded, req.image_inputs
666
708
  )
667
709
 
668
710
  jump_forward_reqs.append(req)
@@ -681,7 +723,9 @@ class ScheduleBatch:
681
723
  for r in self.reqs
682
724
  ]
683
725
 
684
- self.input_ids = torch.tensor(input_ids, dtype=torch.int32, device="cuda")
726
+ self.input_ids = torch.tensor(
727
+ input_ids, dtype=torch.int32, device=self.seq_lens.device
728
+ )
685
729
  self.seq_lens.add_(1)
686
730
 
687
731
  # Alloc mem
@@ -703,33 +747,110 @@ class ScheduleBatch:
703
747
  return
704
748
 
705
749
  self.reqs = [self.reqs[i] for i in unfinished_indices]
706
- new_indices = torch.tensor(unfinished_indices, dtype=torch.int32, device="cuda")
707
- self.seq_lens = self.seq_lens[new_indices]
708
- self.input_ids = None
750
+ new_indices = torch.tensor(
751
+ unfinished_indices, dtype=torch.int32, device=self.seq_lens.device
752
+ )
709
753
  self.req_pool_indices = self.req_pool_indices[new_indices]
710
- self.position_ids_offsets = self.position_ids_offsets[new_indices]
754
+ self.seq_lens = self.seq_lens[new_indices]
711
755
  self.out_cache_loc = None
712
- self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in unfinished_indices]
713
756
  self.return_logprob = any(req.return_logprob for req in self.reqs)
757
+ if self.return_logprob:
758
+ self.top_logprobs_nums = [
759
+ self.top_logprobs_nums[i] for i in unfinished_indices
760
+ ]
761
+ else:
762
+ self.top_logprobs_nums = None
763
+
714
764
  self.has_stream = any(req.stream for req in self.reqs)
765
+ self.has_regex = any(req.regex_fsm for req in self.reqs)
715
766
 
716
- self.sampling_info.filter(unfinished_indices, new_indices)
767
+ self.sampling_info.filter_batch(unfinished_indices, new_indices)
717
768
 
718
- def merge(self, other: "ScheduleBatch"):
769
+ def merge_batch(self, other: "ScheduleBatch"):
719
770
  # Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
720
771
  # orchestrator.merge() depends on Batch.reqs during preparation of each penalizers, so it
721
772
  # needs to be called with pre-merged Batch.reqs.
722
- self.sampling_info.merge(other.sampling_info)
773
+ self.sampling_info.merge_batch(other.sampling_info)
723
774
 
724
- self.reqs.extend(other.reqs)
725
775
  self.req_pool_indices = torch.concat(
726
776
  [self.req_pool_indices, other.req_pool_indices]
727
777
  )
728
778
  self.seq_lens = torch.concat([self.seq_lens, other.seq_lens])
729
- self.position_ids_offsets = torch.concat(
730
- [self.position_ids_offsets, other.position_ids_offsets]
731
- )
732
779
  self.out_cache_loc = None
733
- self.top_logprobs_nums.extend(other.top_logprobs_nums)
734
- self.return_logprob = any(req.return_logprob for req in self.reqs)
735
- self.has_stream = any(req.stream for req in self.reqs)
780
+ if self.return_logprob and other.return_logprob:
781
+ self.top_logprobs_nums.extend(other.top_logprobs_nums)
782
+ elif self.return_logprob:
783
+ self.top_logprobs_nums.extend([0] * len(other.reqs))
784
+ elif other.return_logprob:
785
+ self.top_logprobs_nums = [0] * len(self.reqs) + other.top_logprobs_nums
786
+ self.reqs.extend(other.reqs)
787
+
788
+ self.return_logprob = self.return_logprob or other.return_logprob
789
+ self.has_stream = self.has_stream or other.has_stream
790
+ self.has_regex = self.has_regex or other.has_regex
791
+
792
+ def get_model_worker_batch(self):
793
+ if self.forward_mode.is_decode():
794
+ extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = (
795
+ image_inputs
796
+ ) = None
797
+ else:
798
+ extend_seq_lens = self.extend_lens
799
+ extend_prefix_lens = self.prefix_lens
800
+ extend_logprob_start_lens = self.extend_logprob_start_lens
801
+ image_inputs = [r.image_inputs for r in self.reqs]
802
+
803
+ lora_paths = [req.lora_path for req in self.reqs]
804
+ if self.has_regex:
805
+ self.sampling_info.regex_fsms = [req.regex_fsm for req in self.reqs]
806
+ self.sampling_info.regex_fsm_states = [
807
+ req.regex_fsm_state for req in self.reqs
808
+ ]
809
+
810
+ return ModelWorkerBatch(
811
+ forward_mode=self.forward_mode,
812
+ input_ids=self.input_ids,
813
+ req_pool_indices=self.req_pool_indices,
814
+ seq_lens=self.seq_lens,
815
+ out_cache_loc=self.out_cache_loc,
816
+ return_logprob=self.return_logprob,
817
+ top_logprobs_nums=self.top_logprobs_nums,
818
+ extend_seq_lens=extend_seq_lens,
819
+ extend_prefix_lens=extend_prefix_lens,
820
+ extend_logprob_start_lens=extend_logprob_start_lens,
821
+ image_inputs=image_inputs,
822
+ lora_paths=lora_paths,
823
+ sampling_info=self.sampling_info,
824
+ )
825
+
826
+
827
+ @dataclass
828
+ class ModelWorkerBatch:
829
+ # The forward mode
830
+ forward_mode: ForwardMode
831
+ # The input ids
832
+ input_ids: torch.Tensor
833
+ # The indices of requests in the req_to_token_pool
834
+ req_pool_indices: torch.Tensor
835
+ # The sequence length
836
+ seq_lens: torch.Tensor
837
+ # The indices of output tokens in the token_to_kv_pool
838
+ out_cache_loc: torch.Tensor
839
+
840
+ # For logprob
841
+ return_logprob: bool
842
+ top_logprobs_nums: Optional[List[int]]
843
+
844
+ # For extend
845
+ extend_seq_lens: Optional[List[int]]
846
+ extend_prefix_lens: Optional[List[int]]
847
+ extend_logprob_start_lens: Optional[List[int]]
848
+
849
+ # For multimodal
850
+ image_inputs: Optional[List[ImageInputs]]
851
+
852
+ # For LoRA
853
+ lora_paths: Optional[List[str]]
854
+
855
+ # Sampling info
856
+ sampling_info: SamplingBatchInfo