sglang 0.4.4__py3-none-any.whl → 0.4.4.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -33,7 +33,7 @@ from sglang.srt.model_executor.forward_batch_info import (
33
33
  ForwardBatch,
34
34
  ForwardMode,
35
35
  )
36
- from sglang.srt.utils import is_hip
36
+ from sglang.srt.utils import get_available_gpu_memory, is_hip
37
37
 
38
38
  _is_hip = is_hip()
39
39
 
@@ -174,6 +174,7 @@ class CudaGraphRunner:
174
174
  self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
175
175
  self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder
176
176
  self.enable_dp_attention = model_runner.server_args.enable_dp_attention
177
+ self.speculative_algorithm = model_runner.server_args.speculative_algorithm
177
178
  self.tp_size = model_runner.server_args.tp_size
178
179
  self.dp_size = model_runner.server_args.dp_size
179
180
 
@@ -236,7 +237,7 @@ class CudaGraphRunner:
236
237
  if self.enable_dp_attention:
237
238
  self.gathered_buffer = torch.zeros(
238
239
  (
239
- self.max_bs * self.dp_size,
240
+ self.max_bs * self.dp_size * self.num_tokens_per_bs,
240
241
  self.model_runner.model_config.hidden_size,
241
242
  ),
242
243
  dtype=self.model_runner.dtype,
@@ -276,13 +277,12 @@ class CudaGraphRunner:
276
277
 
277
278
  def can_run(self, forward_batch: ForwardBatch):
278
279
  if self.enable_dp_attention:
279
- min_num_tokens, max_num_tokens = min(
280
- forward_batch.global_num_tokens_cpu
281
- ), max(forward_batch.global_num_tokens_cpu)
280
+ total_global_tokens = sum(forward_batch.global_num_tokens_cpu)
281
+
282
282
  is_bs_supported = forward_batch.can_run_dp_cuda_graph and (
283
- (min_num_tokens == max_num_tokens and max_num_tokens in self.graphs)
283
+ total_global_tokens in self.graphs
284
284
  if self.disable_padding
285
- else max_num_tokens <= self.max_bs
285
+ else total_global_tokens <= self.max_bs
286
286
  )
287
287
  else:
288
288
  is_bs_supported = (
@@ -304,6 +304,9 @@ class CudaGraphRunner:
304
304
  def capture(self):
305
305
  with graph_capture() as graph_capture_context:
306
306
  self.stream = graph_capture_context.stream
307
+ avail_mem = get_available_gpu_memory(
308
+ self.model_runner.device, self.model_runner.gpu_id, empty_cache=False
309
+ )
307
310
  # Reverse the order to enable better memory sharing across cuda graphs.
308
311
  capture_range = (
309
312
  tqdm.tqdm(list(reversed(self.capture_bs)))
@@ -311,6 +314,16 @@ class CudaGraphRunner:
311
314
  else reversed(self.capture_bs)
312
315
  )
313
316
  for bs in capture_range:
317
+ if get_tensor_model_parallel_rank() == 0:
318
+ avail_mem = get_available_gpu_memory(
319
+ self.model_runner.device,
320
+ self.model_runner.gpu_id,
321
+ empty_cache=False,
322
+ )
323
+ capture_range.set_description(
324
+ f"Capturing batches ({avail_mem=:.2f} GB)"
325
+ )
326
+
314
327
  with patch_model(
315
328
  self.model_runner.model,
316
329
  bs in self.compile_bs,
@@ -345,8 +358,18 @@ class CudaGraphRunner:
345
358
  mrope_positions = self.mrope_positions[:, :bs]
346
359
 
347
360
  if self.enable_dp_attention:
348
- global_num_tokens = [bs] * self.tp_size
349
- gathered_buffer = self.gathered_buffer[: bs * self.tp_size]
361
+ self.global_num_tokens_gpu.copy_(
362
+ torch.tensor(
363
+ [
364
+ num_tokens // self.dp_size + (i < bs % self.dp_size)
365
+ for i in range(self.dp_size)
366
+ ],
367
+ dtype=torch.int32,
368
+ device=input_ids.device,
369
+ )
370
+ )
371
+ global_num_tokens = self.global_num_tokens_gpu
372
+ gathered_buffer = self.gathered_buffer[:num_tokens]
350
373
  else:
351
374
  global_num_tokens = None
352
375
  gathered_buffer = None
@@ -371,7 +394,7 @@ class CudaGraphRunner:
371
394
  encoder_lens=encoder_lens,
372
395
  return_logprob=False,
373
396
  positions=positions,
374
- global_num_tokens_cpu=global_num_tokens,
397
+ global_num_tokens_gpu=global_num_tokens,
375
398
  gathered_buffer=gathered_buffer,
376
399
  mrope_positions=mrope_positions,
377
400
  spec_algorithm=self.model_runner.spec_algorithm,
@@ -392,6 +415,9 @@ class CudaGraphRunner:
392
415
 
393
416
  # Run and capture
394
417
  def run_once():
418
+ # Clean intermediate result cache for DP attention
419
+ forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
420
+
395
421
  logits_output = forward(input_ids, forward_batch.positions, forward_batch)
396
422
  return logits_output.next_token_logits, logits_output.hidden_states
397
423
 
@@ -426,7 +452,7 @@ class CudaGraphRunner:
426
452
  self.capture_hidden_mode = hidden_mode_from_spec_info
427
453
  self.capture()
428
454
 
429
- def replay(self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False):
455
+ def replay_prepare(self, forward_batch: ForwardBatch):
430
456
  self.recapture_if_needed(forward_batch)
431
457
 
432
458
  raw_bs = forward_batch.batch_size
@@ -435,7 +461,7 @@ class CudaGraphRunner:
435
461
  # Pad
436
462
  if self.enable_dp_attention:
437
463
  index = bisect.bisect_left(
438
- self.capture_bs, max(forward_batch.global_num_tokens_cpu)
464
+ self.capture_bs, sum(forward_batch.global_num_tokens_cpu)
439
465
  )
440
466
  else:
441
467
  index = bisect.bisect_left(self.capture_bs, raw_bs)
@@ -459,6 +485,8 @@ class CudaGraphRunner:
459
485
  self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
460
486
  if forward_batch.mrope_positions is not None:
461
487
  self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions)
488
+ if self.enable_dp_attention:
489
+ self.global_num_tokens_gpu.copy_(forward_batch.global_num_tokens_gpu)
462
490
 
463
491
  if hasattr(forward_batch.spec_info, "hidden_states"):
464
492
  self.hidden_states[:raw_num_token] = forward_batch.spec_info.hidden_states
@@ -475,14 +503,29 @@ class CudaGraphRunner:
475
503
  seq_lens_cpu=self.seq_lens_cpu,
476
504
  )
477
505
 
506
+ # Store fields
507
+ self.raw_bs = raw_bs
508
+ self.raw_num_token = raw_num_token
509
+ self.bs = bs
510
+
511
+ def replay(self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False):
512
+ if not skip_attn_backend_init:
513
+ self.replay_prepare(forward_batch)
514
+ else:
515
+ # In speculative decoding, these two fields are still needed.
516
+ self.input_ids[: self.raw_num_token].copy_(forward_batch.input_ids)
517
+ self.positions[: self.raw_num_token].copy_(forward_batch.positions)
518
+
478
519
  # Replay
479
- self.graphs[bs].replay()
480
- next_token_logits, hidden_states = self.output_buffers[bs]
520
+ self.graphs[self.bs].replay()
521
+ next_token_logits, hidden_states = self.output_buffers[self.bs]
481
522
 
482
523
  logits_output = LogitsProcessorOutput(
483
- next_token_logits=next_token_logits[:raw_num_token],
524
+ next_token_logits=next_token_logits[: self.raw_num_token],
484
525
  hidden_states=(
485
- hidden_states[:raw_num_token] if hidden_states is not None else None
526
+ hidden_states[: self.raw_num_token]
527
+ if hidden_states is not None
528
+ else None
486
529
  ),
487
530
  )
488
531
  return logits_output
@@ -38,7 +38,7 @@ import triton
38
38
  import triton.language as tl
39
39
 
40
40
  from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
41
- from sglang.srt.utils import get_compiler_backend, next_power_of_2
41
+ from sglang.srt.utils import get_compiler_backend
42
42
 
43
43
  if TYPE_CHECKING:
44
44
  from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
@@ -263,15 +263,24 @@ class ForwardBatch:
263
263
  extend_input_logprob_token_ids_gpu=extend_input_logprob_token_ids_gpu,
264
264
  )
265
265
 
266
+ # For DP attention
266
267
  if batch.global_num_tokens is not None:
267
268
  ret.global_num_tokens_cpu = batch.global_num_tokens
268
- max_len = max(ret.global_num_tokens_cpu)
269
+ ret.global_num_tokens_gpu = torch.tensor(
270
+ batch.global_num_tokens, dtype=torch.int64
271
+ ).to(device, non_blocking=True)
272
+
273
+ ret.global_num_tokens_for_logprob_cpu = batch.global_num_tokens_for_logprob
274
+ ret.global_num_tokens_for_logprob_gpu = torch.tensor(
275
+ batch.global_num_tokens_for_logprob, dtype=torch.int64
276
+ ).to(device, non_blocking=True)
277
+
278
+ sum_len = sum(batch.global_num_tokens)
269
279
  ret.gathered_buffer = torch.zeros(
270
- (max_len * model_runner.tp_size, model_runner.model_config.hidden_size),
280
+ (sum_len, model_runner.model_config.hidden_size),
271
281
  dtype=model_runner.dtype,
272
282
  device=device,
273
283
  )
274
-
275
284
  if ret.forward_mode.is_idle():
276
285
  ret.positions = torch.empty((0,), device=device)
277
286
  return ret