sglang 0.4.3.post3__py3-none-any.whl → 0.4.3.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -159,17 +159,6 @@ class Scheduler:
159
159
  )
160
160
  self.gpu_id = gpu_id
161
161
  self.enable_hierarchical_cache = server_args.enable_hierarchical_cache
162
- self.decode_mem_cache_buf_multiplier = (
163
- (
164
- self.server_args.speculative_num_draft_tokens
165
- + (
166
- self.server_args.speculative_eagle_topk
167
- * self.server_args.speculative_num_draft_tokens
168
- )
169
- )
170
- if not self.spec_algorithm.is_none()
171
- else 1
172
- )
173
162
 
174
163
  # Distributed rank info
175
164
  self.dp_size = server_args.dp_size
@@ -208,42 +197,12 @@ class Scheduler:
208
197
  self.send_to_detokenizer = SimpleNamespace(send_pyobj=lambda x: None)
209
198
 
210
199
  # Init tokenizer
211
- self.model_config = ModelConfig(
212
- server_args.model_path,
213
- trust_remote_code=server_args.trust_remote_code,
214
- revision=server_args.revision,
215
- context_length=server_args.context_length,
216
- model_override_args=server_args.json_model_override_args,
217
- is_embedding=server_args.is_embedding,
218
- dtype=server_args.dtype,
219
- quantization=server_args.quantization,
220
- )
221
- self.is_generation = self.model_config.is_generation
222
-
223
- if server_args.skip_tokenizer_init:
224
- self.tokenizer = self.processor = None
225
- else:
226
- if self.model_config.is_multimodal:
227
- self.processor = get_processor(
228
- server_args.tokenizer_path,
229
- tokenizer_mode=server_args.tokenizer_mode,
230
- trust_remote_code=server_args.trust_remote_code,
231
- revision=server_args.revision,
232
- )
233
- self.tokenizer = self.processor.tokenizer
234
- else:
235
- self.tokenizer = get_tokenizer(
236
- server_args.tokenizer_path,
237
- tokenizer_mode=server_args.tokenizer_mode,
238
- trust_remote_code=server_args.trust_remote_code,
239
- revision=server_args.revision,
240
- )
200
+ self.init_tokenizer()
241
201
 
242
202
  # Check whether overlap can be enabled
243
203
  if not self.is_generation:
244
204
  self.enable_overlap = False
245
205
  logger.info("Overlap scheduler is disabled for embedding models.")
246
-
247
206
  if self.model_config.is_multimodal:
248
207
  self.enable_overlap = False
249
208
  logger.info("Overlap scheduler is disabled for multimodal models.")
@@ -274,10 +233,8 @@ class Scheduler:
274
233
  target_worker=self.tp_worker,
275
234
  dp_rank=dp_rank,
276
235
  )
277
- self.prefill_only_one_req = True
278
236
  else:
279
237
  self.draft_worker = None
280
- self.prefill_only_one_req = False
281
238
 
282
239
  # Get token and memory info from the model worker
283
240
  (
@@ -309,32 +266,7 @@ class Scheduler:
309
266
  )
310
267
 
311
268
  # Init memory pool and cache
312
- self.req_to_token_pool, self.token_to_kv_pool_allocator = (
313
- self.tp_worker.get_memory_pool()
314
- )
315
-
316
- if (
317
- server_args.chunked_prefill_size is not None
318
- and server_args.disable_radix_cache
319
- ):
320
- self.tree_cache = ChunkCache(
321
- req_to_token_pool=self.req_to_token_pool,
322
- token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
323
- )
324
- else:
325
- if self.enable_hierarchical_cache:
326
- self.tree_cache = HiRadixCache(
327
- req_to_token_pool=self.req_to_token_pool,
328
- token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
329
- )
330
- else:
331
- self.tree_cache = RadixCache(
332
- req_to_token_pool=self.req_to_token_pool,
333
- token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
334
- disable=server_args.disable_radix_cache,
335
- )
336
-
337
- self.policy = SchedulePolicy(self.schedule_policy, self.tree_cache)
269
+ self.init_memory_pool_and_cache()
338
270
 
339
271
  # Init running status
340
272
  self.waiting_queue: List[Req] = []
@@ -348,25 +280,13 @@ class Scheduler:
348
280
  self.forward_ct = 0
349
281
  self.forward_ct_decode = 0
350
282
  self.num_generated_tokens = 0
351
- self.spec_num_total_accepted_tokens = 0
352
- self.spec_num_total_forward_ct = 0
353
- self.cum_spec_accept_length = 0
354
- self.cum_spec_accept_count = 0
355
283
  self.last_decode_stats_tic = time.time()
356
284
  self.return_health_check_ct = 0
357
285
  self.current_stream = torch.get_device_module(self.device).current_stream()
358
286
  if self.device == "cpu":
359
287
  self.current_stream.synchronize = lambda: None # No-op for CPU
360
288
 
361
- # For metrics only.
362
- # The largest prefill length of a single request
363
- self._largest_prefill_len: int = 0
364
- # The largest context length (prefill + generation) of a single request
365
- self._largest_prefill_decode_len: int = 0
366
- self.last_gen_throughput: float = 0.0
367
- self.step_time_dict = defaultdict(list) # Dict[batch size -> step time]
368
-
369
- # Session info
289
+ # Init session info
370
290
  self.sessions: Dict[str, Session] = {}
371
291
 
372
292
  # Init chunked prefill
@@ -387,11 +307,11 @@ class Scheduler:
387
307
  else:
388
308
  self.grammar_backend = None
389
309
 
390
- # Init new token estimation
310
+ # Init schedule policy and new token estimation
311
+ self.policy = SchedulePolicy(self.schedule_policy, self.tree_cache)
391
312
  assert (
392
313
  server_args.schedule_conservativeness >= 0
393
314
  ), "Invalid schedule_conservativeness"
394
-
395
315
  self.init_new_token_ratio = min(
396
316
  global_config.default_init_new_token_ratio
397
317
  * server_args.schedule_conservativeness,
@@ -430,14 +350,7 @@ class Scheduler:
430
350
  self.profiler_target_forward_ct: Optional[int] = None
431
351
 
432
352
  # Init metrics stats
433
- self.stats = SchedulerStats()
434
- if self.enable_metrics:
435
- self.metrics_collector = SchedulerMetricsCollector(
436
- labels={
437
- "model_name": self.server_args.served_model_name,
438
- # TODO: Add lora name/path in the future,
439
- },
440
- )
353
+ self.init_metrics()
441
354
 
442
355
  # Init request dispatcher
443
356
  self._request_dispatcher = TypeBasedDispatcher(
@@ -460,39 +373,104 @@ class Scheduler:
460
373
  (ResumeMemoryOccupationReqInput, self.resume_memory_occupation),
461
374
  (ProfileReq, self.profile),
462
375
  (GetInternalStateReq, self.get_internal_state),
376
+ (SetInternalStateReq, self.set_internal_state),
463
377
  ]
464
378
  )
465
379
 
466
- def watchdog_thread(self):
467
- """A watch dog thread that will try to kill the server itself if one forward batch takes too long."""
468
- self.watchdog_last_forward_ct = 0
469
- self.watchdog_last_time = time.time()
380
+ def init_tokenizer(self):
381
+ server_args = self.server_args
470
382
 
471
- while True:
472
- current = time.time()
473
- if self.cur_batch is not None:
474
- if self.watchdog_last_forward_ct == self.forward_ct:
475
- if current > self.watchdog_last_time + self.watchdog_timeout:
476
- logger.error(f"Watchdog timeout ({self.watchdog_timeout=})")
477
- break
478
- else:
479
- self.watchdog_last_forward_ct = self.forward_ct
480
- self.watchdog_last_time = current
481
- time.sleep(self.watchdog_timeout // 2)
383
+ self.model_config = ModelConfig(
384
+ server_args.model_path,
385
+ trust_remote_code=server_args.trust_remote_code,
386
+ revision=server_args.revision,
387
+ context_length=server_args.context_length,
388
+ model_override_args=server_args.json_model_override_args,
389
+ is_embedding=server_args.is_embedding,
390
+ dtype=server_args.dtype,
391
+ quantization=server_args.quantization,
392
+ )
393
+ self.is_generation = self.model_config.is_generation
482
394
 
483
- # Print batch size and memory pool info to check whether there are de-sync issues.
484
- logger.error(
485
- f"{self.cur_batch.batch_size()=}, "
486
- f"{self.cur_batch.reqs=}, "
487
- f"{self.token_to_kv_pool.available_size()=}, "
488
- f"{self.tree_cache.evictable_size()=}, "
395
+ if server_args.skip_tokenizer_init:
396
+ self.tokenizer = self.processor = None
397
+ else:
398
+ if self.model_config.is_multimodal:
399
+ self.processor = get_processor(
400
+ server_args.tokenizer_path,
401
+ tokenizer_mode=server_args.tokenizer_mode,
402
+ trust_remote_code=server_args.trust_remote_code,
403
+ revision=server_args.revision,
404
+ )
405
+ self.tokenizer = self.processor.tokenizer
406
+ else:
407
+ self.tokenizer = get_tokenizer(
408
+ server_args.tokenizer_path,
409
+ tokenizer_mode=server_args.tokenizer_mode,
410
+ trust_remote_code=server_args.trust_remote_code,
411
+ revision=server_args.revision,
412
+ )
413
+
414
+ def init_memory_pool_and_cache(self):
415
+ server_args = self.server_args
416
+
417
+ self.req_to_token_pool, self.token_to_kv_pool_allocator = (
418
+ self.tp_worker.get_memory_pool()
419
+ )
420
+
421
+ if (
422
+ server_args.chunked_prefill_size is not None
423
+ and server_args.disable_radix_cache
424
+ ):
425
+ self.tree_cache = ChunkCache(
426
+ req_to_token_pool=self.req_to_token_pool,
427
+ token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
428
+ )
429
+ else:
430
+ if self.enable_hierarchical_cache:
431
+ self.tree_cache = HiRadixCache(
432
+ req_to_token_pool=self.req_to_token_pool,
433
+ token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
434
+ )
435
+ else:
436
+ self.tree_cache = RadixCache(
437
+ req_to_token_pool=self.req_to_token_pool,
438
+ token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
439
+ disable=server_args.disable_radix_cache,
440
+ )
441
+
442
+ self.decode_mem_cache_buf_multiplier = (
443
+ 1
444
+ if self.spec_algorithm.is_none()
445
+ else (
446
+ server_args.speculative_num_draft_tokens
447
+ + (
448
+ server_args.speculative_eagle_topk
449
+ * server_args.speculative_num_steps
450
+ )
451
+ )
489
452
  )
490
- # Wait for some time so that the parent process can print the error.
491
- pyspy_dump_schedulers()
492
- print(file=sys.stderr, flush=True)
493
- print(file=sys.stdout, flush=True)
494
- time.sleep(5)
495
- self.parent_process.send_signal(signal.SIGQUIT)
453
+
454
+ def init_metrics(self):
455
+ # The largest prefill length of a single request
456
+ self._largest_prefill_len: int = 0
457
+ # The largest context length (prefill + generation) of a single request
458
+ self._largest_prefill_decode_len: int = 0
459
+ self.last_gen_throughput: float = 0.0
460
+ self.step_time_dict = defaultdict(list) # Dict[batch size -> step time]
461
+ self.spec_num_total_accepted_tokens = 0
462
+ self.spec_num_total_forward_ct = 0
463
+ self.cum_spec_accept_length = 0
464
+ self.cum_spec_accept_count = 0
465
+ self.stats = SchedulerStats()
466
+ if self.enable_metrics:
467
+ engine_type = "unified"
468
+ self.metrics_collector = SchedulerMetricsCollector(
469
+ labels={
470
+ "model_name": self.server_args.served_model_name,
471
+ "engine_type": engine_type,
472
+ },
473
+ )
496
474
 
497
475
  @torch.no_grad()
498
476
  def event_loop_normal(self):
@@ -932,7 +910,7 @@ class Scheduler:
932
910
  ):
933
911
  # During idle time, also collect metrics every 30 seconds.
934
912
  num_used = self.max_total_num_tokens - (
935
- self.token_to_kv_pool.available_size()
913
+ self.token_to_kv_pool_allocator.available_size()
936
914
  + self.tree_cache.evictable_size()
937
915
  )
938
916
  num_running_reqs = len(self.running_batch.reqs) if self.running_batch else 0
@@ -1077,8 +1055,6 @@ class Scheduler:
1077
1055
  else:
1078
1056
  self.batch_is_full = True
1079
1057
  break
1080
- if self.prefill_only_one_req:
1081
- break
1082
1058
 
1083
1059
  # Update waiting queue
1084
1060
  can_run_list: List[Req] = adder.can_run_list
@@ -1180,6 +1156,7 @@ class Scheduler:
1180
1156
  ):
1181
1157
  self.stop_profile()
1182
1158
 
1159
+ # Run forward
1183
1160
  if self.is_generation:
1184
1161
  if self.spec_algorithm.is_none():
1185
1162
  model_worker_batch = batch.get_model_worker_batch()
@@ -1200,6 +1177,7 @@ class Scheduler:
1200
1177
  self.spec_num_total_forward_ct += batch.batch_size()
1201
1178
  self.num_generated_tokens += num_accepted_tokens
1202
1179
  batch.output_ids = next_token_ids
1180
+
1203
1181
  # These 2 values are needed for processing the output, but the values can be
1204
1182
  # modified by overlap schedule. So we have to copy them here so that
1205
1183
  # we can use the correct values in output processing.
@@ -1233,7 +1211,6 @@ class Scheduler:
1233
1211
  result: Union[GenerationBatchResult, EmbeddingBatchResult],
1234
1212
  ):
1235
1213
  if batch.forward_mode.is_decode():
1236
- assert isinstance(result, GenerationBatchResult)
1237
1214
  self.process_batch_result_decode(batch, result)
1238
1215
  if batch.is_empty():
1239
1216
  self.running_batch = None
@@ -1485,6 +1462,7 @@ class Scheduler:
1485
1462
  batch.next_batch_sampling_info.update_regex_vocab_mask()
1486
1463
  self.current_stream.synchronize()
1487
1464
  batch.next_batch_sampling_info.sampling_info_done.set()
1465
+
1488
1466
  self.stream_output(batch.reqs, batch.return_logprob)
1489
1467
 
1490
1468
  self.token_to_kv_pool_allocator.free_group_end()
@@ -1588,7 +1566,9 @@ class Scheduler:
1588
1566
  req.temp_input_token_ids_logprobs_idx
1589
1567
  )
1590
1568
  for val, idx in zip(
1591
- req.temp_input_top_logprobs_val, req.temp_input_top_logprobs_idx
1569
+ req.temp_input_top_logprobs_val,
1570
+ req.temp_input_top_logprobs_idx,
1571
+ strict=True,
1592
1572
  ):
1593
1573
  req.input_top_logprobs_val.extend(val)
1594
1574
  req.input_top_logprobs_idx.extend(idx)
@@ -1813,14 +1793,18 @@ class Scheduler:
1813
1793
  else: # embedding or reward model
1814
1794
  embeddings = []
1815
1795
  prompt_tokens = []
1796
+ cached_tokens = []
1816
1797
  for req in reqs:
1817
1798
  if req.finished():
1818
1799
  rids.append(req.rid)
1819
1800
  finished_reasons.append(req.finished_reason.to_json())
1820
1801
  embeddings.append(req.embedding)
1821
1802
  prompt_tokens.append(len(req.origin_input_ids))
1803
+ cached_tokens.append(req.cached_tokens)
1822
1804
  self.send_to_detokenizer.send_pyobj(
1823
- BatchEmbeddingOut(rids, finished_reasons, embeddings, prompt_tokens)
1805
+ BatchEmbeddingOut(
1806
+ rids, finished_reasons, embeddings, prompt_tokens, cached_tokens
1807
+ )
1824
1808
  )
1825
1809
 
1826
1810
  def prepare_dp_attn_batch(self, local_batch: ScheduleBatch):
@@ -1906,6 +1890,37 @@ class Scheduler:
1906
1890
  self._extend_requests_to_queue(self.grammar_queue[:num_ready_reqs])
1907
1891
  self.grammar_queue = self.grammar_queue[num_ready_reqs:]
1908
1892
 
1893
+ def watchdog_thread(self):
1894
+ """A watch dog thread that will try to kill the server itself if one forward batch takes too long."""
1895
+ self.watchdog_last_forward_ct = 0
1896
+ self.watchdog_last_time = time.time()
1897
+
1898
+ while True:
1899
+ current = time.time()
1900
+ if self.cur_batch is not None:
1901
+ if self.watchdog_last_forward_ct == self.forward_ct:
1902
+ if current > self.watchdog_last_time + self.watchdog_timeout:
1903
+ logger.error(f"Watchdog timeout ({self.watchdog_timeout=})")
1904
+ break
1905
+ else:
1906
+ self.watchdog_last_forward_ct = self.forward_ct
1907
+ self.watchdog_last_time = current
1908
+ time.sleep(self.watchdog_timeout // 2)
1909
+
1910
+ # Print batch size and memory pool info to check whether there are de-sync issues.
1911
+ logger.error(
1912
+ f"{self.cur_batch.batch_size()=}, "
1913
+ f"{self.cur_batch.reqs=}, "
1914
+ f"{self.token_to_kv_pool_allocator.available_size()=}, "
1915
+ f"{self.tree_cache.evictable_size()=}, "
1916
+ )
1917
+ # Wait for some time so that the parent process can print the error.
1918
+ pyspy_dump_schedulers()
1919
+ print(file=sys.stderr, flush=True)
1920
+ print(file=sys.stdout, flush=True)
1921
+ time.sleep(5)
1922
+ self.parent_process.send_signal(signal.SIGQUIT)
1923
+
1909
1924
  def flush_cache_wrapped(self, recv_req: FlushCacheReq):
1910
1925
  self.flush_cache()
1911
1926
 
@@ -1917,7 +1932,6 @@ class Scheduler:
1917
1932
  self.cur_batch = None
1918
1933
  self.last_batch = None
1919
1934
  self.tree_cache.reset()
1920
- self.tree_cache_metrics = {"total": 0, "hit": 0}
1921
1935
  if self.grammar_backend:
1922
1936
  self.grammar_backend.reset()
1923
1937
  self.req_to_token_pool.clear()
@@ -2009,6 +2023,9 @@ class Scheduler:
2009
2023
  req.to_abort = True
2010
2024
  break
2011
2025
 
2026
+ def _pause_engine(self) -> Tuple[List[Req], int]:
2027
+ raise NotImplementedError()
2028
+
2012
2029
  def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
2013
2030
  """In-place update of the weights from disk."""
2014
2031
  success, message = self.tp_worker.update_weights_from_disk(recv_req)
@@ -1068,6 +1068,7 @@ class TokenizerManager:
1068
1068
  self.metrics_collector.observe_one_finished_request(
1069
1069
  recv_obj.prompt_tokens[i],
1070
1070
  completion_tokens,
1071
+ recv_obj.cached_tokens[i],
1071
1072
  state.finished_time - state.created_time,
1072
1073
  )
1073
1074
 
@@ -20,9 +20,8 @@ Memory pool.
20
20
 
21
21
  SGLang has two levels of memory pool.
22
22
  ReqToTokenPool maps a a request to its token locations.
23
- TokenToKVPoolAllocator maps a token location to its KV cache data.
24
- KVCache actually holds the physical kv cache. Allocation indices are allocated
25
- by TokenToKVPoolAllocator
23
+ TokenToKVPoolAllocator manages the indices to kv cache data.
24
+ KVCache actually holds the physical kv cache.
26
25
  """
27
26
 
28
27
  import abc
@@ -92,14 +91,40 @@ class ReqToTokenPool:
92
91
  self.free_slots = list(range(self.size))
93
92
 
94
93
 
94
+ class KVCache(abc.ABC):
95
+
96
+ @abc.abstractmethod
97
+ def get_key_buffer(self, layer_id: int) -> torch.Tensor:
98
+ raise NotImplementedError()
99
+
100
+ @abc.abstractmethod
101
+ def get_value_buffer(self, layer_id: int) -> torch.Tensor:
102
+ raise NotImplementedError()
103
+
104
+ @abc.abstractmethod
105
+ def get_kv_buffer(self, layer_id: int) -> Tuple[torch.Tensor, torch.Tensor]:
106
+ raise NotImplementedError()
107
+
108
+ @abc.abstractmethod
109
+ def set_kv_buffer(
110
+ self,
111
+ layer: RadixAttention,
112
+ loc: torch.Tensor,
113
+ cache_k: torch.Tensor,
114
+ cache_v: torch.Tensor,
115
+ ) -> None:
116
+ raise NotImplementedError()
117
+
118
+
95
119
  class TokenToKVPoolAllocator:
96
- """A memory pool that maps a token location to its kv cache data."""
120
+ """An allocator managing the indices to kv cache data."""
97
121
 
98
122
  def __init__(
99
123
  self,
100
124
  size: int,
101
125
  dtype: torch.dtype,
102
126
  device: str,
127
+ kvcache: KVCache,
103
128
  ):
104
129
  self.size = size
105
130
  self.dtype = dtype
@@ -110,9 +135,14 @@ class TokenToKVPoolAllocator:
110
135
  self.free_group = []
111
136
  self.clear()
112
137
 
138
+ self._kvcache = kvcache
139
+
113
140
  def available_size(self):
114
141
  return len(self.free_slots)
115
142
 
143
+ def get_kvcache(self):
144
+ return self._kvcache
145
+
116
146
  def alloc(self, need_size: int):
117
147
  if need_size > len(self.free_slots):
118
148
  return None
@@ -147,31 +177,6 @@ class TokenToKVPoolAllocator:
147
177
  self.free_group = []
148
178
 
149
179
 
150
- class KVCache(abc.ABC):
151
-
152
- @abc.abstractmethod
153
- def get_key_buffer(self, layer_id: int) -> torch.Tensor:
154
- raise NotImplementedError()
155
-
156
- @abc.abstractmethod
157
- def get_value_buffer(self, layer_id: int) -> torch.Tensor:
158
- raise NotImplementedError()
159
-
160
- @abc.abstractmethod
161
- def get_kv_buffer(self, layer_id: int) -> Tuple[torch.Tensor, torch.Tensor]:
162
- raise NotImplementedError()
163
-
164
- @abc.abstractmethod
165
- def set_kv_buffer(
166
- self,
167
- layer: RadixAttention,
168
- loc: torch.Tensor,
169
- cache_k: torch.Tensor,
170
- cache_v: torch.Tensor,
171
- ) -> None:
172
- raise NotImplementedError()
173
-
174
-
175
180
  class MHATokenToKVPool(KVCache):
176
181
 
177
182
  def __init__(
@@ -121,6 +121,12 @@ class TokenizerMetricsCollector:
121
121
  labelnames=labels.keys(),
122
122
  )
123
123
 
124
+ self.cached_tokens_total = Counter(
125
+ name="sglang:cached_tokens_total",
126
+ documentation="Number of cached prompt tokens.",
127
+ labelnames=labels.keys(),
128
+ )
129
+
124
130
  self.num_requests_total = Counter(
125
131
  name="sglang:num_requests_total",
126
132
  documentation="Number of requests processed.",
@@ -245,10 +251,12 @@ class TokenizerMetricsCollector:
245
251
  self,
246
252
  prompt_tokens: int,
247
253
  generation_tokens: int,
254
+ cached_tokens: int,
248
255
  e2e_latency: float,
249
256
  ):
250
257
  self.prompt_tokens_total.labels(**self.labels).inc(prompt_tokens)
251
258
  self.generation_tokens_total.labels(**self.labels).inc(generation_tokens)
259
+ self.cached_tokens_total.labels(**self.labels).inc(cached_tokens)
252
260
  self.num_requests_total.labels(**self.labels).inc(1)
253
261
  self._log_histogram(self.histogram_e2e_request_latency, e2e_latency)
254
262
  if generation_tokens >= 1:
@@ -396,16 +396,10 @@ class CudaGraphRunner:
396
396
 
397
397
  run_once()
398
398
 
399
- torch.cuda.synchronize()
400
- self.model_runner.tp_group.barrier()
401
-
402
399
  global global_graph_memory_pool
403
400
  with torch.cuda.graph(graph, pool=global_graph_memory_pool, stream=stream):
404
401
  out = run_once()
405
402
 
406
- torch.cuda.synchronize()
407
- self.model_runner.tp_group.barrier()
408
-
409
403
  global_graph_memory_pool = graph.pool()
410
404
  return graph, out
411
405
 
@@ -427,7 +421,7 @@ class CudaGraphRunner:
427
421
  self.capture_hidden_mode = hidden_mode_from_spec_info
428
422
  self.capture()
429
423
 
430
- def replay(self, forward_batch: ForwardBatch):
424
+ def replay(self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False):
431
425
  self.recapture_if_needed(forward_batch)
432
426
 
433
427
  raw_bs = forward_batch.batch_size