sglang 0.4.4__py3-none-any.whl → 0.4.4.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/srt/function_call_parser.py +33 -2
- sglang/srt/layers/dp_attention.py +30 -2
- sglang/srt/layers/elementwise.py +411 -0
- sglang/srt/layers/logits_processor.py +1 -0
- sglang/srt/layers/moe/router.py +342 -0
- sglang/srt/managers/cache_controller.py +2 -0
- sglang/srt/managers/data_parallel_controller.py +1 -1
- sglang/srt/managers/schedule_batch.py +1 -1
- sglang/srt/managers/scheduler.py +52 -18
- sglang/srt/managers/scheduler_output_processor_mixin.py +4 -1
- sglang/srt/mem_cache/hiradix_cache.py +9 -1
- sglang/srt/mem_cache/memory_pool.py +4 -1
- sglang/srt/model_executor/cuda_graph_runner.py +59 -16
- sglang/srt/model_executor/forward_batch_info.py +13 -4
- sglang/srt/models/deepseek_v2.py +180 -177
- sglang/srt/models/grok.py +374 -119
- sglang/srt/openai_api/adapter.py +22 -20
- sglang/srt/server_args.py +5 -5
- sglang/version.py +1 -1
- {sglang-0.4.4.dist-info → sglang-0.4.4.post1.dist-info}/METADATA +1 -1
- {sglang-0.4.4.dist-info → sglang-0.4.4.post1.dist-info}/RECORD +24 -22
- {sglang-0.4.4.dist-info → sglang-0.4.4.post1.dist-info}/LICENSE +0 -0
- {sglang-0.4.4.dist-info → sglang-0.4.4.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.4.dist-info → sglang-0.4.4.post1.dist-info}/top_level.txt +0 -0
@@ -33,7 +33,7 @@ from sglang.srt.model_executor.forward_batch_info import (
|
|
33
33
|
ForwardBatch,
|
34
34
|
ForwardMode,
|
35
35
|
)
|
36
|
-
from sglang.srt.utils import is_hip
|
36
|
+
from sglang.srt.utils import get_available_gpu_memory, is_hip
|
37
37
|
|
38
38
|
_is_hip = is_hip()
|
39
39
|
|
@@ -174,6 +174,7 @@ class CudaGraphRunner:
|
|
174
174
|
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
|
175
175
|
self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder
|
176
176
|
self.enable_dp_attention = model_runner.server_args.enable_dp_attention
|
177
|
+
self.speculative_algorithm = model_runner.server_args.speculative_algorithm
|
177
178
|
self.tp_size = model_runner.server_args.tp_size
|
178
179
|
self.dp_size = model_runner.server_args.dp_size
|
179
180
|
|
@@ -236,7 +237,7 @@ class CudaGraphRunner:
|
|
236
237
|
if self.enable_dp_attention:
|
237
238
|
self.gathered_buffer = torch.zeros(
|
238
239
|
(
|
239
|
-
self.max_bs * self.dp_size,
|
240
|
+
self.max_bs * self.dp_size * self.num_tokens_per_bs,
|
240
241
|
self.model_runner.model_config.hidden_size,
|
241
242
|
),
|
242
243
|
dtype=self.model_runner.dtype,
|
@@ -276,13 +277,12 @@ class CudaGraphRunner:
|
|
276
277
|
|
277
278
|
def can_run(self, forward_batch: ForwardBatch):
|
278
279
|
if self.enable_dp_attention:
|
279
|
-
|
280
|
-
|
281
|
-
), max(forward_batch.global_num_tokens_cpu)
|
280
|
+
total_global_tokens = sum(forward_batch.global_num_tokens_cpu)
|
281
|
+
|
282
282
|
is_bs_supported = forward_batch.can_run_dp_cuda_graph and (
|
283
|
-
|
283
|
+
total_global_tokens in self.graphs
|
284
284
|
if self.disable_padding
|
285
|
-
else
|
285
|
+
else total_global_tokens <= self.max_bs
|
286
286
|
)
|
287
287
|
else:
|
288
288
|
is_bs_supported = (
|
@@ -304,6 +304,9 @@ class CudaGraphRunner:
|
|
304
304
|
def capture(self):
|
305
305
|
with graph_capture() as graph_capture_context:
|
306
306
|
self.stream = graph_capture_context.stream
|
307
|
+
avail_mem = get_available_gpu_memory(
|
308
|
+
self.model_runner.device, self.model_runner.gpu_id, empty_cache=False
|
309
|
+
)
|
307
310
|
# Reverse the order to enable better memory sharing across cuda graphs.
|
308
311
|
capture_range = (
|
309
312
|
tqdm.tqdm(list(reversed(self.capture_bs)))
|
@@ -311,6 +314,16 @@ class CudaGraphRunner:
|
|
311
314
|
else reversed(self.capture_bs)
|
312
315
|
)
|
313
316
|
for bs in capture_range:
|
317
|
+
if get_tensor_model_parallel_rank() == 0:
|
318
|
+
avail_mem = get_available_gpu_memory(
|
319
|
+
self.model_runner.device,
|
320
|
+
self.model_runner.gpu_id,
|
321
|
+
empty_cache=False,
|
322
|
+
)
|
323
|
+
capture_range.set_description(
|
324
|
+
f"Capturing batches ({avail_mem=:.2f} GB)"
|
325
|
+
)
|
326
|
+
|
314
327
|
with patch_model(
|
315
328
|
self.model_runner.model,
|
316
329
|
bs in self.compile_bs,
|
@@ -345,8 +358,18 @@ class CudaGraphRunner:
|
|
345
358
|
mrope_positions = self.mrope_positions[:, :bs]
|
346
359
|
|
347
360
|
if self.enable_dp_attention:
|
348
|
-
|
349
|
-
|
361
|
+
self.global_num_tokens_gpu.copy_(
|
362
|
+
torch.tensor(
|
363
|
+
[
|
364
|
+
num_tokens // self.dp_size + (i < bs % self.dp_size)
|
365
|
+
for i in range(self.dp_size)
|
366
|
+
],
|
367
|
+
dtype=torch.int32,
|
368
|
+
device=input_ids.device,
|
369
|
+
)
|
370
|
+
)
|
371
|
+
global_num_tokens = self.global_num_tokens_gpu
|
372
|
+
gathered_buffer = self.gathered_buffer[:num_tokens]
|
350
373
|
else:
|
351
374
|
global_num_tokens = None
|
352
375
|
gathered_buffer = None
|
@@ -371,7 +394,7 @@ class CudaGraphRunner:
|
|
371
394
|
encoder_lens=encoder_lens,
|
372
395
|
return_logprob=False,
|
373
396
|
positions=positions,
|
374
|
-
|
397
|
+
global_num_tokens_gpu=global_num_tokens,
|
375
398
|
gathered_buffer=gathered_buffer,
|
376
399
|
mrope_positions=mrope_positions,
|
377
400
|
spec_algorithm=self.model_runner.spec_algorithm,
|
@@ -392,6 +415,9 @@ class CudaGraphRunner:
|
|
392
415
|
|
393
416
|
# Run and capture
|
394
417
|
def run_once():
|
418
|
+
# Clean intermediate result cache for DP attention
|
419
|
+
forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
|
420
|
+
|
395
421
|
logits_output = forward(input_ids, forward_batch.positions, forward_batch)
|
396
422
|
return logits_output.next_token_logits, logits_output.hidden_states
|
397
423
|
|
@@ -426,7 +452,7 @@ class CudaGraphRunner:
|
|
426
452
|
self.capture_hidden_mode = hidden_mode_from_spec_info
|
427
453
|
self.capture()
|
428
454
|
|
429
|
-
def
|
455
|
+
def replay_prepare(self, forward_batch: ForwardBatch):
|
430
456
|
self.recapture_if_needed(forward_batch)
|
431
457
|
|
432
458
|
raw_bs = forward_batch.batch_size
|
@@ -435,7 +461,7 @@ class CudaGraphRunner:
|
|
435
461
|
# Pad
|
436
462
|
if self.enable_dp_attention:
|
437
463
|
index = bisect.bisect_left(
|
438
|
-
self.capture_bs,
|
464
|
+
self.capture_bs, sum(forward_batch.global_num_tokens_cpu)
|
439
465
|
)
|
440
466
|
else:
|
441
467
|
index = bisect.bisect_left(self.capture_bs, raw_bs)
|
@@ -459,6 +485,8 @@ class CudaGraphRunner:
|
|
459
485
|
self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
|
460
486
|
if forward_batch.mrope_positions is not None:
|
461
487
|
self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions)
|
488
|
+
if self.enable_dp_attention:
|
489
|
+
self.global_num_tokens_gpu.copy_(forward_batch.global_num_tokens_gpu)
|
462
490
|
|
463
491
|
if hasattr(forward_batch.spec_info, "hidden_states"):
|
464
492
|
self.hidden_states[:raw_num_token] = forward_batch.spec_info.hidden_states
|
@@ -475,14 +503,29 @@ class CudaGraphRunner:
|
|
475
503
|
seq_lens_cpu=self.seq_lens_cpu,
|
476
504
|
)
|
477
505
|
|
506
|
+
# Store fields
|
507
|
+
self.raw_bs = raw_bs
|
508
|
+
self.raw_num_token = raw_num_token
|
509
|
+
self.bs = bs
|
510
|
+
|
511
|
+
def replay(self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False):
|
512
|
+
if not skip_attn_backend_init:
|
513
|
+
self.replay_prepare(forward_batch)
|
514
|
+
else:
|
515
|
+
# In speculative decoding, these two fields are still needed.
|
516
|
+
self.input_ids[: self.raw_num_token].copy_(forward_batch.input_ids)
|
517
|
+
self.positions[: self.raw_num_token].copy_(forward_batch.positions)
|
518
|
+
|
478
519
|
# Replay
|
479
|
-
self.graphs[bs].replay()
|
480
|
-
next_token_logits, hidden_states = self.output_buffers[bs]
|
520
|
+
self.graphs[self.bs].replay()
|
521
|
+
next_token_logits, hidden_states = self.output_buffers[self.bs]
|
481
522
|
|
482
523
|
logits_output = LogitsProcessorOutput(
|
483
|
-
next_token_logits=next_token_logits[:raw_num_token],
|
524
|
+
next_token_logits=next_token_logits[: self.raw_num_token],
|
484
525
|
hidden_states=(
|
485
|
-
hidden_states[:raw_num_token]
|
526
|
+
hidden_states[: self.raw_num_token]
|
527
|
+
if hidden_states is not None
|
528
|
+
else None
|
486
529
|
),
|
487
530
|
)
|
488
531
|
return logits_output
|
@@ -38,7 +38,7 @@ import triton
|
|
38
38
|
import triton.language as tl
|
39
39
|
|
40
40
|
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
|
41
|
-
from sglang.srt.utils import get_compiler_backend
|
41
|
+
from sglang.srt.utils import get_compiler_backend
|
42
42
|
|
43
43
|
if TYPE_CHECKING:
|
44
44
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
@@ -263,15 +263,24 @@ class ForwardBatch:
|
|
263
263
|
extend_input_logprob_token_ids_gpu=extend_input_logprob_token_ids_gpu,
|
264
264
|
)
|
265
265
|
|
266
|
+
# For DP attention
|
266
267
|
if batch.global_num_tokens is not None:
|
267
268
|
ret.global_num_tokens_cpu = batch.global_num_tokens
|
268
|
-
|
269
|
+
ret.global_num_tokens_gpu = torch.tensor(
|
270
|
+
batch.global_num_tokens, dtype=torch.int64
|
271
|
+
).to(device, non_blocking=True)
|
272
|
+
|
273
|
+
ret.global_num_tokens_for_logprob_cpu = batch.global_num_tokens_for_logprob
|
274
|
+
ret.global_num_tokens_for_logprob_gpu = torch.tensor(
|
275
|
+
batch.global_num_tokens_for_logprob, dtype=torch.int64
|
276
|
+
).to(device, non_blocking=True)
|
277
|
+
|
278
|
+
sum_len = sum(batch.global_num_tokens)
|
269
279
|
ret.gathered_buffer = torch.zeros(
|
270
|
-
(
|
280
|
+
(sum_len, model_runner.model_config.hidden_size),
|
271
281
|
dtype=model_runner.dtype,
|
272
282
|
device=device,
|
273
283
|
)
|
274
|
-
|
275
284
|
if ret.forward_mode.is_idle():
|
276
285
|
ret.positions = torch.empty((0,), device=device)
|
277
286
|
return ret
|