sglang 0.4.9.post3__py3-none-any.whl → 0.4.9.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/srt/_custom_ops.py +29 -1
- sglang/srt/configs/model_config.py +1 -1
- sglang/srt/conversation.py +1 -1
- sglang/srt/disaggregation/common/conn.py +34 -6
- sglang/srt/disaggregation/mini_lb.py +3 -2
- sglang/srt/disaggregation/mooncake/conn.py +49 -20
- sglang/srt/disaggregation/mooncake/transfer_engine.py +4 -2
- sglang/srt/disaggregation/nixl/conn.py +17 -13
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -91
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +96 -1
- sglang/srt/distributed/device_communicators/quick_all_reduce.py +273 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +12 -5
- sglang/srt/distributed/parallel_state.py +70 -15
- sglang/srt/entrypoints/engine.py +2 -8
- sglang/srt/entrypoints/http_server.py +20 -32
- sglang/srt/entrypoints/openai/protocol.py +3 -3
- sglang/srt/entrypoints/openai/serving_chat.py +27 -4
- sglang/srt/function_call/base_format_detector.py +74 -12
- sglang/srt/function_call/deepseekv3_detector.py +26 -11
- sglang/srt/function_call/ebnf_composer.py +95 -63
- sglang/srt/function_call/function_call_parser.py +4 -4
- sglang/srt/function_call/kimik2_detector.py +41 -16
- sglang/srt/function_call/llama32_detector.py +6 -3
- sglang/srt/function_call/mistral_detector.py +11 -3
- sglang/srt/function_call/pythonic_detector.py +16 -14
- sglang/srt/function_call/qwen25_detector.py +12 -3
- sglang/srt/function_call/{qwen3_detector.py → qwen3_coder_detector.py} +10 -9
- sglang/srt/layers/activation.py +11 -3
- sglang/srt/layers/attention/base_attn_backend.py +3 -1
- sglang/srt/layers/communicator.py +12 -12
- sglang/srt/layers/dp_attention.py +72 -24
- sglang/srt/layers/logits_processor.py +34 -24
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=320,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +25 -224
- sglang/srt/layers/moe/topk.py +5 -13
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -9
- sglang/srt/layers/quantization/modelopt_quant.py +8 -4
- sglang/srt/layers/quantization/utils.py +0 -9
- sglang/srt/layers/radix_attention.py +5 -3
- sglang/srt/lora/lora_manager.py +133 -169
- sglang/srt/lora/lora_registry.py +124 -0
- sglang/srt/lora/mem_pool.py +2 -2
- sglang/srt/managers/cache_controller.py +53 -6
- sglang/srt/managers/io_struct.py +19 -1
- sglang/srt/managers/schedule_batch.py +13 -3
- sglang/srt/managers/scheduler.py +13 -25
- sglang/srt/managers/tokenizer_manager.py +28 -25
- sglang/srt/managers/tp_worker.py +2 -4
- sglang/srt/mem_cache/allocator.py +67 -7
- sglang/srt/mem_cache/hicache_storage.py +17 -1
- sglang/srt/mem_cache/hiradix_cache.py +30 -16
- sglang/srt/mem_cache/memory_pool_host.py +3 -0
- sglang/srt/model_executor/cuda_graph_runner.py +61 -25
- sglang/srt/model_executor/forward_batch_info.py +201 -29
- sglang/srt/model_executor/model_runner.py +41 -23
- sglang/srt/models/deepseek_v2.py +1 -2
- sglang/srt/models/mllama4.py +10 -3
- sglang/srt/models/qwen2_moe.py +0 -4
- sglang/srt/models/qwen3_moe.py +1 -6
- sglang/srt/reasoning_parser.py +46 -4
- sglang/srt/sampling/sampling_batch_info.py +6 -5
- sglang/srt/server_args.py +76 -55
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +33 -28
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +37 -36
- sglang/srt/speculative/eagle_utils.py +51 -23
- sglang/srt/speculative/eagle_worker.py +59 -44
- sglang/srt/two_batch_overlap.py +9 -5
- sglang/srt/utils.py +17 -68
- sglang/test/test_activation.py +50 -1
- sglang/version.py +1 -1
- {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post4.dist-info}/METADATA +5 -5
- {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post4.dist-info}/RECORD +75 -72
- {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post4.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post4.dist-info}/top_level.txt +0 -0
@@ -9,6 +9,12 @@ import torch
|
|
9
9
|
logger = logging.getLogger(__name__)
|
10
10
|
|
11
11
|
|
12
|
+
from sglang.srt.distributed import (
|
13
|
+
get_tensor_model_parallel_rank,
|
14
|
+
get_tensor_model_parallel_world_size,
|
15
|
+
)
|
16
|
+
|
17
|
+
|
12
18
|
def get_hash_str(token_ids: List[int], prior_hash: Optional[str] = None) -> str:
|
13
19
|
hasher = hashlib.sha256()
|
14
20
|
|
@@ -80,13 +86,20 @@ class HiCacheFile(HiCacheStorage):
|
|
80
86
|
|
81
87
|
def __init__(self, file_path: str = "/tmp/hicache"):
|
82
88
|
self.file_path = file_path
|
83
|
-
|
89
|
+
tp_rank = get_tensor_model_parallel_rank()
|
90
|
+
tp_size = get_tensor_model_parallel_world_size()
|
91
|
+
self.tp_suffix = f"_{tp_rank}_{tp_size}" if tp_size > 1 else ""
|
92
|
+
if not os.path.exists(self.file_path) and tp_rank == 0:
|
84
93
|
os.makedirs(self.file_path)
|
85
94
|
logger.info(f"Created HiCacheFile storage directory at {self.file_path}")
|
86
95
|
|
96
|
+
def _get_suffixed_key(self, key: str) -> str:
|
97
|
+
return key + self.tp_suffix
|
98
|
+
|
87
99
|
def get(
|
88
100
|
self, key: str, target_location: Optional[torch.Tensor] = None
|
89
101
|
) -> torch.Tensor | None:
|
102
|
+
key = self._get_suffixed_key(key)
|
90
103
|
tensor_path = os.path.join(self.file_path, f"{key}.bin")
|
91
104
|
try:
|
92
105
|
# todo: fixing the target_location logic to enable in-place loading
|
@@ -112,6 +125,7 @@ class HiCacheFile(HiCacheStorage):
|
|
112
125
|
]
|
113
126
|
|
114
127
|
def set(self, key: str, value: torch.Tensor) -> bool:
|
128
|
+
key = self._get_suffixed_key(key)
|
115
129
|
tensor_path = os.path.join(self.file_path, f"{key}.bin")
|
116
130
|
if self.exists(key):
|
117
131
|
logger.debug(f"Key {key} already exists. Skipped.")
|
@@ -130,10 +144,12 @@ class HiCacheFile(HiCacheStorage):
|
|
130
144
|
return True
|
131
145
|
|
132
146
|
def exists(self, key: str) -> bool:
|
147
|
+
key = self._get_suffixed_key(key)
|
133
148
|
tensor_path = os.path.join(self.file_path, f"{key}.bin")
|
134
149
|
return os.path.exists(tensor_path)
|
135
150
|
|
136
151
|
def delete(self, key: str) -> None:
|
152
|
+
key = self._get_suffixed_key(key)
|
137
153
|
tensor_path = os.path.join(self.file_path, f"{key}.bin")
|
138
154
|
try:
|
139
155
|
os.remove(tensor_path)
|
@@ -50,6 +50,7 @@ class HiRadixCache(RadixCache):
|
|
50
50
|
raise ValueError(f"HiRadixCache only supports MHA and MLA yet")
|
51
51
|
|
52
52
|
self.tp_group = tp_cache_group
|
53
|
+
self.tp_world_size = torch.distributed.get_world_size(group=self.tp_group)
|
53
54
|
self.enable_storage = hicache_storage_backend is not None
|
54
55
|
# todo: customizable storage prefetch threshold
|
55
56
|
self.prefetch_threshold = 256
|
@@ -59,6 +60,7 @@ class HiRadixCache(RadixCache):
|
|
59
60
|
token_to_kv_pool_allocator,
|
60
61
|
self.token_to_kv_pool_host,
|
61
62
|
page_size,
|
63
|
+
self.tp_group,
|
62
64
|
load_cache_event=self.load_cache_event,
|
63
65
|
write_policy=hicache_write_policy,
|
64
66
|
io_backend=hicache_io_backend,
|
@@ -153,7 +155,7 @@ class HiRadixCache(RadixCache):
|
|
153
155
|
queue_size = torch.tensor(
|
154
156
|
self.cache_controller.ack_write_queue.qsize(), dtype=torch.int
|
155
157
|
)
|
156
|
-
if
|
158
|
+
if self.tp_world_size > 1:
|
157
159
|
# synchrnoize TP workers to make the same update to radix cache
|
158
160
|
torch.distributed.all_reduce(
|
159
161
|
queue_size,
|
@@ -353,7 +355,7 @@ class HiRadixCache(RadixCache):
|
|
353
355
|
queue_size = torch.tensor(
|
354
356
|
self.cache_controller.prefetch_revoke_queue.qsize(), dtype=torch.int
|
355
357
|
)
|
356
|
-
if
|
358
|
+
if self.tp_world_size > 1:
|
357
359
|
# synchrnoize TP workers to make the same update to hiradix cache
|
358
360
|
torch.distributed.all_reduce(
|
359
361
|
queue_size,
|
@@ -372,7 +374,7 @@ class HiRadixCache(RadixCache):
|
|
372
374
|
queue_size = torch.tensor(
|
373
375
|
self.cache_controller.ack_backup_queue.qsize(), dtype=torch.int
|
374
376
|
)
|
375
|
-
if
|
377
|
+
if self.tp_world_size > 1:
|
376
378
|
# synchrnoize TP workers to make the same update to hiradix cache
|
377
379
|
torch.distributed.all_reduce(
|
378
380
|
queue_size,
|
@@ -380,9 +382,15 @@ class HiRadixCache(RadixCache):
|
|
380
382
|
group=self.tp_group,
|
381
383
|
)
|
382
384
|
for _ in range(queue_size.item()):
|
383
|
-
ack_id, hash_value =
|
384
|
-
|
385
|
-
|
385
|
+
ack_id, hash_value, completed_tokens = (
|
386
|
+
self.cache_controller.ack_backup_queue.get()
|
387
|
+
)
|
388
|
+
host_node = self.ongoing_backup[ack_id]
|
389
|
+
if completed_tokens < len(host_node.key):
|
390
|
+
# backup is only partially successful, split the node
|
391
|
+
new_node = self._split_node(host_node.key, host_node, completed_tokens)
|
392
|
+
new_node.hash_value = hash_value
|
393
|
+
host_node.release_host()
|
386
394
|
del self.ongoing_backup[ack_id]
|
387
395
|
|
388
396
|
def check_prefetch_progress(self, req_id: str):
|
@@ -400,15 +408,18 @@ class HiRadixCache(RadixCache):
|
|
400
408
|
)
|
401
409
|
logger.debug(f"Prefetch {req_id} completed with {completed_tokens} tokens")
|
402
410
|
|
403
|
-
min_completed_tokens =
|
404
|
-
if
|
411
|
+
min_completed_tokens = completed_tokens
|
412
|
+
if self.tp_world_size > 1:
|
405
413
|
# synchrnoize TP workers to make the same update to hiradix cache
|
414
|
+
completed_tokens_tensor = torch.tensor(
|
415
|
+
min_completed_tokens, dtype=torch.int
|
416
|
+
)
|
406
417
|
torch.distributed.all_reduce(
|
407
|
-
|
418
|
+
completed_tokens_tensor,
|
408
419
|
op=torch.distributed.ReduceOp.MIN,
|
409
420
|
group=self.tp_group,
|
410
421
|
)
|
411
|
-
|
422
|
+
min_completed_tokens = completed_tokens_tensor.item()
|
412
423
|
fetched_token_ids = token_ids[:min_completed_tokens]
|
413
424
|
written_indices = host_indices[:min_completed_tokens]
|
414
425
|
matched_length = self._insert_helper_host(
|
@@ -465,16 +476,19 @@ class HiRadixCache(RadixCache):
|
|
465
476
|
new_input_tokens: List[int],
|
466
477
|
last_hash: Optional[str] = None,
|
467
478
|
):
|
468
|
-
|
479
|
+
# align the number of fetching tokens to the page size
|
480
|
+
prefetch_length = len(new_input_tokens) - (
|
481
|
+
len(new_input_tokens) % self.page_size
|
482
|
+
)
|
483
|
+
new_input_tokens = new_input_tokens[:prefetch_length]
|
484
|
+
if not self.enable_storage or prefetch_length < self.prefetch_threshold:
|
469
485
|
return
|
470
486
|
|
471
487
|
last_host_node.protect_host()
|
472
|
-
host_indices = self.cache_controller.mem_pool_host.alloc(
|
488
|
+
host_indices = self.cache_controller.mem_pool_host.alloc(prefetch_length)
|
473
489
|
if host_indices is None:
|
474
|
-
self.evict_host(
|
475
|
-
host_indices = self.cache_controller.mem_pool_host.alloc(
|
476
|
-
len(new_input_tokens)
|
477
|
-
)
|
490
|
+
self.evict_host(prefetch_length)
|
491
|
+
host_indices = self.cache_controller.mem_pool_host.alloc(prefetch_length)
|
478
492
|
if host_indices is None:
|
479
493
|
last_host_node.release_host()
|
480
494
|
# no sufficient host memory to prefetch
|
@@ -126,6 +126,9 @@ class HostKVCache(abc.ABC):
|
|
126
126
|
|
127
127
|
@synchronized()
|
128
128
|
def alloc(self, need_size: int) -> torch.Tensor:
|
129
|
+
assert (
|
130
|
+
need_size % self.page_size == 0
|
131
|
+
), "The requested size should be a multiple of the page size."
|
129
132
|
if need_size > self.available_size():
|
130
133
|
return None
|
131
134
|
|
@@ -29,9 +29,9 @@ from torch.profiler import ProfilerActivity, profile
|
|
29
29
|
from sglang.srt.custom_op import CustomOp
|
30
30
|
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
31
31
|
from sglang.srt.distributed.parallel_state import GroupCoordinator, graph_capture
|
32
|
+
from sglang.srt.layers.dp_attention import DPPaddingMode, get_attention_tp_size
|
32
33
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
33
34
|
from sglang.srt.layers.torchao_utils import save_gemlite_cache
|
34
|
-
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
35
35
|
from sglang.srt.model_executor.forward_batch_info import (
|
36
36
|
CaptureHiddenMode,
|
37
37
|
ForwardBatch,
|
@@ -167,8 +167,15 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
|
|
167
167
|
# is very small. We add more values here to make sure we capture the maximum bs.
|
168
168
|
capture_bs += [model_runner.req_to_token_pool.size]
|
169
169
|
|
170
|
+
mul_base = 1
|
171
|
+
|
170
172
|
if server_args.enable_two_batch_overlap:
|
171
|
-
|
173
|
+
mul_base *= 2
|
174
|
+
|
175
|
+
if require_gathered_buffer(server_args):
|
176
|
+
mul_base *= get_attention_tp_size()
|
177
|
+
|
178
|
+
capture_bs = [bs for bs in capture_bs if bs % mul_base == 0]
|
172
179
|
|
173
180
|
if server_args.cuda_graph_max_bs:
|
174
181
|
capture_bs = [bs for bs in capture_bs if bs <= server_args.cuda_graph_max_bs]
|
@@ -306,20 +313,37 @@ class CudaGraphRunner:
|
|
306
313
|
self.encoder_lens = None
|
307
314
|
|
308
315
|
if self.require_gathered_buffer:
|
309
|
-
self.gathered_buffer = torch.zeros(
|
310
|
-
(
|
311
|
-
self.max_num_token,
|
312
|
-
self.model_runner.model_config.hidden_size,
|
313
|
-
),
|
314
|
-
dtype=self.model_runner.dtype,
|
315
|
-
)
|
316
316
|
if self.require_mlp_tp_gather:
|
317
317
|
self.global_num_tokens_gpu = torch.zeros(
|
318
318
|
(self.dp_size,), dtype=torch.int32
|
319
319
|
)
|
320
|
+
self.global_num_tokens_for_logprob_gpu = torch.zeros(
|
321
|
+
(self.dp_size,), dtype=torch.int32
|
322
|
+
)
|
323
|
+
self.gathered_buffer = torch.zeros(
|
324
|
+
(
|
325
|
+
self.max_num_token * self.dp_size,
|
326
|
+
self.model_runner.model_config.hidden_size,
|
327
|
+
),
|
328
|
+
dtype=self.model_runner.dtype,
|
329
|
+
)
|
320
330
|
else:
|
321
331
|
assert self.require_attn_tp_gather
|
322
332
|
self.global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32)
|
333
|
+
self.global_num_tokens_for_logprob_gpu = torch.zeros(
|
334
|
+
(1,), dtype=torch.int32
|
335
|
+
)
|
336
|
+
self.gathered_buffer = torch.zeros(
|
337
|
+
(
|
338
|
+
self.max_num_token,
|
339
|
+
self.model_runner.model_config.hidden_size,
|
340
|
+
),
|
341
|
+
dtype=self.model_runner.dtype,
|
342
|
+
)
|
343
|
+
else:
|
344
|
+
self.global_num_tokens_gpu = None
|
345
|
+
self.global_num_tokens_for_logprob_gpu = None
|
346
|
+
self.gathered_buffer = None
|
323
347
|
|
324
348
|
self.custom_mask = torch.ones(
|
325
349
|
(
|
@@ -342,9 +366,9 @@ class CudaGraphRunner:
|
|
342
366
|
def can_run(self, forward_batch: ForwardBatch):
|
343
367
|
if self.require_mlp_tp_gather:
|
344
368
|
cuda_graph_bs = (
|
345
|
-
|
369
|
+
max(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
|
346
370
|
if self.model_runner.spec_algorithm.is_eagle()
|
347
|
-
else
|
371
|
+
else max(forward_batch.global_num_tokens_cpu)
|
348
372
|
)
|
349
373
|
else:
|
350
374
|
cuda_graph_bs = forward_batch.batch_size
|
@@ -480,16 +504,19 @@ class CudaGraphRunner:
|
|
480
504
|
if self.require_mlp_tp_gather:
|
481
505
|
self.global_num_tokens_gpu.copy_(
|
482
506
|
torch.tensor(
|
483
|
-
[
|
484
|
-
num_tokens // self.dp_size + (i < (num_tokens % self.dp_size))
|
485
|
-
for i in range(self.dp_size)
|
486
|
-
],
|
507
|
+
[num_tokens] * self.dp_size,
|
487
508
|
dtype=torch.int32,
|
488
509
|
device=input_ids.device,
|
489
510
|
)
|
490
511
|
)
|
491
|
-
|
492
|
-
|
512
|
+
self.global_num_tokens_for_logprob_gpu.copy_(
|
513
|
+
torch.tensor(
|
514
|
+
[num_tokens] * self.dp_size,
|
515
|
+
dtype=torch.int32,
|
516
|
+
device=input_ids.device,
|
517
|
+
)
|
518
|
+
)
|
519
|
+
gathered_buffer = self.gathered_buffer[: num_tokens * self.dp_size]
|
493
520
|
elif self.require_attn_tp_gather:
|
494
521
|
self.global_num_tokens_gpu.copy_(
|
495
522
|
torch.tensor(
|
@@ -498,10 +525,15 @@ class CudaGraphRunner:
|
|
498
525
|
device=input_ids.device,
|
499
526
|
)
|
500
527
|
)
|
501
|
-
|
528
|
+
self.global_num_tokens_for_logprob_gpu.copy_(
|
529
|
+
torch.tensor(
|
530
|
+
[num_tokens],
|
531
|
+
dtype=torch.int32,
|
532
|
+
device=input_ids.device,
|
533
|
+
)
|
534
|
+
)
|
502
535
|
gathered_buffer = self.gathered_buffer[:num_tokens]
|
503
536
|
else:
|
504
|
-
global_num_tokens = None
|
505
537
|
gathered_buffer = None
|
506
538
|
|
507
539
|
spec_info = self.get_spec_info(num_tokens)
|
@@ -531,7 +563,9 @@ class CudaGraphRunner:
|
|
531
563
|
encoder_lens=encoder_lens,
|
532
564
|
return_logprob=False,
|
533
565
|
positions=positions,
|
534
|
-
global_num_tokens_gpu=
|
566
|
+
global_num_tokens_gpu=self.global_num_tokens_gpu,
|
567
|
+
global_num_tokens_for_logprob_gpu=self.global_num_tokens_for_logprob_gpu,
|
568
|
+
dp_padding_mode=DPPaddingMode.get_default_mode_in_cuda_graph(),
|
535
569
|
gathered_buffer=gathered_buffer,
|
536
570
|
mrope_positions=mrope_positions,
|
537
571
|
spec_algorithm=self.model_runner.spec_algorithm,
|
@@ -635,12 +669,13 @@ class CudaGraphRunner:
|
|
635
669
|
|
636
670
|
# Pad
|
637
671
|
if self.require_mlp_tp_gather:
|
638
|
-
|
639
|
-
|
672
|
+
max_num_tokens = max(forward_batch.global_num_tokens_cpu)
|
673
|
+
max_batch_size = (
|
674
|
+
max_num_tokens / self.num_tokens_per_bs
|
640
675
|
if self.model_runner.spec_algorithm.is_eagle()
|
641
|
-
else
|
676
|
+
else max_num_tokens
|
642
677
|
)
|
643
|
-
index = bisect.bisect_left(self.capture_bs,
|
678
|
+
index = bisect.bisect_left(self.capture_bs, max_batch_size)
|
644
679
|
else:
|
645
680
|
index = bisect.bisect_left(self.capture_bs, raw_bs)
|
646
681
|
bs = self.capture_bs[index]
|
@@ -670,7 +705,8 @@ class CudaGraphRunner:
|
|
670
705
|
if forward_batch.mrope_positions is not None:
|
671
706
|
self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions)
|
672
707
|
if self.require_gathered_buffer:
|
673
|
-
self.global_num_tokens_gpu.
|
708
|
+
self.global_num_tokens_gpu.fill_(bs * self.num_tokens_per_bs)
|
709
|
+
self.global_num_tokens_for_logprob_gpu.fill_(bs * self.num_tokens_per_bs)
|
674
710
|
if enable_num_token_non_padded(self.model_runner.server_args):
|
675
711
|
self.num_token_non_padded.copy_(forward_batch.num_token_non_padded)
|
676
712
|
if self.enable_two_batch_overlap:
|
@@ -38,6 +38,11 @@ import torch
|
|
38
38
|
import triton
|
39
39
|
import triton.language as tl
|
40
40
|
|
41
|
+
from sglang.srt.layers.dp_attention import (
|
42
|
+
DPPaddingMode,
|
43
|
+
get_attention_dp_rank,
|
44
|
+
get_attention_tp_size,
|
45
|
+
)
|
41
46
|
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
|
42
47
|
from sglang.srt.utils import (
|
43
48
|
flatten_nested_list,
|
@@ -48,6 +53,7 @@ from sglang.srt.utils import (
|
|
48
53
|
|
49
54
|
if TYPE_CHECKING:
|
50
55
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
56
|
+
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
51
57
|
from sglang.srt.managers.schedule_batch import ModelWorkerBatch, MultimodalInputs
|
52
58
|
from sglang.srt.mem_cache.memory_pool import KVCache, ReqToTokenPool
|
53
59
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
@@ -68,8 +74,6 @@ class ForwardMode(IntEnum):
|
|
68
74
|
MIXED = auto()
|
69
75
|
# No sequence to forward. For data parallel attention, some workers will be IDLE if no sequence are allocated.
|
70
76
|
IDLE = auto()
|
71
|
-
# Split Prefill for PD multiplexing
|
72
|
-
SPLIT_PREFILL = auto()
|
73
77
|
|
74
78
|
# Used in speculative decoding: verify a batch in the target model.
|
75
79
|
TARGET_VERIFY = auto()
|
@@ -80,6 +84,9 @@ class ForwardMode(IntEnum):
|
|
80
84
|
# It is now used for triggering the sampling_info_done event for the first prefill batch.
|
81
85
|
DUMMY_FIRST = auto()
|
82
86
|
|
87
|
+
# Split Prefill for PD multiplexing
|
88
|
+
SPLIT_PREFILL = auto()
|
89
|
+
|
83
90
|
def is_prefill(self):
|
84
91
|
return self.is_extend()
|
85
92
|
|
@@ -97,12 +104,12 @@ class ForwardMode(IntEnum):
|
|
97
104
|
def is_mixed(self):
|
98
105
|
return self == ForwardMode.MIXED
|
99
106
|
|
100
|
-
def is_split_prefill(self):
|
101
|
-
return self == ForwardMode.SPLIT_PREFILL
|
102
|
-
|
103
107
|
def is_idle(self):
|
104
108
|
return self == ForwardMode.IDLE
|
105
109
|
|
110
|
+
def is_decode_or_idle(self):
|
111
|
+
return self == ForwardMode.DECODE or self == ForwardMode.IDLE
|
112
|
+
|
106
113
|
def is_target_verify(self):
|
107
114
|
return self == ForwardMode.TARGET_VERIFY
|
108
115
|
|
@@ -126,8 +133,8 @@ class ForwardMode(IntEnum):
|
|
126
133
|
def is_dummy_first(self):
|
127
134
|
return self == ForwardMode.DUMMY_FIRST
|
128
135
|
|
129
|
-
def
|
130
|
-
return self == ForwardMode.
|
136
|
+
def is_split_prefill(self):
|
137
|
+
return self == ForwardMode.SPLIT_PREFILL
|
131
138
|
|
132
139
|
|
133
140
|
@total_ordering
|
@@ -242,7 +249,7 @@ class ForwardBatch:
|
|
242
249
|
lora_paths: Optional[List[str]] = None
|
243
250
|
|
244
251
|
# For input embeddings
|
245
|
-
input_embeds: Optional[torch.
|
252
|
+
input_embeds: Optional[torch.Tensor] = None
|
246
253
|
|
247
254
|
# For cross-encoder model
|
248
255
|
token_type_ids: Optional[torch.Tensor] = None
|
@@ -261,6 +268,8 @@ class ForwardBatch:
|
|
261
268
|
# Has to be None when cuda graph is captured.
|
262
269
|
global_num_tokens_for_logprob_cpu: Optional[List[int]] = None
|
263
270
|
global_num_tokens_for_logprob_gpu: Optional[torch.Tensor] = None
|
271
|
+
# The padding mode for DP attention
|
272
|
+
dp_padding_mode: Optional[DPPaddingMode] = None
|
264
273
|
# for extend, local start pos and num tokens is different in logits processor
|
265
274
|
# this will be computed in get_dp_local_info
|
266
275
|
# this will be recomputed in LogitsMetadata.from_forward_batch
|
@@ -286,7 +295,7 @@ class ForwardBatch:
|
|
286
295
|
# For two-batch overlap
|
287
296
|
tbo_split_seq_index: Optional[int] = None
|
288
297
|
tbo_parent_token_range: Optional[Tuple[int, int]] = None
|
289
|
-
tbo_children: Optional[List[
|
298
|
+
tbo_children: Optional[List[ForwardBatch]] = None
|
290
299
|
|
291
300
|
@classmethod
|
292
301
|
def init_new(
|
@@ -340,20 +349,38 @@ class ForwardBatch:
|
|
340
349
|
len(batch.input_ids), dtype=torch.int32
|
341
350
|
).to(device, non_blocking=True)
|
342
351
|
|
343
|
-
# For
|
352
|
+
# For MLP sync
|
344
353
|
if batch.global_num_tokens is not None:
|
345
|
-
|
346
|
-
|
347
|
-
|
348
|
-
if batch.spec_num_draft_tokens is not None
|
349
|
-
else 1
|
354
|
+
from sglang.srt.speculative.eagle_utils import (
|
355
|
+
EagleDraftInput,
|
356
|
+
EagleVerifyInput,
|
350
357
|
)
|
351
|
-
|
352
|
-
|
353
|
-
|
354
|
-
|
355
|
-
|
356
|
-
|
358
|
+
|
359
|
+
assert batch.global_num_tokens_for_logprob is not None
|
360
|
+
# process global_num_tokens and global_num_tokens_for_logprob
|
361
|
+
if batch.spec_info is not None:
|
362
|
+
if isinstance(batch.spec_info, EagleDraftInput):
|
363
|
+
global_num_tokens = [
|
364
|
+
x * batch.spec_info.num_tokens_per_batch
|
365
|
+
for x in batch.global_num_tokens
|
366
|
+
]
|
367
|
+
global_num_tokens_for_logprob = [
|
368
|
+
x * batch.spec_info.num_tokens_for_logprob_per_batch
|
369
|
+
for x in batch.global_num_tokens_for_logprob
|
370
|
+
]
|
371
|
+
else:
|
372
|
+
assert isinstance(batch.spec_info, EagleVerifyInput)
|
373
|
+
global_num_tokens = [
|
374
|
+
x * batch.spec_info.draft_token_num
|
375
|
+
for x in batch.global_num_tokens
|
376
|
+
]
|
377
|
+
global_num_tokens_for_logprob = [
|
378
|
+
x * batch.spec_info.draft_token_num
|
379
|
+
for x in batch.global_num_tokens_for_logprob
|
380
|
+
]
|
381
|
+
else:
|
382
|
+
global_num_tokens = batch.global_num_tokens
|
383
|
+
global_num_tokens_for_logprob = batch.global_num_tokens_for_logprob
|
357
384
|
|
358
385
|
ret.global_num_tokens_cpu = global_num_tokens
|
359
386
|
ret.global_num_tokens_gpu = torch.tensor(
|
@@ -365,15 +392,8 @@ class ForwardBatch:
|
|
365
392
|
global_num_tokens_for_logprob, dtype=torch.int64
|
366
393
|
).to(device, non_blocking=True)
|
367
394
|
|
368
|
-
sum_len = sum(global_num_tokens)
|
369
|
-
ret.gathered_buffer = torch.zeros(
|
370
|
-
(sum_len, model_runner.model_config.hidden_size),
|
371
|
-
dtype=model_runner.dtype,
|
372
|
-
device=device,
|
373
|
-
)
|
374
|
-
|
375
395
|
if ret.forward_mode.is_idle():
|
376
|
-
ret.positions = torch.empty((0,), device=device)
|
396
|
+
ret.positions = torch.empty((0,), dtype=torch.int64, device=device)
|
377
397
|
TboForwardBatchPreparer.prepare(
|
378
398
|
ret, is_draft_worker=model_runner.is_draft_worker
|
379
399
|
)
|
@@ -573,6 +593,158 @@ class ForwardBatch:
|
|
573
593
|
)
|
574
594
|
self.prefix_chunk_kv_indices.append(chunk_kv_indices)
|
575
595
|
|
596
|
+
def _pad_tensor_to_size(self, tensor: torch.Tensor, size: int, *, value: int = 0):
|
597
|
+
if value == 0:
|
598
|
+
return torch.cat(
|
599
|
+
[tensor, tensor.new_zeros(size - tensor.shape[0], *tensor.shape[1:])],
|
600
|
+
dim=0,
|
601
|
+
)
|
602
|
+
else:
|
603
|
+
return torch.cat(
|
604
|
+
[
|
605
|
+
tensor,
|
606
|
+
tensor.new_full((size - tensor.shape[0], *tensor.shape[1:]), value),
|
607
|
+
],
|
608
|
+
dim=0,
|
609
|
+
)
|
610
|
+
|
611
|
+
def prepare_mlp_sync_batch(self, model_runner: ModelRunner):
|
612
|
+
|
613
|
+
from sglang.srt.speculative.eagle_utils import EagleDraftInput
|
614
|
+
|
615
|
+
assert self.global_num_tokens_cpu is not None
|
616
|
+
assert self.global_num_tokens_for_logprob_cpu is not None
|
617
|
+
|
618
|
+
global_num_tokens = self.global_num_tokens_cpu
|
619
|
+
sync_group_size = len(global_num_tokens)
|
620
|
+
attn_tp_size = get_attention_tp_size()
|
621
|
+
|
622
|
+
for i in range(sync_group_size):
|
623
|
+
# make sure that the padded length is divisible by attn_tp_size because we may need reduce-scatter across attn_tp dim.
|
624
|
+
# there is no reduce-scatter in LM logprob, so we do not need to adjust the padded length for logprob
|
625
|
+
global_num_tokens[i] = (
|
626
|
+
(global_num_tokens[i] - 1) // attn_tp_size + 1
|
627
|
+
) * attn_tp_size
|
628
|
+
|
629
|
+
dp_padding_mode = DPPaddingMode.get_dp_padding_mode(global_num_tokens)
|
630
|
+
self.dp_padding_mode = dp_padding_mode
|
631
|
+
|
632
|
+
if dp_padding_mode.is_max_len():
|
633
|
+
# when DP gather mode is all gather, we will use all_gather_into_tensor to gather hidden states,
|
634
|
+
# where transferred tokens should be padded to the same length.
|
635
|
+
max_num_tokens = max(global_num_tokens)
|
636
|
+
global_num_tokens = [max_num_tokens] * sync_group_size
|
637
|
+
buffer_len = max_num_tokens * sync_group_size
|
638
|
+
else:
|
639
|
+
buffer_len = sum(global_num_tokens)
|
640
|
+
|
641
|
+
self.gathered_buffer = torch.zeros(
|
642
|
+
(buffer_len, model_runner.model_config.hidden_size),
|
643
|
+
dtype=model_runner.dtype,
|
644
|
+
device=model_runner.device,
|
645
|
+
)
|
646
|
+
|
647
|
+
bs = self.batch_size
|
648
|
+
if len(global_num_tokens) > 1:
|
649
|
+
num_tokens = global_num_tokens[get_attention_dp_rank()]
|
650
|
+
else:
|
651
|
+
num_tokens = global_num_tokens[0]
|
652
|
+
|
653
|
+
# padding
|
654
|
+
self.input_ids = self._pad_tensor_to_size(self.input_ids, num_tokens)
|
655
|
+
self.req_pool_indices = self._pad_tensor_to_size(self.req_pool_indices, bs)
|
656
|
+
|
657
|
+
seq_len_fill_value = (
|
658
|
+
model_runner.attn_backend.get_cuda_graph_seq_len_fill_value()
|
659
|
+
)
|
660
|
+
self.seq_lens = self._pad_tensor_to_size(
|
661
|
+
self.seq_lens, bs, value=seq_len_fill_value
|
662
|
+
)
|
663
|
+
if self.seq_lens_cpu is not None:
|
664
|
+
self.seq_lens_cpu = self._pad_tensor_to_size(
|
665
|
+
self.seq_lens_cpu, bs, value=seq_len_fill_value
|
666
|
+
)
|
667
|
+
|
668
|
+
self.out_cache_loc = self._pad_tensor_to_size(self.out_cache_loc, num_tokens)
|
669
|
+
if self.encoder_lens is not None:
|
670
|
+
self.encoder_lens = self._pad_tensor_to_size(self.encoder_lens, bs)
|
671
|
+
self.positions = self._pad_tensor_to_size(self.positions, num_tokens)
|
672
|
+
self.global_num_tokens_cpu = global_num_tokens
|
673
|
+
self.global_num_tokens_gpu = self.global_num_tokens_gpu.new_tensor(
|
674
|
+
global_num_tokens
|
675
|
+
)
|
676
|
+
|
677
|
+
if self.mrope_positions is not None:
|
678
|
+
self.mrope_positions = self._pad_tensor_to_size(self.mrope_positions, bs)
|
679
|
+
|
680
|
+
if self.extend_seq_lens is not None:
|
681
|
+
self.extend_seq_lens = self._pad_tensor_to_size(self.extend_seq_lens, bs)
|
682
|
+
|
683
|
+
if self.spec_info is not None and isinstance(self.spec_info, EagleDraftInput):
|
684
|
+
spec_info = self.spec_info
|
685
|
+
self.output_cache_loc_backup = self.out_cache_loc
|
686
|
+
self.hidden_states_backup = spec_info.hidden_states
|
687
|
+
if spec_info.topk_p is not None:
|
688
|
+
spec_info.topk_p = self._pad_tensor_to_size(spec_info.topk_p, bs)
|
689
|
+
if spec_info.topk_index is not None:
|
690
|
+
spec_info.topk_index = self._pad_tensor_to_size(
|
691
|
+
spec_info.topk_index, bs
|
692
|
+
)
|
693
|
+
if spec_info.accept_length is not None:
|
694
|
+
spec_info.accept_length = self._pad_tensor_to_size(
|
695
|
+
spec_info.accept_length, bs
|
696
|
+
)
|
697
|
+
spec_info.hidden_states = self._pad_tensor_to_size(
|
698
|
+
spec_info.hidden_states, num_tokens
|
699
|
+
)
|
700
|
+
|
701
|
+
def post_forward_mlp_sync_batch(self, logits_output: LogitsProcessorOutput):
|
702
|
+
|
703
|
+
bs = self.batch_size
|
704
|
+
|
705
|
+
if self.spec_info is not None:
|
706
|
+
if self.forward_mode.is_decode(): # draft
|
707
|
+
num_tokens = self.hidden_states_backup.shape[0]
|
708
|
+
self.positions = self.positions[:num_tokens]
|
709
|
+
self.seq_lens = self.seq_lens[:bs]
|
710
|
+
self.req_pool_indices = self.req_pool_indices[:bs]
|
711
|
+
if self.seq_lens_cpu is not None:
|
712
|
+
self.seq_lens_cpu = self.seq_lens_cpu[:bs]
|
713
|
+
logits_output.next_token_logits = logits_output.next_token_logits[
|
714
|
+
:num_tokens
|
715
|
+
]
|
716
|
+
logits_output.hidden_states = logits_output.hidden_states[:num_tokens]
|
717
|
+
elif self.forward_mode.is_target_verify(): # verify
|
718
|
+
num_tokens = bs * self.spec_info.draft_token_num
|
719
|
+
logits_output.next_token_logits = logits_output.next_token_logits[
|
720
|
+
:num_tokens
|
721
|
+
]
|
722
|
+
logits_output.hidden_states = logits_output.hidden_states[:num_tokens]
|
723
|
+
elif self.forward_mode.is_draft_extend(): # draft extend
|
724
|
+
self.spec_info.accept_length = self.spec_info.accept_length[:bs]
|
725
|
+
logits_output.next_token_logits = logits_output.next_token_logits[:bs]
|
726
|
+
logits_output.hidden_states = logits_output.hidden_states[:bs]
|
727
|
+
elif self.forward_mode.is_extend() or self.forward_mode.is_idle():
|
728
|
+
logits_output.next_token_logits = logits_output.next_token_logits[:bs]
|
729
|
+
logits_output.hidden_states = logits_output.hidden_states[:bs]
|
730
|
+
|
731
|
+
if hasattr(self, "hidden_states_backup"):
|
732
|
+
self.spec_info.hidden_states = self.hidden_states_backup
|
733
|
+
if hasattr(self, "output_cache_loc_backup"):
|
734
|
+
self.out_cache_loc = self.output_cache_loc_backup
|
735
|
+
|
736
|
+
elif self.forward_mode.is_decode() or self.forward_mode.is_idle():
|
737
|
+
logits_output.next_token_logits = logits_output.next_token_logits[:bs]
|
738
|
+
if logits_output.hidden_states is not None:
|
739
|
+
logits_output.hidden_states = logits_output.hidden_states[:bs]
|
740
|
+
elif self.forward_mode.is_extend():
|
741
|
+
num_tokens = self.seq_lens_sum
|
742
|
+
logits_output.next_token_logits = logits_output.next_token_logits[
|
743
|
+
:num_tokens
|
744
|
+
]
|
745
|
+
if logits_output.hidden_states is not None:
|
746
|
+
logits_output.hidden_states = logits_output.hidden_states[:num_tokens]
|
747
|
+
|
576
748
|
# Here we suppose the length of each chunk is equal
|
577
749
|
# For example, if we have 4 sequences with prefix length [256, 512, 768, 1024], prefix_chunk_len = 256
|
578
750
|
# num_prefix_chunks = cdiv(1024, 256) = 4
|