sglang 0.4.0__py3-none-any.whl → 0.4.0.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +1 -1
- sglang/bench_offline_throughput.py +18 -6
- sglang/bench_one_batch.py +13 -0
- sglang/bench_serving.py +8 -1
- sglang/check_env.py +140 -48
- sglang/lang/backend/runtime_endpoint.py +1 -0
- sglang/lang/chat_template.py +32 -0
- sglang/llama3_eval.py +316 -0
- sglang/srt/constrained/outlines_backend.py +5 -0
- sglang/srt/constrained/xgrammar_backend.py +9 -6
- sglang/srt/layers/attention/__init__.py +5 -2
- sglang/srt/layers/attention/double_sparsity_backend.py +22 -8
- sglang/srt/layers/attention/flashinfer_backend.py +22 -5
- sglang/srt/layers/attention/torch_native_backend.py +22 -8
- sglang/srt/layers/attention/triton_backend.py +38 -33
- sglang/srt/layers/attention/triton_ops/decode_attention.py +305 -350
- sglang/srt/layers/attention/triton_ops/extend_attention.py +3 -0
- sglang/srt/layers/ep_moe/__init__.py +0 -0
- sglang/srt/layers/ep_moe/kernels.py +349 -0
- sglang/srt/layers/ep_moe/layer.py +665 -0
- sglang/srt/layers/fused_moe_triton/fused_moe.py +64 -21
- sglang/srt/layers/fused_moe_triton/layer.py +1 -1
- sglang/srt/layers/logits_processor.py +133 -95
- sglang/srt/layers/quantization/__init__.py +2 -47
- sglang/srt/layers/quantization/fp8.py +607 -0
- sglang/srt/layers/quantization/fp8_utils.py +27 -0
- sglang/srt/layers/radix_attention.py +11 -2
- sglang/srt/layers/sampler.py +29 -5
- sglang/srt/layers/torchao_utils.py +58 -45
- sglang/srt/managers/detokenizer_manager.py +37 -17
- sglang/srt/managers/io_struct.py +39 -10
- sglang/srt/managers/schedule_batch.py +39 -24
- sglang/srt/managers/schedule_policy.py +64 -5
- sglang/srt/managers/scheduler.py +236 -197
- sglang/srt/managers/tokenizer_manager.py +99 -58
- sglang/srt/managers/tp_worker_overlap_thread.py +7 -5
- sglang/srt/mem_cache/base_prefix_cache.py +2 -2
- sglang/srt/mem_cache/chunk_cache.py +2 -2
- sglang/srt/mem_cache/memory_pool.py +5 -1
- sglang/srt/mem_cache/radix_cache.py +12 -2
- sglang/srt/model_executor/cuda_graph_runner.py +39 -11
- sglang/srt/model_executor/model_runner.py +24 -9
- sglang/srt/model_parallel.py +67 -10
- sglang/srt/models/commandr.py +2 -2
- sglang/srt/models/deepseek_v2.py +87 -7
- sglang/srt/models/gemma2.py +34 -0
- sglang/srt/models/gemma2_reward.py +0 -1
- sglang/srt/models/granite.py +517 -0
- sglang/srt/models/grok.py +72 -13
- sglang/srt/models/llama.py +22 -5
- sglang/srt/models/llama_classification.py +11 -23
- sglang/srt/models/llama_reward.py +0 -2
- sglang/srt/models/llava.py +37 -14
- sglang/srt/models/mixtral.py +12 -9
- sglang/srt/models/phi3_small.py +0 -5
- sglang/srt/models/qwen2.py +20 -0
- sglang/srt/models/qwen2_moe.py +0 -5
- sglang/srt/models/torch_native_llama.py +0 -5
- sglang/srt/openai_api/adapter.py +4 -0
- sglang/srt/openai_api/protocol.py +9 -4
- sglang/srt/sampling/sampling_batch_info.py +9 -8
- sglang/srt/server.py +4 -4
- sglang/srt/server_args.py +62 -13
- sglang/srt/utils.py +57 -10
- sglang/test/test_utils.py +3 -2
- sglang/utils.py +10 -3
- sglang/version.py +1 -1
- {sglang-0.4.0.dist-info → sglang-0.4.0.post2.dist-info}/METADATA +15 -9
- {sglang-0.4.0.dist-info → sglang-0.4.0.post2.dist-info}/RECORD +72 -65
- {sglang-0.4.0.dist-info → sglang-0.4.0.post2.dist-info}/LICENSE +0 -0
- {sglang-0.4.0.dist-info → sglang-0.4.0.post2.dist-info}/WHEEL +0 -0
- {sglang-0.4.0.dist-info → sglang-0.4.0.post2.dist-info}/top_level.txt +0 -0
@@ -20,9 +20,11 @@ from contextlib import contextmanager
|
|
20
20
|
from enum import Enum, auto
|
21
21
|
from typing import Dict, List, Optional
|
22
22
|
|
23
|
+
import torch
|
24
|
+
|
23
25
|
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
|
24
26
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
25
|
-
from sglang.srt.mem_cache.radix_cache import TreeNode
|
27
|
+
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
|
26
28
|
|
27
29
|
# Clip the estimation of max_new_tokens for the request whose max_new_tokens is very large.
|
28
30
|
# This can prevent the server from being too conservative.
|
@@ -32,6 +34,21 @@ CLIP_MAX_NEW_TOKENS_ESTIMATION = int(
|
|
32
34
|
os.environ.get("SGLANG_CLIP_MAX_NEW_TOKENS_ESTIMATION", "4096")
|
33
35
|
)
|
34
36
|
|
37
|
+
# Threshold for in-batch prefix cache.
|
38
|
+
# If a request has a matched prefix length (against existing cache) less than this value,
|
39
|
+
# the scheduler runs the in-batch prefix caching check for this request.
|
40
|
+
# If we set it to -1, it means we disable in-batch prefix caching.
|
41
|
+
IN_BATCH_PREFIX_CACHING_CHECK_THRESHOLD = int(
|
42
|
+
os.environ.get("IN_BATCH_PREFIX_CACHING_CHECK_THRESHOLD", "32")
|
43
|
+
)
|
44
|
+
|
45
|
+
# Threshold for in-batch prefix cache.
|
46
|
+
# If a request has a matched prefix length (within the waiting queue) larger than this value,
|
47
|
+
# the scheduler deprioritizes this request
|
48
|
+
IN_BATCH_PREFIX_CACHING_DEPRIORITIZE_THRESHOLD = int(
|
49
|
+
os.environ.get("IN_BATCH_PREFIX_CACHING_DEPRIORITIZE_THRESHOLD", "32")
|
50
|
+
)
|
51
|
+
|
35
52
|
|
36
53
|
class SchedulePolicy:
|
37
54
|
def __init__(self, policy: str, tree_cache: BasePrefixCache):
|
@@ -42,6 +59,11 @@ class SchedulePolicy:
|
|
42
59
|
self.policy = policy
|
43
60
|
self.tree_cache = tree_cache
|
44
61
|
|
62
|
+
# It is used to find the matching prefix for in-batch prefix caching.
|
63
|
+
self.waiting_queue_radix_tree = RadixCache(
|
64
|
+
req_to_token_pool=None, token_to_kv_pool=None, disable=False
|
65
|
+
)
|
66
|
+
|
45
67
|
def calc_priority(self, waiting_queue: List[Req]):
|
46
68
|
if len(waiting_queue) > 128 and self.policy == "lpm":
|
47
69
|
# Turn off the expensive prefix matching and sorting when the #queue is large.
|
@@ -52,17 +74,53 @@ class SchedulePolicy:
|
|
52
74
|
# Compute matched prefix length
|
53
75
|
prefix_computed = False
|
54
76
|
if policy == "lpm" or policy == "dfs-weight":
|
77
|
+
# rid to deprioritize in the current run for in-batch prefix caching.
|
78
|
+
temporary_deprioritized = set()
|
79
|
+
self.waiting_queue_radix_tree.reset()
|
80
|
+
|
55
81
|
for r in waiting_queue:
|
82
|
+
prefix_ids = r.adjust_max_prefix_ids()
|
83
|
+
|
56
84
|
# NOTE: the prefix_indices must always be aligned with last_node
|
57
85
|
r.prefix_indices, r.last_node = self.tree_cache.match_prefix(
|
58
|
-
rid=r.rid, key=
|
86
|
+
rid=r.rid, key=prefix_ids
|
59
87
|
)
|
60
88
|
|
89
|
+
# NOTE(sang): This logic is for in-batch prefix caching;
|
90
|
+
# If there are more than 1 request that have small matching prefix from
|
91
|
+
# existing cache, but all those requests share the same prefix, we prefer
|
92
|
+
# to schedule only one of them so that we can increase the cache hit rate.
|
93
|
+
# We prefer to set IN_BATCH_PREFIX_CACHING_CHECK_THRESHOLD > 0 because too small
|
94
|
+
# threshold means we cannot use in-batch prefix caching for short prefixes.
|
95
|
+
# It is kind of common when the engine is long running (e.g., imagine the prefix "the").
|
96
|
+
if len(r.prefix_indices) <= IN_BATCH_PREFIX_CACHING_CHECK_THRESHOLD:
|
97
|
+
in_batch_matching_prefixes, _ = (
|
98
|
+
self.waiting_queue_radix_tree.match_prefix(
|
99
|
+
rid=r.rid, key=prefix_ids
|
100
|
+
)
|
101
|
+
)
|
102
|
+
if (
|
103
|
+
len(in_batch_matching_prefixes)
|
104
|
+
>= IN_BATCH_PREFIX_CACHING_DEPRIORITIZE_THRESHOLD
|
105
|
+
):
|
106
|
+
temporary_deprioritized.add(r.rid)
|
107
|
+
else:
|
108
|
+
# Insert with a dummy key
|
109
|
+
self.waiting_queue_radix_tree.insert(
|
110
|
+
prefix_ids, torch.empty(len(prefix_ids), dtype=torch.bool)
|
111
|
+
)
|
112
|
+
|
61
113
|
prefix_computed = True
|
62
114
|
|
63
115
|
if policy == "lpm":
|
64
116
|
# Longest Prefix Match
|
65
|
-
waiting_queue.sort(
|
117
|
+
waiting_queue.sort(
|
118
|
+
key=lambda r: (
|
119
|
+
-len(r.prefix_indices)
|
120
|
+
if r.rid not in temporary_deprioritized
|
121
|
+
else float("inf")
|
122
|
+
)
|
123
|
+
)
|
66
124
|
elif policy == "fcfs":
|
67
125
|
# first come first serve
|
68
126
|
pass
|
@@ -72,6 +130,7 @@ class SchedulePolicy:
|
|
72
130
|
elif policy == "random":
|
73
131
|
random.shuffle(waiting_queue)
|
74
132
|
elif policy == "dfs-weight":
|
133
|
+
# Experimental policy based on custom weights
|
75
134
|
last_node_to_reqs = defaultdict(list)
|
76
135
|
for req in waiting_queue:
|
77
136
|
last_node_to_reqs[req.last_node].append(req)
|
@@ -101,8 +160,8 @@ class SchedulePolicy:
|
|
101
160
|
def get_dfs_priority(
|
102
161
|
self,
|
103
162
|
cur_node: TreeNode,
|
104
|
-
node_to_priority: Dict,
|
105
|
-
last_node_to_reqs: Dict,
|
163
|
+
node_to_priority: Dict[TreeNode, int],
|
164
|
+
last_node_to_reqs: Dict[TreeNode, List[Req]],
|
106
165
|
q: List,
|
107
166
|
):
|
108
167
|
childs = [child for child in cur_node.children.values()]
|