sglang 0.3.4__py3-none-any.whl → 0.3.4.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. sglang/bench_latency.py +2 -1
  2. sglang/lang/chat_template.py +17 -0
  3. sglang/launch_server_llavavid.py +1 -1
  4. sglang/srt/configs/__init__.py +3 -0
  5. sglang/srt/configs/model_config.py +27 -2
  6. sglang/srt/configs/qwen2vl.py +133 -0
  7. sglang/srt/constrained/fsm_cache.py +10 -3
  8. sglang/srt/conversation.py +27 -0
  9. sglang/srt/hf_transformers_utils.py +16 -1
  10. sglang/srt/layers/attention/__init__.py +16 -5
  11. sglang/srt/layers/attention/double_sparsity_backend.py +22 -6
  12. sglang/srt/layers/attention/flashinfer_backend.py +174 -54
  13. sglang/srt/layers/attention/triton_backend.py +22 -6
  14. sglang/srt/layers/attention/triton_ops/prefill_attention.py +26 -4
  15. sglang/srt/layers/linear.py +89 -63
  16. sglang/srt/layers/logits_processor.py +5 -5
  17. sglang/srt/layers/rotary_embedding.py +112 -0
  18. sglang/srt/layers/sampler.py +51 -39
  19. sglang/srt/lora/lora.py +3 -1
  20. sglang/srt/managers/data_parallel_controller.py +1 -1
  21. sglang/srt/managers/detokenizer_manager.py +4 -0
  22. sglang/srt/managers/image_processor.py +186 -13
  23. sglang/srt/managers/io_struct.py +10 -0
  24. sglang/srt/managers/schedule_batch.py +238 -68
  25. sglang/srt/managers/scheduler.py +69 -50
  26. sglang/srt/managers/tokenizer_manager.py +24 -4
  27. sglang/srt/managers/tp_worker.py +26 -111
  28. sglang/srt/managers/tp_worker_overlap_thread.py +209 -0
  29. sglang/srt/mem_cache/memory_pool.py +56 -10
  30. sglang/srt/mem_cache/radix_cache.py +4 -3
  31. sglang/srt/model_executor/cuda_graph_runner.py +87 -28
  32. sglang/srt/model_executor/forward_batch_info.py +83 -3
  33. sglang/srt/model_executor/model_runner.py +32 -11
  34. sglang/srt/models/chatglm.py +3 -3
  35. sglang/srt/models/deepseek_v2.py +2 -2
  36. sglang/srt/models/mllama.py +1004 -0
  37. sglang/srt/models/qwen2_vl.py +724 -0
  38. sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +6 -3
  39. sglang/srt/sampling/sampling_batch_info.py +13 -3
  40. sglang/srt/sampling/sampling_params.py +5 -7
  41. sglang/srt/server.py +12 -0
  42. sglang/srt/server_args.py +10 -0
  43. sglang/srt/utils.py +22 -0
  44. sglang/test/run_eval.py +2 -0
  45. sglang/test/runners.py +20 -1
  46. sglang/test/srt/sampling/penaltylib/utils.py +1 -0
  47. sglang/test/test_utils.py +100 -3
  48. sglang/version.py +1 -1
  49. {sglang-0.3.4.dist-info → sglang-0.3.4.post2.dist-info}/METADATA +17 -18
  50. {sglang-0.3.4.dist-info → sglang-0.3.4.post2.dist-info}/RECORD +53 -48
  51. {sglang-0.3.4.dist-info → sglang-0.3.4.post2.dist-info}/LICENSE +0 -0
  52. {sglang-0.3.4.dist-info → sglang-0.3.4.post2.dist-info}/WHEEL +0 -0
  53. {sglang-0.3.4.dist-info → sglang-0.3.4.post2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,209 @@
1
+ """
2
+ Copyright 2023-2024 SGLang Team
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ """
15
+
16
+ """A tensor parallel worker."""
17
+
18
+ import logging
19
+ import threading
20
+ import time
21
+ from queue import Queue
22
+ from typing import Optional
23
+
24
+ import torch
25
+
26
+ from sglang.srt.managers.io_struct import UpdateWeightReqInput
27
+ from sglang.srt.managers.schedule_batch import ModelWorkerBatch
28
+ from sglang.srt.managers.tp_worker import TpModelWorker
29
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
30
+ from sglang.srt.server_args import ServerArgs
31
+
32
+ logger = logging.getLogger(__name__)
33
+
34
+
35
+ @torch.compile(dynamic=True)
36
+ def resolve_future_token_ids(input_ids, future_token_ids_map):
37
+ input_ids[:] = torch.where(
38
+ input_ids < 0,
39
+ future_token_ids_map[torch.clamp(-input_ids, min=0)],
40
+ input_ids,
41
+ )
42
+
43
+
44
+ class TpModelWorkerClient:
45
+ """A tensor parallel model worker."""
46
+
47
+ def __init__(
48
+ self,
49
+ server_args: ServerArgs,
50
+ gpu_id: int,
51
+ tp_rank: int,
52
+ dp_rank: Optional[int],
53
+ nccl_port: int,
54
+ ):
55
+ # Load the model
56
+ self.worker = TpModelWorker(server_args, gpu_id, tp_rank, dp_rank, nccl_port)
57
+ self.max_running_requests = self.worker.max_running_requests
58
+ self.device = self.worker.device
59
+
60
+ # Init future mappings
61
+ self.future_token_ids_ct = 0
62
+ self.future_token_ids_limit = self.max_running_requests * 3
63
+ self.future_token_ids_map = torch.empty(
64
+ (self.max_running_requests * 5,), dtype=torch.int32, device=self.device
65
+ )
66
+
67
+ # Launch threads
68
+ self.input_queue = Queue()
69
+ self.output_queue = Queue()
70
+ self.forward_stream = torch.cuda.Stream()
71
+ self.forward_thread = threading.Thread(
72
+ target=self.forward_thread_func,
73
+ )
74
+ self.forward_thread.start()
75
+
76
+ self.copy_queue = Queue()
77
+ self.copy_thread = threading.Thread(
78
+ target=self.copy_thread_func,
79
+ )
80
+ self.copy_thread.start()
81
+
82
+ def get_worker_info(self):
83
+ return self.worker.get_worker_info()
84
+
85
+ def get_pad_input_ids_func(self):
86
+ return self.worker.get_pad_input_ids_func()
87
+
88
+ def get_tp_cpu_group(self):
89
+ return self.worker.get_tp_cpu_group()
90
+
91
+ def get_memory_pool(self):
92
+ return (
93
+ self.worker.model_runner.req_to_token_pool,
94
+ self.worker.model_runner.token_to_kv_pool,
95
+ )
96
+
97
+ def forward_thread_func(self):
98
+ with torch.cuda.stream(self.forward_stream):
99
+ self.forward_thread_func_()
100
+
101
+ @torch.inference_mode()
102
+ def forward_thread_func_(self):
103
+ while True:
104
+ self.has_inflight_batch = False
105
+ model_worker_batch, future_token_ids_ct = self.input_queue.get()
106
+ if not model_worker_batch:
107
+ break
108
+ self.has_inflight_batch = True
109
+ self.launch_event = threading.Event()
110
+
111
+ # Resolve future tokens in the input
112
+ input_ids = model_worker_batch.input_ids
113
+ resolve_future_token_ids(input_ids, self.future_token_ids_map)
114
+
115
+ # Run forward
116
+ logits_output, next_token_ids = self.worker.forward_batch_generation(
117
+ model_worker_batch
118
+ )
119
+
120
+ # Update the future token ids map
121
+ bs = len(model_worker_batch.seq_lens)
122
+ self.future_token_ids_map[
123
+ future_token_ids_ct + 1 : future_token_ids_ct + bs + 1
124
+ ] = next_token_ids
125
+
126
+ # Copy results to the CPU
127
+ if model_worker_batch.return_logprob:
128
+ logits_output.next_token_logprobs = logits_output.next_token_logprobs[
129
+ torch.arange(len(next_token_ids), device=self.device),
130
+ next_token_ids,
131
+ ].to("cpu", non_blocking=True)
132
+ if logits_output.input_token_logprobs is not None:
133
+ logits_output.input_token_logprobs = (
134
+ logits_output.input_token_logprobs.to("cpu", non_blocking=True)
135
+ )
136
+ logits_output.normalized_prompt_logprobs = (
137
+ logits_output.normalized_prompt_logprobs.to(
138
+ "cpu", non_blocking=True
139
+ )
140
+ )
141
+ next_token_ids = next_token_ids.to("cpu", non_blocking=True)
142
+ copy_event = torch.cuda.Event(blocking=True)
143
+ copy_event.record()
144
+
145
+ self.launch_event.set()
146
+ self.copy_queue.put((copy_event, logits_output, next_token_ids))
147
+
148
+ def copy_thread_func(self):
149
+ while True:
150
+ copy_event, logits_output, next_token_ids = self.copy_queue.get()
151
+ if not copy_event:
152
+ break
153
+ while not copy_event.query():
154
+ time.sleep(1e-5)
155
+
156
+ if logits_output.next_token_logprobs is not None:
157
+ logits_output.next_token_logprobs = (
158
+ logits_output.next_token_logprobs.tolist()
159
+ )
160
+ if logits_output.input_token_logprobs is not None:
161
+ logits_output.input_token_logprobs = (
162
+ logits_output.input_token_logprobs.tolist()
163
+ )
164
+ logits_output.normalized_prompt_logprobs = (
165
+ logits_output.normalized_prompt_logprobs.tolist()
166
+ )
167
+
168
+ self.output_queue.put((logits_output, next_token_ids.tolist()))
169
+
170
+ def resulve_batch_result(self, bid: int):
171
+ logits_output, next_token_ids = self.output_queue.get()
172
+ if self.has_inflight_batch:
173
+ # Wait until the batch is launched
174
+ self.launch_event.wait()
175
+ return logits_output, next_token_ids
176
+
177
+ def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch):
178
+ # Push a new batch to the queue
179
+ self.input_queue.put((model_worker_batch.copy(), self.future_token_ids_ct))
180
+
181
+ # Allocate output future objects
182
+ bs = len(model_worker_batch.seq_lens)
183
+ future_next_token_ids = torch.arange(
184
+ -(self.future_token_ids_ct + 1),
185
+ -(self.future_token_ids_ct + 1 + bs),
186
+ -1,
187
+ dtype=torch.int32,
188
+ device=self.device,
189
+ )
190
+ self.future_token_ids_ct = (
191
+ self.future_token_ids_ct + bs
192
+ ) % self.future_token_ids_limit
193
+ return None, future_next_token_ids
194
+
195
+ def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch):
196
+ forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
197
+ logits_output = self.model_runner.forward(forward_batch)
198
+ embeddings = logits_output.embeddings
199
+ return embeddings
200
+
201
+ def update_weights(self, recv_req: UpdateWeightReqInput):
202
+ success, message = self.model_runner.update_weights(
203
+ recv_req.model_path, recv_req.load_format
204
+ )
205
+ return success, message
206
+
207
+ def __delete__(self):
208
+ self.input_queue.put((None, None))
209
+ self.copy_queue.put((None, None, None))
@@ -13,27 +13,46 @@ See the License for the specific language governing permissions and
13
13
  limitations under the License.
14
14
  """
15
15
 
16
- """Memory pool."""
16
+ """
17
+ Memory pool.
18
+
19
+ SGLang has two levels of memory pool.
20
+ ReqToTokenPool maps a a request to its token locations.
21
+ BaseTokenToKVPool maps a token location to its KV cache data.
22
+ """
17
23
 
18
24
  import logging
19
25
  from typing import List, Tuple, Union
20
26
 
21
27
  import torch
22
28
 
29
+ from sglang.srt.layers.radix_attention import RadixAttention
30
+
23
31
  logger = logging.getLogger(__name__)
24
32
 
25
33
 
26
34
  class ReqToTokenPool:
27
35
  """A memory pool that maps a request to its token locations."""
28
36
 
29
- def __init__(self, size: int, max_context_len: int, device: str):
37
+ def __init__(self, size: int, max_context_len: int, device: str, use_records: bool):
30
38
  self.size = size
31
39
  self.max_context_len = max_context_len
32
40
  self.device = device
33
- self.req_to_token = torch.empty(
41
+ self.req_to_token = torch.zeros(
34
42
  (size, max_context_len), dtype=torch.int32, device=device
35
43
  )
36
44
  self.free_slots = list(range(size))
45
+ self.write_records = []
46
+ self.use_records = use_records
47
+
48
+ if self.use_records:
49
+ self.write = self.write_with_records
50
+ else:
51
+ self.write = self.write_without_records
52
+
53
+ def write(self, indices, values):
54
+ # Keep the signature for type checking. It will be assigned during runtime.
55
+ raise NotImplementedError()
37
56
 
38
57
  def available_size(self):
39
58
  return len(self.free_slots)
@@ -55,10 +74,27 @@ class ReqToTokenPool:
55
74
 
56
75
  def clear(self):
57
76
  self.free_slots = list(range(self.size))
77
+ self.write_records = []
78
+
79
+ def write_without_records(self, indices, values):
80
+ self.req_to_token[indices] = values
81
+
82
+ def write_with_records(self, indices, values):
83
+ self.req_to_token[indices] = values
84
+ self.write_records.append((indices, values))
85
+
86
+ def get_write_records(self):
87
+ ret = self.write_records
88
+ self.write_records = []
89
+ return ret
90
+
91
+ def apply_write_records(self, write_records: List[Tuple]):
92
+ for indices, values in write_records:
93
+ self.req_to_token[indices] = values
58
94
 
59
95
 
60
96
  class BaseTokenToKVPool:
61
- """A memory pool that maps a token to its kv cache locations"""
97
+ """A memory pool that maps a token location to its kv cache data."""
62
98
 
63
99
  def __init__(
64
100
  self,
@@ -68,12 +104,12 @@ class BaseTokenToKVPool:
68
104
  ):
69
105
  self.size = size
70
106
  self.dtype = dtype
71
- self.device = device
72
107
  if dtype == torch.float8_e5m2:
73
108
  # NOTE: Store as torch.uint8 because Tensor index_put is not implemented for torch.float8_e5m2
74
109
  self.store_dtype = torch.uint8
75
110
  else:
76
111
  self.store_dtype = dtype
112
+ self.device = device
77
113
 
78
114
  self.free_slots = None
79
115
  self.is_not_in_free_group = True
@@ -124,7 +160,7 @@ class BaseTokenToKVPool:
124
160
 
125
161
  def set_kv_buffer(
126
162
  self,
127
- layer_id: int,
163
+ layer: RadixAttention,
128
164
  loc: torch.Tensor,
129
165
  cache_k: torch.Tensor,
130
166
  cache_v: torch.Tensor,
@@ -179,14 +215,14 @@ class MHATokenToKVPool(BaseTokenToKVPool):
179
215
 
180
216
  def set_kv_buffer(
181
217
  self,
182
- layer_id: int,
218
+ layer: RadixAttention,
183
219
  loc: torch.Tensor,
184
220
  cache_k: torch.Tensor,
185
221
  cache_v: torch.Tensor,
186
222
  ):
223
+ layer_id = layer.layer_id
187
224
  if cache_k.dtype != self.dtype:
188
225
  cache_k = cache_k.to(self.dtype)
189
- if cache_v.dtype != self.dtype:
190
226
  cache_v = cache_v.to(self.dtype)
191
227
  if self.store_dtype != self.dtype:
192
228
  self.k_buffer[layer_id][loc] = cache_k.view(self.store_dtype)
@@ -196,6 +232,14 @@ class MHATokenToKVPool(BaseTokenToKVPool):
196
232
  self.v_buffer[layer_id][loc] = cache_v
197
233
 
198
234
 
235
+ # This compiled version is slower in the unit test
236
+ # python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_non_stream_small_batch_size
237
+ @torch.compile(dynamic=True)
238
+ def copy_two_array(loc, dst_1, src_1, dst_2, src_2, dtype, store_dtype):
239
+ dst_1[loc] = src_1.to(dtype).view(store_dtype)
240
+ dst_2[loc] = src_2.to(dtype).view(store_dtype)
241
+
242
+
199
243
  class MLATokenToKVPool(BaseTokenToKVPool):
200
244
 
201
245
  def __init__(
@@ -235,11 +279,12 @@ class MLATokenToKVPool(BaseTokenToKVPool):
235
279
 
236
280
  def set_kv_buffer(
237
281
  self,
238
- layer_id: int,
282
+ layer: RadixAttention,
239
283
  loc: torch.Tensor,
240
284
  cache_k: torch.Tensor,
241
285
  cache_v: torch.Tensor,
242
286
  ):
287
+ layer_id = layer.layer_id
243
288
  if cache_k.dtype != self.dtype:
244
289
  cache_k = cache_k.to(self.dtype)
245
290
  if self.store_dtype != self.dtype:
@@ -294,13 +339,14 @@ class DoubleSparseTokenToKVPool(BaseTokenToKVPool):
294
339
 
295
340
  def set_kv_buffer(
296
341
  self,
297
- layer_id: int,
342
+ layer: RadixAttention,
298
343
  loc: torch.Tensor,
299
344
  cache_k: torch.Tensor,
300
345
  cache_v: torch.Tensor,
301
346
  cache_label: torch.Tensor,
302
347
  ):
303
348
  # NOTE(Andy): ignore the dtype check
349
+ layer_id = layer.layer_id
304
350
  self.k_buffer[layer_id][loc] = cache_k
305
351
  self.v_buffer[layer_id][loc] = cache_v
306
352
  self.label_buffer[layer_id][loc] = cache_label
@@ -145,9 +145,10 @@ class RadixCache(BasePrefixCache):
145
145
  # The prefix indices could be updated, reuse it
146
146
  new_indices, new_last_node = self.match_prefix(token_ids)
147
147
  assert len(new_indices) == len(token_ids)
148
- self.req_to_token_pool.req_to_token[
149
- req.req_pool_idx, len(req.prefix_indices) : len(new_indices)
150
- ] = new_indices[len(req.prefix_indices) :]
148
+ self.req_to_token_pool.write(
149
+ (req.req_pool_idx, slice(len(req.prefix_indices), len(new_indices))),
150
+ new_indices[len(req.prefix_indices) :],
151
+ )
151
152
 
152
153
  self.dec_lock_ref(req.last_node)
153
154
  self.inc_lock_ref(new_last_node)
@@ -92,6 +92,11 @@ def set_torch_compile_config():
92
92
  torch._dynamo.config.accumulated_cache_size_limit = 1024
93
93
 
94
94
 
95
+ @torch.compile(dynamic=True)
96
+ def clamp_position(seq_lens):
97
+ return torch.clamp((seq_lens - 1), min=0).to(torch.int64)
98
+
99
+
95
100
  class CudaGraphRunner:
96
101
  """A CudaGraphRunner runs the forward pass of a model with cuda graph and torch.compile."""
97
102
 
@@ -105,13 +110,13 @@ class CudaGraphRunner:
105
110
  self.graph_memory_pool = None
106
111
  self.use_torch_compile = model_runner.server_args.enable_torch_compile
107
112
  self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
113
+ self.is_encoder_decoder = self.model_runner.model_config.is_encoder_decoder
108
114
 
109
115
  # Batch sizes to capture
110
116
  if self.model_runner.server_args.disable_cuda_graph_padding:
111
117
  self.capture_bs = list(range(1, 32)) + [64, 128]
112
118
  else:
113
- self.capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
114
-
119
+ self.capture_bs = [1, 2, 3, 4] + [i * 8 for i in range(1, 21)]
115
120
  self.capture_bs = [
116
121
  bs for bs in self.capture_bs if bs <= model_runner.req_to_token_pool.size
117
122
  ]
@@ -128,10 +133,14 @@ class CudaGraphRunner:
128
133
  # Attention backend
129
134
  self.max_bs = max(self.capture_bs)
130
135
  self.model_runner.attn_backend.init_cuda_graph_state(self.max_bs)
136
+
131
137
  self.seq_len_fill_value = (
132
138
  self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value()
133
139
  )
134
140
 
141
+ # FIXME(lsyin): leave it here for now, I don't know whether it is necessary
142
+ self.encoder_len_fill_value = 0
143
+
135
144
  if self.use_torch_compile:
136
145
  set_torch_compile_config()
137
146
 
@@ -143,10 +152,20 @@ class CudaGraphRunner:
143
152
  (self.max_bs,), self.seq_len_fill_value, dtype=torch.int32
144
153
  )
145
154
  self.out_cache_loc = torch.zeros((self.max_bs,), dtype=torch.int32)
155
+ self.mrope_positions = torch.zeros((3, self.max_bs), dtype=torch.int32)
156
+
157
+ if self.is_encoder_decoder:
158
+ # NOTE: encoder_lens can influence the full_text_row_masked_out_mask tensor when doing mixed batch
159
+ self.encoder_lens = torch.full(
160
+ (self.max_bs,), self.encoder_len_fill_value, dtype=torch.int32
161
+ )
162
+ else:
163
+ self.encoder_lens = None
146
164
 
147
165
  # Capture
148
166
  try:
149
- self.capture()
167
+ with self.model_capture_mode():
168
+ self.capture()
150
169
  except RuntimeError as e:
151
170
  raise Exception(
152
171
  f"Capture cuda graph failed: {e}\n"
@@ -157,11 +176,32 @@ class CudaGraphRunner:
157
176
  "Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n"
158
177
  )
159
178
 
160
- def can_run(self, batch_size: int):
161
- if self.disable_padding:
162
- return batch_size in self.graphs
163
- else:
164
- return batch_size <= self.max_bs
179
+ @contextmanager
180
+ def model_capture_mode(self):
181
+ if hasattr(self.model_runner.model, "capture_mode"):
182
+ self.model_runner.model.capture_mode = True
183
+
184
+ yield
185
+
186
+ if hasattr(self.model_runner.model, "capture_mode"):
187
+ self.model_runner.model.capture_mode = False
188
+
189
+ def can_run(self, forward_batch: ForwardBatch):
190
+ is_bs_supported = (
191
+ forward_batch.batch_size in self.graphs
192
+ if self.disable_padding
193
+ else forward_batch.batch_size <= self.max_bs
194
+ )
195
+
196
+ # NOTE: cuda graph cannot handle mixed batch (encoder_len = 0)
197
+ # If mixed batch cannot be supported, then encoder_lens can be removed in cuda graph
198
+ # because the full_text_row_masked_out_mask tensor will always be ones
199
+ is_encoder_lens_supported = (
200
+ torch.all(forward_batch.encoder_lens > 0)
201
+ if self.is_encoder_decoder
202
+ else True
203
+ )
204
+ return is_bs_supported and is_encoder_lens_supported
165
205
 
166
206
  def capture(self):
167
207
  with graph_capture() as graph_capture_context:
@@ -188,10 +228,20 @@ class CudaGraphRunner:
188
228
  req_pool_indices = self.req_pool_indices[:bs]
189
229
  seq_lens = self.seq_lens[:bs]
190
230
  out_cache_loc = self.out_cache_loc[:bs]
231
+ if self.is_encoder_decoder:
232
+ encoder_lens = self.encoder_lens[:bs]
233
+ else:
234
+ encoder_lens = None
235
+
236
+ seq_lens_sum = seq_lens.sum().item()
237
+ mrope_positions = self.mrope_positions[:, :bs]
191
238
 
192
239
  # Attention backend
193
240
  self.model_runner.attn_backend.init_forward_metadata_capture_cuda_graph(
194
- bs, req_pool_indices, seq_lens
241
+ bs,
242
+ req_pool_indices,
243
+ seq_lens,
244
+ encoder_lens,
195
245
  )
196
246
 
197
247
  # Run and capture
@@ -206,11 +256,15 @@ class CudaGraphRunner:
206
256
  token_to_kv_pool=self.model_runner.token_to_kv_pool,
207
257
  attn_backend=self.model_runner.attn_backend,
208
258
  out_cache_loc=out_cache_loc,
259
+ seq_lens_sum=seq_lens_sum,
260
+ encoder_lens=encoder_lens,
209
261
  return_logprob=False,
210
262
  top_logprobs_nums=[0] * bs,
211
- positions=torch.clamp((seq_lens - 1), min=0).to(torch.int64),
263
+ positions=clamp_position(seq_lens),
264
+ mrope_positions=mrope_positions,
212
265
  )
213
- return forward(input_ids, forward_batch.positions, forward_batch)
266
+ logits_output = forward(input_ids, forward_batch.positions, forward_batch)
267
+ return logits_output.next_token_logits
214
268
 
215
269
  for _ in range(2):
216
270
  torch.cuda.synchronize()
@@ -241,7 +295,7 @@ class CudaGraphRunner:
241
295
  index = bisect.bisect_left(self.capture_bs, raw_bs)
242
296
  bs = self.capture_bs[index]
243
297
  if bs != raw_bs:
244
- self.seq_lens.fill_(self.seq_len_fill_value)
298
+ self.seq_lens.fill_(1)
245
299
  self.out_cache_loc.zero_()
246
300
 
247
301
  # Common inputs
@@ -249,31 +303,32 @@ class CudaGraphRunner:
249
303
  self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices)
250
304
  self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens)
251
305
  self.out_cache_loc[:raw_bs].copy_(forward_batch.out_cache_loc)
306
+ if self.is_encoder_decoder:
307
+ self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
308
+ if forward_batch.mrope_positions is not None:
309
+ self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions)
252
310
 
253
311
  # Attention backend
254
312
  self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph(
255
- bs, self.req_pool_indices, self.seq_lens
313
+ bs,
314
+ self.req_pool_indices,
315
+ self.seq_lens,
316
+ forward_batch.seq_lens_sum + (bs - raw_bs),
317
+ self.encoder_lens,
256
318
  )
257
319
 
258
320
  # Replay
259
321
  self.graphs[bs].replay()
260
- logits_output = self.output_buffers[bs]
261
-
262
- # Unpad
263
- if bs != raw_bs:
264
- logits_output = LogitsProcessorOutput(
265
- next_token_logits=logits_output.next_token_logits[:raw_bs],
266
- next_token_logprobs=None,
267
- normalized_prompt_logprobs=None,
268
- input_token_logprobs=None,
269
- input_top_logprobs=None,
270
- output_top_logprobs=None,
271
- )
322
+ next_token_logits = self.output_buffers[bs][:raw_bs]
272
323
 
273
324
  # Extract logprobs
274
325
  if forward_batch.return_logprob:
275
- logits_output.next_token_logprobs = torch.nn.functional.log_softmax(
276
- logits_output.next_token_logits, dim=-1
326
+ next_token_logprobs = torch.nn.functional.log_softmax(
327
+ next_token_logits, dim=-1
328
+ )
329
+ logits_output = LogitsProcessorOutput(
330
+ next_token_logits=next_token_logits,
331
+ next_token_logprobs=next_token_logprobs,
277
332
  )
278
333
  return_top_logprob = any(x > 0 for x in forward_batch.top_logprobs_nums)
279
334
  if return_top_logprob:
@@ -282,7 +337,11 @@ class CudaGraphRunner:
282
337
  top_logprobs_nums=forward_batch.top_logprobs_nums,
283
338
  )
284
339
  logits_output.output_top_logprobs = LogitsProcessor.get_top_logprobs(
285
- logits_output.next_token_logprobs, logits_metadata
340
+ next_token_logprobs, logits_metadata
286
341
  )[1]
342
+ else:
343
+ logits_output = LogitsProcessorOutput(
344
+ next_token_logits=next_token_logits,
345
+ )
287
346
 
288
347
  return logits_output