sglang 0.4.10.post1__py3-none-any.whl → 0.5.0rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +113 -17
- sglang/compile_deep_gemm.py +8 -1
- sglang/global_config.py +5 -1
- sglang/srt/configs/model_config.py +35 -0
- sglang/srt/conversation.py +9 -117
- sglang/srt/disaggregation/base/conn.py +5 -2
- sglang/srt/disaggregation/decode.py +6 -1
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -0
- sglang/srt/disaggregation/mooncake/conn.py +243 -135
- sglang/srt/disaggregation/prefill.py +3 -0
- sglang/srt/distributed/device_communicators/pynccl.py +7 -0
- sglang/srt/distributed/device_communicators/pynccl_allocator.py +133 -0
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +42 -3
- sglang/srt/distributed/parallel_state.py +22 -9
- sglang/srt/entrypoints/context.py +244 -0
- sglang/srt/entrypoints/engine.py +8 -5
- sglang/srt/entrypoints/harmony_utils.py +370 -0
- sglang/srt/entrypoints/http_server.py +106 -15
- sglang/srt/entrypoints/openai/protocol.py +227 -1
- sglang/srt/entrypoints/openai/serving_chat.py +278 -42
- sglang/srt/entrypoints/openai/serving_responses.py +1273 -0
- sglang/srt/entrypoints/openai/tool_server.py +174 -0
- sglang/srt/entrypoints/tool.py +87 -0
- sglang/srt/eplb/expert_distribution.py +4 -2
- sglang/srt/eplb/expert_location.py +5 -1
- sglang/srt/function_call/harmony_tool_parser.py +130 -0
- sglang/srt/hf_transformers_utils.py +55 -13
- sglang/srt/jinja_template_utils.py +8 -1
- sglang/srt/layers/attention/aiter_backend.py +5 -8
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1700 -0
- sglang/srt/layers/attention/flashattention_backend.py +7 -11
- sglang/srt/layers/attention/triton_backend.py +85 -14
- sglang/srt/layers/attention/triton_ops/decode_attention.py +17 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +143 -98
- sglang/srt/layers/attention/trtllm_mha_backend.py +332 -0
- sglang/srt/layers/attention/trtllm_mla_backend.py +6 -6
- sglang/srt/layers/attention/vision.py +40 -15
- sglang/srt/layers/communicator.py +35 -8
- sglang/srt/layers/dp_attention.py +12 -0
- sglang/srt/layers/linear.py +9 -8
- sglang/srt/layers/logits_processor.py +9 -1
- sglang/srt/layers/moe/cutlass_moe.py +20 -6
- sglang/srt/layers/moe/ep_moe/layer.py +87 -107
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=352,device_name=NVIDIA_RTX_6000_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +101 -12
- sglang/srt/layers/moe/fused_moe_triton/layer.py +442 -58
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +169 -15
- sglang/srt/layers/moe/token_dispatcher/__init__.py +23 -0
- sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +12 -1
- sglang/srt/layers/moe/{ep_moe/token_dispatcher.py → token_dispatcher/deepep.py} +8 -15
- sglang/srt/layers/moe/topk.py +12 -3
- sglang/srt/layers/moe/utils.py +59 -0
- sglang/srt/layers/quantization/__init__.py +22 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +3 -2
- sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +1 -1
- sglang/srt/layers/quantization/fp4.py +557 -0
- sglang/srt/layers/quantization/fp8.py +8 -7
- sglang/srt/layers/quantization/fp8_kernel.py +0 -4
- sglang/srt/layers/quantization/fp8_utils.py +29 -0
- sglang/srt/layers/quantization/modelopt_quant.py +259 -64
- sglang/srt/layers/quantization/mxfp4.py +651 -0
- sglang/srt/layers/quantization/mxfp4_tensor.py +133 -0
- sglang/srt/layers/quantization/quark/__init__.py +0 -0
- sglang/srt/layers/quantization/quark/schemes/__init__.py +6 -0
- sglang/srt/layers/quantization/quark/schemes/quark_scheme.py +55 -0
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +118 -0
- sglang/srt/layers/quantization/quark/utils.py +107 -0
- sglang/srt/layers/quantization/unquant.py +60 -6
- sglang/srt/layers/quantization/w4afp8.py +1 -1
- sglang/srt/layers/rotary_embedding.py +225 -1
- sglang/srt/layers/utils.py +9 -0
- sglang/srt/layers/vocab_parallel_embedding.py +15 -4
- sglang/srt/lora/lora_manager.py +70 -14
- sglang/srt/lora/lora_registry.py +10 -2
- sglang/srt/lora/mem_pool.py +43 -5
- sglang/srt/managers/cache_controller.py +61 -32
- sglang/srt/managers/data_parallel_controller.py +52 -2
- sglang/srt/managers/detokenizer_manager.py +1 -1
- sglang/srt/managers/io_struct.py +21 -4
- sglang/srt/managers/mm_utils.py +5 -11
- sglang/srt/managers/schedule_batch.py +30 -8
- sglang/srt/managers/schedule_policy.py +3 -1
- sglang/srt/managers/scheduler.py +170 -18
- sglang/srt/managers/scheduler_output_processor_mixin.py +1 -2
- sglang/srt/managers/scheduler_recv_skipper.py +37 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +6 -0
- sglang/srt/managers/template_manager.py +59 -22
- sglang/srt/managers/tokenizer_manager.py +137 -67
- sglang/srt/managers/tp_worker.py +3 -0
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -0
- sglang/srt/managers/utils.py +45 -1
- sglang/srt/mem_cache/cpp_radix_tree/radix_tree.py +182 -0
- sglang/srt/mem_cache/hicache_storage.py +13 -21
- sglang/srt/mem_cache/hiradix_cache.py +53 -5
- sglang/srt/mem_cache/memory_pool_host.py +1 -1
- sglang/srt/mem_cache/multimodal_cache.py +33 -13
- sglang/srt/mem_cache/radix_cache_cpp.py +229 -0
- sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +2 -2
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_utils.cpp +35 -0
- sglang/srt/model_executor/cuda_graph_runner.py +24 -9
- sglang/srt/model_executor/forward_batch_info.py +48 -17
- sglang/srt/model_executor/model_runner.py +24 -2
- sglang/srt/model_loader/weight_utils.py +10 -0
- sglang/srt/models/bailing_moe.py +425 -0
- sglang/srt/models/deepseek_v2.py +95 -50
- sglang/srt/models/ernie4.py +426 -0
- sglang/srt/models/ernie4_eagle.py +203 -0
- sglang/srt/models/gemma3n_mm.py +39 -0
- sglang/srt/models/glm4_moe.py +102 -27
- sglang/srt/models/gpt_oss.py +1134 -0
- sglang/srt/models/grok.py +3 -3
- sglang/srt/models/llama4.py +13 -2
- sglang/srt/models/mixtral.py +3 -3
- sglang/srt/models/mllama4.py +428 -19
- sglang/srt/models/qwen2.py +6 -0
- sglang/srt/models/qwen2_moe.py +7 -4
- sglang/srt/models/qwen3_moe.py +39 -14
- sglang/srt/models/step3_vl.py +10 -1
- sglang/srt/models/transformers.py +2 -5
- sglang/srt/multimodal/processors/base_processor.py +4 -3
- sglang/srt/multimodal/processors/gemma3n.py +0 -7
- sglang/srt/multimodal/processors/step3_vl.py +3 -1
- sglang/srt/operations_strategy.py +1 -1
- sglang/srt/reasoning_parser.py +18 -39
- sglang/srt/server_args.py +218 -23
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +18 -0
- sglang/srt/two_batch_overlap.py +163 -9
- sglang/srt/utils.py +41 -26
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/runners.py +4 -4
- sglang/test/test_utils.py +4 -4
- sglang/version.py +1 -1
- {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/METADATA +18 -15
- {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/RECORD +143 -116
- /sglang/srt/mem_cache/{mooncake_store → storage/mooncake_store}/mooncake_store.py +0 -0
- /sglang/srt/mem_cache/{mooncake_store → storage/mooncake_store}/unit_test.py +0 -0
- /sglang/srt/mem_cache/{nixl → storage/nixl}/hicache_nixl.py +0 -0
- /sglang/srt/mem_cache/{nixl → storage/nixl}/nixl_utils.py +0 -0
- /sglang/srt/mem_cache/{nixl → storage/nixl}/test_hicache_nixl_storage.py +0 -0
- {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/WHEEL +0 -0
- {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/top_level.txt +0 -0
@@ -2,11 +2,12 @@ import heapq
|
|
2
2
|
import logging
|
3
3
|
import threading
|
4
4
|
import time
|
5
|
+
from queue import Queue
|
5
6
|
from typing import List, Optional
|
6
7
|
|
7
8
|
import torch
|
8
9
|
|
9
|
-
from sglang.srt.managers.cache_controller import HiCacheController
|
10
|
+
from sglang.srt.managers.cache_controller import HiCacheController, PrefetchOperation
|
10
11
|
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
11
12
|
from sglang.srt.mem_cache.base_prefix_cache import MatchResult
|
12
13
|
from sglang.srt.mem_cache.memory_pool import (
|
@@ -37,6 +38,7 @@ class HiRadixCache(RadixCache):
|
|
37
38
|
hicache_io_backend: str,
|
38
39
|
hicache_mem_layout: str,
|
39
40
|
hicache_storage_backend: Optional[str] = None,
|
41
|
+
hicache_storage_prefetch_policy: Optional[str] = "best_effort",
|
40
42
|
):
|
41
43
|
|
42
44
|
if hicache_io_backend == "direct":
|
@@ -85,6 +87,13 @@ class HiRadixCache(RadixCache):
|
|
85
87
|
prefetch_threshold=self.prefetch_threshold,
|
86
88
|
)
|
87
89
|
|
90
|
+
self.prefetch_stop_policy = hicache_storage_prefetch_policy
|
91
|
+
# todo: customizable storage prefetch timeout
|
92
|
+
self.prefetch_timeout = 3 # seconds
|
93
|
+
logger.info(
|
94
|
+
f"HiCache storage prefetch policy: {hicache_storage_prefetch_policy}"
|
95
|
+
)
|
96
|
+
|
88
97
|
# record the nodes with ongoing write through
|
89
98
|
self.ongoing_write_through = {}
|
90
99
|
# record the node segments with ongoing load back
|
@@ -385,9 +394,10 @@ class HiRadixCache(RadixCache):
|
|
385
394
|
for _ in range(queue_size.item()):
|
386
395
|
req_id = self.cache_controller.prefetch_revoke_queue.get()
|
387
396
|
if req_id in self.ongoing_prefetch:
|
388
|
-
last_host_node,
|
397
|
+
last_host_node, token_ids, _, _ = self.ongoing_prefetch[req_id]
|
389
398
|
last_host_node.release_host()
|
390
399
|
del self.ongoing_prefetch[req_id]
|
400
|
+
self.cache_controller.prefetch_tokens_occupied -= len(token_ids)
|
391
401
|
else:
|
392
402
|
# the revoked operation already got terminated
|
393
403
|
pass
|
@@ -419,10 +429,41 @@ class HiRadixCache(RadixCache):
|
|
419
429
|
host_node.release_host()
|
420
430
|
del self.ongoing_backup[ack_id]
|
421
431
|
|
422
|
-
def
|
432
|
+
def can_terminate_prefetch(self, operation: PrefetchOperation):
|
433
|
+
can_terminate = True
|
434
|
+
|
435
|
+
if self.prefetch_stop_policy == "best_effort":
|
436
|
+
return can_terminate
|
437
|
+
|
438
|
+
completed = (
|
439
|
+
operation.completed_tokens == len(operation.hash_value) * self.page_size
|
440
|
+
)
|
441
|
+
|
442
|
+
if self.prefetch_stop_policy == "wait_complete":
|
443
|
+
can_terminate = completed
|
444
|
+
elif self.prefetch_stop_policy == "timeout":
|
445
|
+
can_terminate = completed or (
|
446
|
+
time.monotonic() - operation.start_time > self.prefetch_timeout
|
447
|
+
)
|
448
|
+
else:
|
449
|
+
# unknown prefetch stop policy, just return True
|
450
|
+
return True
|
451
|
+
|
452
|
+
if self.tp_world_size > 1:
|
453
|
+
can_terminate = torch.tensor(can_terminate, dtype=torch.int)
|
454
|
+
torch.distributed.all_reduce(
|
455
|
+
can_terminate,
|
456
|
+
op=torch.distributed.ReduceOp.MIN,
|
457
|
+
group=self.tp_group,
|
458
|
+
)
|
459
|
+
can_terminate = bool(can_terminate.item())
|
460
|
+
|
461
|
+
return can_terminate
|
462
|
+
|
463
|
+
def check_prefetch_progress(self, req_id: str) -> bool:
|
423
464
|
if req_id not in self.ongoing_prefetch:
|
424
465
|
# there is no ongoing prefetch for this request or it has been revoked
|
425
|
-
return
|
466
|
+
return True
|
426
467
|
|
427
468
|
# todo: more policies for prefetch progress such as timeout
|
428
469
|
# the current policy is to prefetch with best effort and terminate when queuing is over
|
@@ -430,13 +471,16 @@ class HiRadixCache(RadixCache):
|
|
430
471
|
req_id
|
431
472
|
]
|
432
473
|
|
474
|
+
if not self.can_terminate_prefetch(operation):
|
475
|
+
return False
|
476
|
+
|
433
477
|
completed_tokens, hash_value = self.cache_controller.terminate_prefetch(
|
434
478
|
operation
|
435
479
|
)
|
436
480
|
logger.debug(f"Prefetch {req_id} completed with {completed_tokens} tokens")
|
437
481
|
|
438
482
|
min_completed_tokens = completed_tokens
|
439
|
-
if self.tp_world_size > 1:
|
483
|
+
if self.tp_world_size > 1 and self.prefetch_stop_policy != "wait_complete":
|
440
484
|
# synchrnoize TP workers to make the same update to hiradix cache
|
441
485
|
completed_tokens_tensor = torch.tensor(
|
442
486
|
min_completed_tokens, dtype=torch.int
|
@@ -464,6 +508,9 @@ class HiRadixCache(RadixCache):
|
|
464
508
|
)
|
465
509
|
last_host_node.release_host()
|
466
510
|
del self.ongoing_prefetch[req_id]
|
511
|
+
self.cache_controller.prefetch_tokens_occupied -= len(token_ids)
|
512
|
+
|
513
|
+
return True
|
467
514
|
|
468
515
|
def match_prefix(self, key: List[int], **kwargs):
|
469
516
|
empty_value = torch.empty((0,), dtype=torch.int64, device=self.device)
|
@@ -531,6 +578,7 @@ class HiRadixCache(RadixCache):
|
|
531
578
|
host_indices,
|
532
579
|
operation,
|
533
580
|
)
|
581
|
+
self.cache_controller.prefetch_tokens_occupied += len(new_input_tokens)
|
534
582
|
|
535
583
|
def _insert_helper_host(self, node: TreeNode, key: List, host_value, hash_value):
|
536
584
|
node.last_access_time = time.monotonic()
|
@@ -618,7 +618,7 @@ class MLATokenToKVPoolHost(HostKVCache):
|
|
618
618
|
elif self.layout == "page_first":
|
619
619
|
transfer_kv_all_layer_mla_lf_pf(
|
620
620
|
src_layers=device_pool.data_ptrs,
|
621
|
-
|
621
|
+
dst=self.kv_buffer,
|
622
622
|
src_indices=device_indices,
|
623
623
|
dst_indices=host_indices,
|
624
624
|
item_size=self.token_stride_size,
|
@@ -1,24 +1,46 @@
|
|
1
|
+
import logging
|
2
|
+
from collections import OrderedDict
|
1
3
|
from typing import Dict
|
2
4
|
|
3
5
|
import torch
|
4
6
|
|
7
|
+
# Set up logging for cache behavior
|
8
|
+
logger = logging.getLogger(__name__)
|
9
|
+
|
5
10
|
|
6
11
|
class MultiModalCache:
|
7
|
-
"""MultiModalCache is used to store vlm encoder results"""
|
12
|
+
"""MultiModalCache is used to store vlm encoder results with LRU eviction"""
|
8
13
|
|
9
14
|
def __init__(
|
10
15
|
self,
|
11
16
|
max_size: int,
|
12
17
|
):
|
13
18
|
self.max_size = max_size
|
14
|
-
self.mm_cache:
|
19
|
+
self.mm_cache: OrderedDict[int, torch.Tensor] = OrderedDict()
|
15
20
|
self.current_size = 0
|
16
21
|
|
22
|
+
def _allocate(self, embedding_size: int) -> bool:
|
23
|
+
"""Allocate space by evicting least recently used entries"""
|
24
|
+
evictions = 0
|
25
|
+
while self.current_size + embedding_size > self.max_size and self.mm_cache:
|
26
|
+
_, old_embedding = self.mm_cache.popitem(last=False)
|
27
|
+
evicted_size = self._get_tensor_size(old_embedding)
|
28
|
+
self.current_size -= evicted_size
|
29
|
+
evictions += evicted_size
|
30
|
+
|
31
|
+
if evictions > 0:
|
32
|
+
logger.debug(
|
33
|
+
f"Cache eviction: evicted {evictions} bytes, remaining size: {self.current_size}/{self.max_size} bytes"
|
34
|
+
)
|
35
|
+
|
36
|
+
if self.current_size + embedding_size > self.max_size:
|
37
|
+
return False
|
38
|
+
return True
|
39
|
+
|
17
40
|
def put(self, mm_hash: int, embedding: torch.Tensor) -> bool:
|
18
|
-
if mm_hash in self.mm_cache:
|
19
|
-
return True
|
20
41
|
data_size = self._get_tensor_size(embedding)
|
21
|
-
|
42
|
+
# Lazy free cache if not enough space
|
43
|
+
if not self._allocate(data_size):
|
22
44
|
return False
|
23
45
|
self.mm_cache[mm_hash] = embedding
|
24
46
|
self.current_size += data_size
|
@@ -28,14 +50,12 @@ class MultiModalCache:
|
|
28
50
|
return mm_hash in self.mm_cache
|
29
51
|
|
30
52
|
def get(self, mm_hash: int) -> torch.Tensor:
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
return
|
36
|
-
|
37
|
-
self.current_size -= self._get_tensor_size(old_embedding)
|
38
|
-
return True
|
53
|
+
"""Get embedding and update LRU order"""
|
54
|
+
if mm_hash in self.mm_cache:
|
55
|
+
# Move to end (most recently used)
|
56
|
+
self.mm_cache.move_to_end(mm_hash)
|
57
|
+
return self.mm_cache[mm_hash]
|
58
|
+
return None
|
39
59
|
|
40
60
|
def clear(self):
|
41
61
|
self.mm_cache.clear()
|
@@ -0,0 +1,229 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import logging
|
4
|
+
from typing import TYPE_CHECKING, List, Set
|
5
|
+
|
6
|
+
import torch
|
7
|
+
|
8
|
+
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
9
|
+
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult
|
10
|
+
from sglang.srt.mem_cache.cpp_radix_tree.radix_tree import (
|
11
|
+
IOHandle,
|
12
|
+
RadixTreeCpp,
|
13
|
+
TreeNodeCpp,
|
14
|
+
)
|
15
|
+
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
from sglang.srt.managers.schedule_batch import Req
|
19
|
+
|
20
|
+
|
21
|
+
logger = logging.getLogger(__name__)
|
22
|
+
|
23
|
+
|
24
|
+
class RadixCacheCpp(BasePrefixCache):
|
25
|
+
def _merge_tensor(self, l: List[torch.Tensor]) -> torch.Tensor:
|
26
|
+
"""
|
27
|
+
Merge a list of tensors into a single tensor.
|
28
|
+
Args:
|
29
|
+
l (List[torch.Tensor]): List of tensors to merge.
|
30
|
+
Returns:
|
31
|
+
torch.Tensor: Merged tensor.
|
32
|
+
"""
|
33
|
+
if len(l) == 0:
|
34
|
+
return torch.empty(0, dtype=torch.int64, device=self.device)
|
35
|
+
elif len(l) == 1:
|
36
|
+
return l[0]
|
37
|
+
else:
|
38
|
+
return torch.cat(l)
|
39
|
+
|
40
|
+
def __init__(
|
41
|
+
self,
|
42
|
+
disable: bool,
|
43
|
+
use_hicache: bool,
|
44
|
+
req_to_token_pool: ReqToTokenPool,
|
45
|
+
token_to_kv_pool: BaseTokenToKVPoolAllocator,
|
46
|
+
tp_cache_group: torch.distributed.ProcessGroup,
|
47
|
+
page_size: int,
|
48
|
+
hicache_ratio: float,
|
49
|
+
hicache_size: int,
|
50
|
+
hicache_write_policy: str,
|
51
|
+
enable_kv_cache_events: bool = False,
|
52
|
+
hicache_oracle: bool = False,
|
53
|
+
enable_write_cancel: bool = False,
|
54
|
+
):
|
55
|
+
self.disable = disable
|
56
|
+
self.enable_write_cancel = enable_write_cancel
|
57
|
+
|
58
|
+
assert (
|
59
|
+
enable_kv_cache_events is False
|
60
|
+
), "HiRadixCache does not support kv cache events yet"
|
61
|
+
self.kv_cache = token_to_kv_pool.get_kvcache()
|
62
|
+
|
63
|
+
# record the nodes with ongoing write through
|
64
|
+
self.ongoing_write_through: Set[IOHandle] = set()
|
65
|
+
# record the node segments with ongoing load back
|
66
|
+
self.ongoing_load_back: Set[IOHandle] = set()
|
67
|
+
# todo: dynamically adjust the threshold
|
68
|
+
self.write_through_threshold = (
|
69
|
+
1 if hicache_write_policy == "write_through" else 2
|
70
|
+
)
|
71
|
+
self.device = token_to_kv_pool.device
|
72
|
+
self.token_to_kv_pool = token_to_kv_pool
|
73
|
+
self.req_to_token_pool = req_to_token_pool
|
74
|
+
self.page_size = page_size
|
75
|
+
|
76
|
+
self.tp_group = tp_cache_group
|
77
|
+
|
78
|
+
if not use_hicache:
|
79
|
+
self.tree = RadixTreeCpp(
|
80
|
+
disabled=self.disable,
|
81
|
+
page_size=page_size,
|
82
|
+
host_size=None, # no host cache, this should be removed in the future
|
83
|
+
write_through_threshold=self.write_through_threshold,
|
84
|
+
)
|
85
|
+
self.cache_controller = None
|
86
|
+
return # early return if hicache is not used
|
87
|
+
|
88
|
+
raise NotImplementedError("Host cache is not supported yet")
|
89
|
+
|
90
|
+
def reset(self):
|
91
|
+
if self.cache_controller is not None:
|
92
|
+
# need to clear the acks before resetting the cache controller
|
93
|
+
raise NotImplementedError("Host cache is not supported yet")
|
94
|
+
self.tree.reset()
|
95
|
+
|
96
|
+
def match_prefix(self, key: List[int], **kwargs) -> MatchResult:
|
97
|
+
device_indices_vec, host_indices_length, node_gpu, node_cpu = (
|
98
|
+
self.tree.match_prefix(key)
|
99
|
+
)
|
100
|
+
return MatchResult(
|
101
|
+
device_indices=self._merge_tensor(device_indices_vec),
|
102
|
+
last_device_node=node_gpu,
|
103
|
+
last_host_node=node_cpu,
|
104
|
+
host_hit_length=host_indices_length,
|
105
|
+
)
|
106
|
+
|
107
|
+
def _insert(self, key: List[int], value: torch.Tensor) -> int:
|
108
|
+
"""
|
109
|
+
Insert a key-value pair into the radix tree.
|
110
|
+
Args:
|
111
|
+
key (List[int]): The key to insert, represented as a list of integers.
|
112
|
+
value (torch.Tensor): The value to associate with the key.
|
113
|
+
Returns:
|
114
|
+
int: Number of device indices that were already present in the tree before the insertion.
|
115
|
+
"""
|
116
|
+
ongoing_write, length = self.tree.writing_through(key, value)
|
117
|
+
if self.cache_controller is None:
|
118
|
+
assert len(ongoing_write) == 0, "Implementation error"
|
119
|
+
return length
|
120
|
+
|
121
|
+
raise NotImplementedError("Host cache is not supported yet")
|
122
|
+
|
123
|
+
def dec_lock_ref(self, node: TreeNodeCpp):
|
124
|
+
"""
|
125
|
+
Decrement the reference count of a node to root of the radix tree.
|
126
|
+
Args:
|
127
|
+
node (TreeNodeCpp): The handle of the node to decrement the reference count for.
|
128
|
+
"""
|
129
|
+
self.tree.lock_ref(node, False) # do not increment
|
130
|
+
|
131
|
+
def inc_lock_ref(self, node: TreeNodeCpp):
|
132
|
+
"""
|
133
|
+
Increment the reference count of from a node to root of the radix tree.
|
134
|
+
Args:
|
135
|
+
node (TreeNodeCpp): The handle of the node to increment the reference count for.
|
136
|
+
"""
|
137
|
+
self.tree.lock_ref(node, True)
|
138
|
+
|
139
|
+
def evict(self, num_tokens: int):
|
140
|
+
evicted_device_indices = self.tree.evict(num_tokens)
|
141
|
+
for indice in evicted_device_indices:
|
142
|
+
self.token_to_kv_pool.free(indice)
|
143
|
+
|
144
|
+
def evictable_size(self):
|
145
|
+
return self.tree.evictable_size()
|
146
|
+
|
147
|
+
def protected_size(self):
|
148
|
+
return self.tree.protected_size()
|
149
|
+
|
150
|
+
def total_size(self):
|
151
|
+
return self.tree.total_size()
|
152
|
+
|
153
|
+
def cache_finished_req(self, req: Req):
|
154
|
+
"""Cache request when it finishes."""
|
155
|
+
assert req.req_pool_idx is not None
|
156
|
+
token_ids = (req.origin_input_ids + req.output_ids)[:-1]
|
157
|
+
overall_len = len(token_ids) # prefill + decode
|
158
|
+
kv_indices = self.req_to_token_pool.req_to_token[req.req_pool_idx, :overall_len]
|
159
|
+
|
160
|
+
# NOTE: our C++ implementation don't need `token_ids` and `kv_indices` to be page-aligned
|
161
|
+
# it will automatically align them, but length of them should be equal
|
162
|
+
old_prefix_len = len(req.prefix_indices) // self.page_size * self.page_size
|
163
|
+
new_prefix_len = self._insert(token_ids, kv_indices)
|
164
|
+
|
165
|
+
# NOTE: kv_indices[:old_prefix_len] == req.prefix_indices
|
166
|
+
assert old_prefix_len <= new_prefix_len, "Wrong prefix indices"
|
167
|
+
|
168
|
+
# KVCache between old & new is newly generated, but already exists in the pool
|
169
|
+
# we need to free this newly generated kv indices
|
170
|
+
if old_prefix_len < new_prefix_len:
|
171
|
+
self.token_to_kv_pool.free(kv_indices[old_prefix_len:new_prefix_len])
|
172
|
+
|
173
|
+
# need to free the unaligned part, since it cannot be inserted into the radix tree
|
174
|
+
if self.page_size != 1 and ( # unaligned tail only exists when page_size > 1
|
175
|
+
(unaligned_len := overall_len % self.page_size) > 0
|
176
|
+
):
|
177
|
+
# NOTE: sglang PagedAllocator support unaligned free (which will automatically align it)
|
178
|
+
self.token_to_kv_pool.free(kv_indices[overall_len - unaligned_len :])
|
179
|
+
|
180
|
+
# Remove req slot release the cache lock
|
181
|
+
self.dec_lock_ref(req.last_node)
|
182
|
+
self.req_to_token_pool.free(req.req_pool_idx)
|
183
|
+
|
184
|
+
def cache_unfinished_req(self, req: Req):
|
185
|
+
"""Cache request when it is unfinished."""
|
186
|
+
assert req.req_pool_idx is not None
|
187
|
+
token_ids = req.fill_ids
|
188
|
+
prefill_len = len(token_ids) # prefill only (maybe chunked)
|
189
|
+
kv_indices = self.req_to_token_pool.req_to_token[req.req_pool_idx, :prefill_len]
|
190
|
+
|
191
|
+
# NOTE: our C++ implementation don't need `token_ids` and `kv_indices` to be page-aligned
|
192
|
+
# it will automatically align them, but length of them should be equal
|
193
|
+
old_prefix_len = len(req.prefix_indices) // self.page_size * self.page_size
|
194
|
+
new_prefix_len = self._insert(token_ids, kv_indices)
|
195
|
+
|
196
|
+
# NOTE: kv_indices[:old_prefix_len] == req.prefix_indices
|
197
|
+
assert old_prefix_len <= new_prefix_len, "Wrong prefix indices"
|
198
|
+
|
199
|
+
# TODO(dark): optimize the `insert` and `match` (e.g. merge into 1 function)
|
200
|
+
# The prefix indices need to updated to reuse the kv indices in the pool
|
201
|
+
new_indices_vec, _, new_last_node, _ = self.tree.match_prefix(token_ids)
|
202
|
+
new_indices = self._merge_tensor(new_indices_vec)
|
203
|
+
assert new_prefix_len <= len(new_indices)
|
204
|
+
|
205
|
+
# KVCache between old & new is newly generated, but already exists in the pool
|
206
|
+
# we need to free this newly generated kv indices and reuse the indices in the pool
|
207
|
+
if old_prefix_len < new_prefix_len:
|
208
|
+
self.token_to_kv_pool.free(kv_indices[old_prefix_len:new_prefix_len])
|
209
|
+
reused_indices = new_indices[old_prefix_len:new_prefix_len]
|
210
|
+
self.req_to_token_pool.req_to_token[
|
211
|
+
req.req_pool_idx, old_prefix_len:new_prefix_len
|
212
|
+
] = reused_indices
|
213
|
+
|
214
|
+
if req.last_node != new_last_node:
|
215
|
+
self.dec_lock_ref(req.last_node)
|
216
|
+
self.inc_lock_ref(new_last_node)
|
217
|
+
|
218
|
+
# NOTE: there might be unaligned tail, so we may need to append it
|
219
|
+
assert len(new_indices) <= prefill_len < len(new_indices) + self.page_size
|
220
|
+
if self.page_size != 1 and len(new_indices) < prefill_len:
|
221
|
+
req.prefix_indices = torch.cat(
|
222
|
+
[new_indices, kv_indices[len(new_indices) :]]
|
223
|
+
)
|
224
|
+
else:
|
225
|
+
req.prefix_indices = new_indices
|
226
|
+
req.last_node = new_last_node
|
227
|
+
|
228
|
+
def pretty_print(self):
|
229
|
+
return self.tree.debug_print()
|
@@ -96,6 +96,8 @@ class Hf3fsClient:
|
|
96
96
|
)
|
97
97
|
self.iov_r = make_iovec(self.shm_r, self.hf3fs_mount_point)
|
98
98
|
self.iov_w = make_iovec(self.shm_w, self.hf3fs_mount_point)
|
99
|
+
self.shm_r.unlink()
|
100
|
+
self.shm_w.unlink()
|
99
101
|
|
100
102
|
self.rlock = threading.RLock()
|
101
103
|
self.wlock = threading.RLock()
|
@@ -176,8 +178,6 @@ class Hf3fsClient:
|
|
176
178
|
del self.iov_w
|
177
179
|
self.shm_r.close()
|
178
180
|
self.shm_w.close()
|
179
|
-
self.shm_r.unlink()
|
180
|
-
self.shm_w.unlink()
|
181
181
|
|
182
182
|
def flush(self) -> None:
|
183
183
|
os.fsync(self.file)
|
@@ -0,0 +1,35 @@
|
|
1
|
+
#include <torch/extension.h>
|
2
|
+
|
3
|
+
#include <cstring>
|
4
|
+
#include <vector>
|
5
|
+
|
6
|
+
void read_shm(const torch::Tensor &shm, std::vector<torch::Tensor> dst) {
|
7
|
+
py::gil_scoped_release release;
|
8
|
+
char *src_ptr = static_cast<char *>(shm.data_ptr());
|
9
|
+
size_t current = 0;
|
10
|
+
for (size_t i = 0; i < dst.size(); ++i) {
|
11
|
+
auto &t = dst[i];
|
12
|
+
size_t t_bytes = t.numel() * t.element_size();
|
13
|
+
char *dst_ptr = static_cast<char *>(t.data_ptr());
|
14
|
+
std::memcpy(dst_ptr, src_ptr + current, t_bytes);
|
15
|
+
current += t_bytes;
|
16
|
+
}
|
17
|
+
}
|
18
|
+
|
19
|
+
void write_shm(const std::vector<torch::Tensor> src, torch::Tensor &shm) {
|
20
|
+
py::gil_scoped_release release;
|
21
|
+
char *dst_ptr = static_cast<char *>(shm.data_ptr());
|
22
|
+
size_t current = 0;
|
23
|
+
for (size_t i = 0; i < src.size(); ++i) {
|
24
|
+
auto &t = src[i];
|
25
|
+
size_t t_bytes = t.numel() * t.element_size();
|
26
|
+
char *src_ptr = static_cast<char *>(t.data_ptr());
|
27
|
+
std::memcpy(dst_ptr + current, src_ptr, t_bytes);
|
28
|
+
current += t_bytes;
|
29
|
+
}
|
30
|
+
}
|
31
|
+
|
32
|
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
33
|
+
m.def("read_shm", &read_shm, "Read tensors from shared memory");
|
34
|
+
m.def("write_shm", &write_shm, "Write tensors to shared memory");
|
35
|
+
}
|
@@ -29,6 +29,9 @@ from torch.profiler import ProfilerActivity, profile
|
|
29
29
|
|
30
30
|
from sglang.srt.custom_op import CustomOp
|
31
31
|
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
32
|
+
from sglang.srt.distributed.device_communicators.pynccl_allocator import (
|
33
|
+
set_graph_pool_id,
|
34
|
+
)
|
32
35
|
from sglang.srt.distributed.parallel_state import GroupCoordinator, graph_capture
|
33
36
|
from sglang.srt.layers.dp_attention import DPPaddingMode, get_attention_tp_size
|
34
37
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
@@ -372,6 +375,11 @@ class CudaGraphRunner:
|
|
372
375
|
dtype=torch.bool,
|
373
376
|
device="cuda",
|
374
377
|
)
|
378
|
+
self.next_token_logits_buffer = torch.zeros(
|
379
|
+
(self.max_num_token, self.model_runner.model_config.vocab_size),
|
380
|
+
dtype=torch.float,
|
381
|
+
device="cuda",
|
382
|
+
)
|
375
383
|
|
376
384
|
# Capture
|
377
385
|
try:
|
@@ -517,6 +525,7 @@ class CudaGraphRunner:
|
|
517
525
|
else:
|
518
526
|
encoder_lens = None
|
519
527
|
mrope_positions = self.mrope_positions[:, :bs]
|
528
|
+
next_token_logits_buffer = self.next_token_logits_buffer[:num_tokens]
|
520
529
|
self.num_token_non_padded[...] = num_tokens
|
521
530
|
|
522
531
|
# pipeline parallelism
|
@@ -567,11 +576,11 @@ class CudaGraphRunner:
|
|
567
576
|
)
|
568
577
|
|
569
578
|
if self.model_runner.server_args.enable_lora:
|
570
|
-
# It is safe to capture CUDA graph using empty LoRA
|
571
|
-
# `--enable-lora` is set to True (and return immediately if the LoRA
|
572
|
-
|
579
|
+
# It is safe to capture CUDA graph using empty LoRA id, as the LoRA kernels will always be launched whenever
|
580
|
+
# `--enable-lora` is set to True (and return immediately if the LoRA id is empty for perf optimization).
|
581
|
+
lora_ids = [None] * bs
|
573
582
|
else:
|
574
|
-
|
583
|
+
lora_ids = None
|
575
584
|
|
576
585
|
forward_batch = ForwardBatch(
|
577
586
|
forward_mode=self.capture_forward_mode,
|
@@ -579,6 +588,8 @@ class CudaGraphRunner:
|
|
579
588
|
input_ids=input_ids,
|
580
589
|
req_pool_indices=req_pool_indices,
|
581
590
|
seq_lens=seq_lens,
|
591
|
+
next_token_logits_buffer=next_token_logits_buffer,
|
592
|
+
orig_seq_lens=seq_lens,
|
582
593
|
req_to_token_pool=self.model_runner.req_to_token_pool,
|
583
594
|
token_to_kv_pool=self.model_runner.token_to_kv_pool,
|
584
595
|
attn_backend=self.model_runner.attn_backend,
|
@@ -597,11 +608,11 @@ class CudaGraphRunner:
|
|
597
608
|
capture_hidden_mode=self.capture_hidden_mode,
|
598
609
|
num_token_non_padded=self.num_token_non_padded,
|
599
610
|
global_forward_mode=self.capture_forward_mode,
|
600
|
-
|
611
|
+
lora_ids=lora_ids,
|
601
612
|
)
|
602
613
|
self.tbo_plugin.capture_one_batch_size(forward_batch, num_tokens=num_tokens)
|
603
614
|
|
604
|
-
if
|
615
|
+
if lora_ids is not None:
|
605
616
|
self.model_runner.lora_manager.prepare_lora_batch(forward_batch)
|
606
617
|
|
607
618
|
# Attention backend
|
@@ -643,11 +654,15 @@ class CudaGraphRunner:
|
|
643
654
|
|
644
655
|
run_once()
|
645
656
|
|
646
|
-
|
647
|
-
|
657
|
+
if get_global_graph_memory_pool() is None:
|
658
|
+
set_global_graph_memory_pool(torch.cuda.graph_pool_handle())
|
659
|
+
# Set graph pool id globally to be able to use symmetric memory
|
660
|
+
set_graph_pool_id(get_global_graph_memory_pool())
|
661
|
+
with torch.cuda.graph(
|
662
|
+
graph, pool=get_global_graph_memory_pool(), stream=stream
|
663
|
+
):
|
648
664
|
out = run_once()
|
649
665
|
|
650
|
-
global_graph_memory_pool = graph.pool()
|
651
666
|
return graph, out
|
652
667
|
|
653
668
|
def recapture_if_needed(self, forward_batch: ForwardBatch):
|