sglang 0.4.1.post3__py3-none-any.whl → 0.4.1.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +2 -0
- sglang/srt/layers/attention/__init__.py +14 -5
- sglang/srt/layers/attention/double_sparsity_backend.py +0 -52
- sglang/srt/layers/attention/flashinfer_backend.py +211 -81
- sglang/srt/layers/attention/torch_native_backend.py +1 -38
- sglang/srt/layers/attention/triton_backend.py +20 -11
- sglang/srt/layers/attention/triton_ops/decode_attention.py +4 -0
- sglang/srt/layers/logits_processor.py +167 -212
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=2560,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=2560,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=320,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=320,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +187 -29
- sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -6
- sglang/srt/layers/quantization/fp8.py +2 -2
- sglang/srt/layers/sampler.py +57 -21
- sglang/srt/layers/torchao_utils.py +17 -3
- sglang/srt/managers/io_struct.py +1 -2
- sglang/srt/managers/schedule_batch.py +26 -2
- sglang/srt/managers/schedule_policy.py +159 -90
- sglang/srt/managers/scheduler.py +62 -26
- sglang/srt/managers/tokenizer_manager.py +22 -20
- sglang/srt/managers/tp_worker.py +16 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
- sglang/srt/model_executor/cuda_graph_runner.py +118 -73
- sglang/srt/model_executor/forward_batch_info.py +33 -8
- sglang/srt/model_executor/model_runner.py +63 -61
- sglang/srt/models/deepseek_v2.py +34 -7
- sglang/srt/models/grok.py +97 -26
- sglang/srt/openai_api/adapter.py +0 -17
- sglang/srt/openai_api/protocol.py +3 -3
- sglang/srt/sampling/sampling_batch_info.py +21 -0
- sglang/srt/sampling/sampling_params.py +9 -1
- sglang/srt/server.py +9 -5
- sglang/srt/server_args.py +108 -57
- sglang/srt/speculative/build_eagle_tree.py +347 -0
- sglang/srt/speculative/eagle_utils.py +618 -0
- sglang/srt/speculative/eagle_worker.py +170 -0
- sglang/srt/speculative/spec_info.py +5 -0
- sglang/srt/utils.py +15 -2
- sglang/version.py +1 -1
- {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post4.dist-info}/METADATA +9 -8
- {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post4.dist-info}/RECORD +63 -39
- {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post4.dist-info}/WHEEL +1 -1
- {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post4.dist-info}/LICENSE +0 -0
- {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post4.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,5 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
1
3
|
# Copyright 2023-2024 SGLang Team
|
2
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
5
|
# you may not use this file except in compliance with the License.
|
@@ -29,7 +31,7 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
|
|
29
31
|
|
30
32
|
import dataclasses
|
31
33
|
import logging
|
32
|
-
from typing import List, Optional, Set, Tuple, Union
|
34
|
+
from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Union
|
33
35
|
|
34
36
|
import numpy as np
|
35
37
|
import torch
|
@@ -47,6 +49,10 @@ from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
|
47
49
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
48
50
|
from sglang.srt.server_args import ServerArgs
|
49
51
|
|
52
|
+
if TYPE_CHECKING:
|
53
|
+
from sglang.srt.speculative.spec_info import SpecInfo, SpeculativeAlgorithm
|
54
|
+
|
55
|
+
|
50
56
|
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
|
51
57
|
|
52
58
|
# Put some global args for easy access
|
@@ -565,9 +571,13 @@ class ScheduleBatch:
|
|
565
571
|
# Has grammar
|
566
572
|
has_grammar: bool = False
|
567
573
|
|
568
|
-
#
|
574
|
+
# Device
|
569
575
|
device: str = "cuda"
|
570
576
|
|
577
|
+
# Speculative decoding
|
578
|
+
spec_algorithm: SpeculativeAlgorithm = None
|
579
|
+
spec_info: Optional[SpecInfo] = None
|
580
|
+
|
571
581
|
@classmethod
|
572
582
|
def init_new(
|
573
583
|
cls,
|
@@ -577,6 +587,7 @@ class ScheduleBatch:
|
|
577
587
|
tree_cache: BasePrefixCache,
|
578
588
|
model_config: ModelConfig,
|
579
589
|
enable_overlap: bool,
|
590
|
+
spec_algorithm: SpeculativeAlgorithm,
|
580
591
|
):
|
581
592
|
return cls(
|
582
593
|
reqs=reqs,
|
@@ -589,6 +600,7 @@ class ScheduleBatch:
|
|
589
600
|
has_stream=any(req.stream for req in reqs),
|
590
601
|
has_grammar=any(req.grammar for req in reqs),
|
591
602
|
device=req_to_token_pool.device,
|
603
|
+
spec_algorithm=spec_algorithm,
|
592
604
|
)
|
593
605
|
|
594
606
|
def batch_size(self):
|
@@ -998,6 +1010,8 @@ class ScheduleBatch:
|
|
998
1010
|
|
999
1011
|
def prepare_for_decode(self):
|
1000
1012
|
self.forward_mode = ForwardMode.DECODE
|
1013
|
+
if self.spec_algorithm.is_eagle():
|
1014
|
+
return
|
1001
1015
|
|
1002
1016
|
self.input_ids = self.output_ids
|
1003
1017
|
self.output_ids = None
|
@@ -1103,6 +1117,9 @@ class ScheduleBatch:
|
|
1103
1117
|
self.has_stream |= other.has_stream
|
1104
1118
|
self.has_grammar |= other.has_grammar
|
1105
1119
|
|
1120
|
+
if self.spec_info:
|
1121
|
+
self.spec_info.merge_batch(other.spec_info)
|
1122
|
+
|
1106
1123
|
def get_model_worker_batch(self):
|
1107
1124
|
if self.forward_mode.is_decode() or self.forward_mode.is_idle():
|
1108
1125
|
extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None
|
@@ -1144,6 +1161,8 @@ class ScheduleBatch:
|
|
1144
1161
|
lora_paths=[req.lora_path for req in self.reqs],
|
1145
1162
|
sampling_info=self.sampling_info,
|
1146
1163
|
input_embeds=self.input_embeds,
|
1164
|
+
spec_algorithm=self.spec_algorithm,
|
1165
|
+
spec_info=self.spec_info,
|
1147
1166
|
)
|
1148
1167
|
|
1149
1168
|
def copy(self):
|
@@ -1155,6 +1174,7 @@ class ScheduleBatch:
|
|
1155
1174
|
out_cache_loc=self.out_cache_loc,
|
1156
1175
|
return_logprob=self.return_logprob,
|
1157
1176
|
decoding_reqs=self.decoding_reqs,
|
1177
|
+
spec_algorithm=self.spec_algorithm,
|
1158
1178
|
)
|
1159
1179
|
|
1160
1180
|
def __str__(self):
|
@@ -1214,6 +1234,10 @@ class ModelWorkerBatch:
|
|
1214
1234
|
# The input Embeds
|
1215
1235
|
input_embeds: Optional[torch.tensor] = None
|
1216
1236
|
|
1237
|
+
# Speculative decoding
|
1238
|
+
spec_algorithm: SpeculativeAlgorithm = None
|
1239
|
+
spec_info: Optional[SpecInfo] = None
|
1240
|
+
|
1217
1241
|
|
1218
1242
|
@triton.jit
|
1219
1243
|
def write_req_to_token_pool_triton(
|
@@ -18,7 +18,7 @@ import random
|
|
18
18
|
from collections import defaultdict
|
19
19
|
from contextlib import contextmanager
|
20
20
|
from enum import Enum, auto
|
21
|
-
from typing import Dict, List, Optional
|
21
|
+
from typing import Dict, List, Optional, Set, Union
|
22
22
|
|
23
23
|
import torch
|
24
24
|
|
@@ -50,13 +50,26 @@ IN_BATCH_PREFIX_CACHING_DEPRIORITIZE_THRESHOLD = int(
|
|
50
50
|
)
|
51
51
|
|
52
52
|
|
53
|
+
class CacheAwarePolicy(Enum):
|
54
|
+
"""Scheduling policies that are aware of the tree cache."""
|
55
|
+
|
56
|
+
LPM = "lpm" # longest prefix match
|
57
|
+
DFS_WEIGHT = "dfs-weight" # depth-first search weighting
|
58
|
+
|
59
|
+
|
60
|
+
class CacheAgnosticPolicy(Enum):
|
61
|
+
"""Scheduling policies that are not aware of the tree cache."""
|
62
|
+
|
63
|
+
FCFS = "fcfs" # first come first serve
|
64
|
+
LOF = "lof" # longest output first
|
65
|
+
RANDOM = "random"
|
66
|
+
|
67
|
+
|
53
68
|
class SchedulePolicy:
|
54
|
-
|
55
|
-
if tree_cache.disable and policy in ["lpm", "dfs-weight"]:
|
56
|
-
# LPM and DFS-weight is meaningless when the tree cache is disabled.
|
57
|
-
policy = "fcfs"
|
69
|
+
Policy = Union[CacheAwarePolicy, CacheAgnosticPolicy]
|
58
70
|
|
59
|
-
|
71
|
+
def __init__(self, policy: str, tree_cache: BasePrefixCache):
|
72
|
+
self.policy = self._validate_and_adjust_policy(policy, tree_cache)
|
60
73
|
self.tree_cache = tree_cache
|
61
74
|
|
62
75
|
# It is used to find the matching prefix for in-batch prefix caching.
|
@@ -64,110 +77,166 @@ class SchedulePolicy:
|
|
64
77
|
req_to_token_pool=None, token_to_kv_pool=None, disable=False
|
65
78
|
)
|
66
79
|
|
67
|
-
def calc_priority(self, waiting_queue: List[Req]):
|
68
|
-
|
69
|
-
# Turn off the expensive prefix matching and sorting when the #queue is large.
|
70
|
-
policy = "fcfs"
|
71
|
-
else:
|
72
|
-
policy = self.policy
|
80
|
+
def calc_priority(self, waiting_queue: List[Req]) -> bool:
|
81
|
+
policy = self._determine_active_policy(waiting_queue)
|
73
82
|
|
74
|
-
# Compute matched prefix length
|
75
83
|
prefix_computed = False
|
76
|
-
if policy
|
77
|
-
|
78
|
-
temporary_deprioritized =
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
# NOTE: the prefix_indices must always be aligned with last_node
|
85
|
-
r.prefix_indices, r.last_node = self.tree_cache.match_prefix(
|
86
|
-
rid=r.rid, key=prefix_ids
|
84
|
+
if isinstance(policy, CacheAwarePolicy):
|
85
|
+
prefix_computed = True
|
86
|
+
temporary_deprioritized = self._compute_prefix_matches(
|
87
|
+
waiting_queue, policy
|
88
|
+
)
|
89
|
+
if policy == CacheAwarePolicy.LPM:
|
90
|
+
SchedulePolicy._sort_by_longest_prefix(
|
91
|
+
waiting_queue, temporary_deprioritized
|
87
92
|
)
|
93
|
+
elif policy == CacheAwarePolicy.DFS_WEIGHT:
|
94
|
+
SchedulePolicy._sort_by_dfs_weight(waiting_queue, self.tree_cache)
|
95
|
+
else:
|
96
|
+
raise ValueError(f"Unknown CacheAware Policy: {policy=}")
|
97
|
+
else:
|
98
|
+
if policy == CacheAgnosticPolicy.FCFS:
|
99
|
+
pass
|
100
|
+
elif policy == CacheAgnosticPolicy.LOF:
|
101
|
+
SchedulePolicy._sort_by_longest_output(waiting_queue)
|
102
|
+
elif policy == CacheAgnosticPolicy.RANDOM:
|
103
|
+
SchedulePolicy._sort_randomly(waiting_queue)
|
104
|
+
else:
|
105
|
+
raise ValueError(f"Unknown CacheAgnostic Policy: {policy=}")
|
88
106
|
|
89
|
-
|
90
|
-
# If there are more than 1 request that have small matching prefix from
|
91
|
-
# existing cache, but all those requests share the same prefix, we prefer
|
92
|
-
# to schedule only one of them so that we can increase the cache hit rate.
|
93
|
-
# We prefer to set IN_BATCH_PREFIX_CACHING_CHECK_THRESHOLD > 0 because too small
|
94
|
-
# threshold means we cannot use in-batch prefix caching for short prefixes.
|
95
|
-
# It is kind of common when the engine is long running (e.g., imagine the prefix "the").
|
96
|
-
if len(r.prefix_indices) <= IN_BATCH_PREFIX_CACHING_CHECK_THRESHOLD:
|
97
|
-
in_batch_matching_prefixes, _ = (
|
98
|
-
self.waiting_queue_radix_tree.match_prefix(
|
99
|
-
rid=r.rid, key=prefix_ids
|
100
|
-
)
|
101
|
-
)
|
102
|
-
if (
|
103
|
-
len(in_batch_matching_prefixes)
|
104
|
-
>= IN_BATCH_PREFIX_CACHING_DEPRIORITIZE_THRESHOLD
|
105
|
-
):
|
106
|
-
temporary_deprioritized.add(r.rid)
|
107
|
-
else:
|
108
|
-
# Insert with a dummy key
|
109
|
-
self.waiting_queue_radix_tree.insert(
|
110
|
-
prefix_ids, torch.empty(len(prefix_ids), dtype=torch.bool)
|
111
|
-
)
|
107
|
+
return prefix_computed
|
112
108
|
|
113
|
-
|
109
|
+
def _determine_active_policy(self, waiting_queue: List[Req]) -> Policy:
|
110
|
+
if len(waiting_queue) > 128 and self.policy == CacheAwarePolicy.LPM:
|
111
|
+
# Turn off the expensive prefix matching and sorting when the #queue is large.
|
112
|
+
return CacheAgnosticPolicy.FCFS
|
113
|
+
return self.policy
|
114
|
+
|
115
|
+
def _validate_and_adjust_policy(
|
116
|
+
self, policy: str, tree_cache: BasePrefixCache
|
117
|
+
) -> Policy:
|
118
|
+
"""
|
119
|
+
Validates the policy and adjusts it if necessary based on tree cache settings.
|
120
|
+
"""
|
121
|
+
try:
|
122
|
+
policy_enum = CacheAwarePolicy(policy)
|
123
|
+
if tree_cache.disable:
|
124
|
+
# If tree_cache is disabled, using CacheAgnosticPolicy policy
|
125
|
+
return CacheAgnosticPolicy.FCFS
|
126
|
+
return policy_enum
|
127
|
+
except ValueError:
|
128
|
+
try:
|
129
|
+
return CacheAgnosticPolicy(policy)
|
130
|
+
except ValueError:
|
131
|
+
raise ValueError(f"Unknown schedule_policy: {policy=}")
|
132
|
+
|
133
|
+
def _compute_prefix_matches(
|
134
|
+
self, waiting_queue: List[Req], policy: CacheAwarePolicy
|
135
|
+
) -> Set[int]:
|
136
|
+
"""
|
137
|
+
Computes and caches the matching prefixes for requests in the waiting queue,
|
138
|
+
and handles in-batch prefix caching logic.
|
139
|
+
"""
|
140
|
+
temporary_deprioritized: Set[int] = set()
|
141
|
+
self.waiting_queue_radix_tree.reset()
|
142
|
+
|
143
|
+
for r in waiting_queue:
|
144
|
+
prefix_ids = r.adjust_max_prefix_ids()
|
145
|
+
|
146
|
+
# NOTE: the prefix_indices must always be aligned with last_node
|
147
|
+
r.prefix_indices, r.last_node = self.tree_cache.match_prefix(
|
148
|
+
rid=r.rid, key=prefix_ids
|
149
|
+
)
|
114
150
|
|
115
|
-
|
116
|
-
#
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
151
|
+
# NOTE(sang): This logic is for in-batch prefix caching;
|
152
|
+
# If there are more than 1 request that have small matching prefix from
|
153
|
+
# existing cache, but all those requests share the same prefix, we prefer
|
154
|
+
# to schedule only one of them so that we can increase the cache hit rate.
|
155
|
+
# We prefer to set IN_BATCH_PREFIX_CACHING_CHECK_THRESHOLD > 0 because too small
|
156
|
+
# threshold means we cannot use in-batch prefix caching for short prefixes.
|
157
|
+
# It is kind of common when the engine is long running (e.g., imagine the prefix "the").
|
158
|
+
if len(r.prefix_indices) <= IN_BATCH_PREFIX_CACHING_CHECK_THRESHOLD:
|
159
|
+
in_batch_matching_prefixes, _ = (
|
160
|
+
self.waiting_queue_radix_tree.match_prefix(
|
161
|
+
rid=r.rid, key=prefix_ids
|
162
|
+
)
|
122
163
|
)
|
164
|
+
if (
|
165
|
+
len(in_batch_matching_prefixes)
|
166
|
+
>= IN_BATCH_PREFIX_CACHING_DEPRIORITIZE_THRESHOLD
|
167
|
+
):
|
168
|
+
temporary_deprioritized.add(r.rid)
|
169
|
+
else:
|
170
|
+
# Insert with a dummy key
|
171
|
+
self.waiting_queue_radix_tree.insert(
|
172
|
+
prefix_ids, torch.empty(len(prefix_ids), dtype=torch.bool)
|
173
|
+
)
|
174
|
+
return temporary_deprioritized
|
175
|
+
|
176
|
+
@staticmethod
|
177
|
+
def _sort_by_longest_prefix(
|
178
|
+
waiting_queue: List[Req], temporary_deprioritized: Set[int]
|
179
|
+
) -> None:
|
180
|
+
"""Sorts the waiting queue based on the longest prefix match."""
|
181
|
+
waiting_queue.sort(
|
182
|
+
key=lambda r: (
|
183
|
+
-len(r.prefix_indices)
|
184
|
+
if r.rid not in temporary_deprioritized
|
185
|
+
else float("inf")
|
123
186
|
)
|
124
|
-
|
125
|
-
# first come first serve
|
126
|
-
pass
|
127
|
-
elif policy == "lof":
|
128
|
-
# longest output first
|
129
|
-
waiting_queue.sort(key=lambda x: -x.sampling_params.max_new_tokens)
|
130
|
-
elif policy == "random":
|
131
|
-
random.shuffle(waiting_queue)
|
132
|
-
elif policy == "dfs-weight":
|
133
|
-
# Experimental policy based on custom weights
|
134
|
-
last_node_to_reqs = defaultdict(list)
|
135
|
-
for req in waiting_queue:
|
136
|
-
last_node_to_reqs[req.last_node].append(req)
|
137
|
-
|
138
|
-
node_to_weight = defaultdict(int)
|
139
|
-
for node in last_node_to_reqs:
|
140
|
-
node_to_weight[node] = len(last_node_to_reqs[node])
|
141
|
-
self.calc_weight(self.tree_cache.root_node, node_to_weight)
|
142
|
-
|
143
|
-
waiting_queue.clear()
|
144
|
-
self.get_dfs_priority(
|
145
|
-
self.tree_cache.root_node,
|
146
|
-
node_to_weight,
|
147
|
-
last_node_to_reqs,
|
148
|
-
waiting_queue,
|
149
|
-
)
|
150
|
-
else:
|
151
|
-
raise ValueError(f"Unknown schedule_policy: {policy=}")
|
187
|
+
)
|
152
188
|
|
153
|
-
|
189
|
+
@staticmethod
|
190
|
+
def _sort_by_dfs_weight(
|
191
|
+
waiting_queue: List[Req], tree_cache: BasePrefixCache
|
192
|
+
) -> None:
|
193
|
+
"""Sorts the waiting queue based on a depth-first search weighting."""
|
194
|
+
last_node_to_reqs = defaultdict(list)
|
195
|
+
for req in waiting_queue:
|
196
|
+
last_node_to_reqs[req.last_node].append(req)
|
197
|
+
|
198
|
+
node_to_weight = defaultdict(int)
|
199
|
+
for node in last_node_to_reqs:
|
200
|
+
node_to_weight[node] = len(last_node_to_reqs[node])
|
201
|
+
SchedulePolicy._calc_weight(tree_cache.root_node, node_to_weight)
|
202
|
+
|
203
|
+
waiting_queue.clear()
|
204
|
+
SchedulePolicy._get_dfs_priority(
|
205
|
+
tree_cache.root_node,
|
206
|
+
node_to_weight,
|
207
|
+
last_node_to_reqs,
|
208
|
+
waiting_queue,
|
209
|
+
)
|
210
|
+
|
211
|
+
@staticmethod
|
212
|
+
def _sort_by_longest_output(waiting_queue: List[Req]) -> None:
|
213
|
+
"""Sorts the waiting queue based on the longest output (max_new_tokens)."""
|
214
|
+
waiting_queue.sort(key=lambda x: -x.sampling_params.max_new_tokens)
|
154
215
|
|
155
|
-
|
216
|
+
@staticmethod
|
217
|
+
def _sort_randomly(waiting_queue: List[Req]) -> None:
|
218
|
+
"""Shuffles the waiting queue randomly."""
|
219
|
+
random.shuffle(waiting_queue)
|
220
|
+
|
221
|
+
@staticmethod
|
222
|
+
def _calc_weight(cur_node: TreeNode, node_to_weight: Dict[TreeNode, int]) -> None:
|
156
223
|
for child in cur_node.children.values():
|
157
|
-
|
224
|
+
SchedulePolicy._calc_weight(child, node_to_weight)
|
158
225
|
node_to_weight[cur_node] += node_to_weight[child]
|
159
226
|
|
160
|
-
|
161
|
-
|
227
|
+
@staticmethod
|
228
|
+
def _get_dfs_priority(
|
162
229
|
cur_node: TreeNode,
|
163
230
|
node_to_priority: Dict[TreeNode, int],
|
164
231
|
last_node_to_reqs: Dict[TreeNode, List[Req]],
|
165
232
|
q: List,
|
166
|
-
):
|
233
|
+
) -> None:
|
167
234
|
childs = [child for child in cur_node.children.values()]
|
168
235
|
childs.sort(key=lambda x: -node_to_priority[x])
|
169
236
|
for child in childs:
|
170
|
-
|
237
|
+
SchedulePolicy._get_dfs_priority(
|
238
|
+
child, node_to_priority, last_node_to_reqs, q
|
239
|
+
)
|
171
240
|
q.extend(last_node_to_reqs[cur_node])
|
172
241
|
|
173
242
|
|
sglang/srt/managers/scheduler.py
CHANGED
@@ -76,6 +76,7 @@ from sglang.srt.mem_cache.radix_cache import RadixCache
|
|
76
76
|
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
|
77
77
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
78
78
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
79
|
+
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
79
80
|
from sglang.srt.utils import (
|
80
81
|
broadcast_pyobj,
|
81
82
|
configure_logger,
|
@@ -116,6 +117,14 @@ class Scheduler:
|
|
116
117
|
self.enable_overlap = not server_args.disable_overlap_schedule
|
117
118
|
self.skip_tokenizer_init = server_args.skip_tokenizer_init
|
118
119
|
self.enable_metrics = server_args.enable_metrics
|
120
|
+
self.spec_algorithm = SpeculativeAlgorithm.from_string(
|
121
|
+
server_args.speculative_algorithm
|
122
|
+
)
|
123
|
+
self.decode_mem_cache_buf_multiplier = (
|
124
|
+
self.server_args.speculative_num_draft_tokens
|
125
|
+
if not self.spec_algorithm.is_none()
|
126
|
+
else 1
|
127
|
+
)
|
119
128
|
|
120
129
|
# Init inter-process communication
|
121
130
|
context = zmq.Context(2)
|
@@ -199,6 +208,21 @@ class Scheduler:
|
|
199
208
|
nccl_port=port_args.nccl_port,
|
200
209
|
)
|
201
210
|
|
211
|
+
# Launch worker for speculative decoding if need
|
212
|
+
if self.spec_algorithm.is_eagle():
|
213
|
+
from sglang.srt.speculative.eagle_worker import EAGLEWorker
|
214
|
+
|
215
|
+
self.draft_worker = EAGLEWorker(
|
216
|
+
gpu_id=gpu_id,
|
217
|
+
tp_rank=tp_rank,
|
218
|
+
server_args=server_args,
|
219
|
+
nccl_port=port_args.nccl_port,
|
220
|
+
target_worker=self.tp_worker,
|
221
|
+
dp_rank=dp_rank,
|
222
|
+
)
|
223
|
+
else:
|
224
|
+
self.draft_worker = None
|
225
|
+
|
202
226
|
# Get token and memory info from the model worker
|
203
227
|
(
|
204
228
|
self.max_total_num_tokens,
|
@@ -855,6 +879,7 @@ class Scheduler:
|
|
855
879
|
self.tree_cache,
|
856
880
|
self.model_config,
|
857
881
|
self.enable_overlap,
|
882
|
+
self.spec_algorithm,
|
858
883
|
)
|
859
884
|
new_batch.prepare_for_extend()
|
860
885
|
|
@@ -888,11 +913,15 @@ class Scheduler:
|
|
888
913
|
return None
|
889
914
|
|
890
915
|
# Check if decode out of memory
|
891
|
-
if not batch.check_decode_mem() or (
|
916
|
+
if not batch.check_decode_mem(self.decode_mem_cache_buf_multiplier) or (
|
917
|
+
test_retract and batch.batch_size() > 10
|
918
|
+
):
|
892
919
|
old_ratio = self.new_token_ratio
|
893
920
|
|
894
921
|
retracted_reqs, new_token_ratio = batch.retract_decode()
|
895
922
|
self.new_token_ratio = new_token_ratio
|
923
|
+
if self.draft_worker:
|
924
|
+
self.draft_worker.finish_request(retracted_reqs)
|
896
925
|
|
897
926
|
logger.info(
|
898
927
|
"Decode out of memory happened. "
|
@@ -926,11 +955,17 @@ class Scheduler:
|
|
926
955
|
self.forward_ct += 1
|
927
956
|
|
928
957
|
if self.is_generation:
|
929
|
-
model_worker_batch = batch.get_model_worker_batch()
|
930
958
|
if batch.forward_mode.is_decode() or batch.extend_num_tokens != 0:
|
931
|
-
|
932
|
-
model_worker_batch
|
933
|
-
|
959
|
+
if self.spec_algorithm.is_none():
|
960
|
+
model_worker_batch = batch.get_model_worker_batch()
|
961
|
+
logits_output, next_token_ids = (
|
962
|
+
self.tp_worker.forward_batch_generation(model_worker_batch)
|
963
|
+
)
|
964
|
+
else:
|
965
|
+
logits_output, next_token_ids, model_worker_batch, spec_info = (
|
966
|
+
self.draft_worker.forward_batch_speculative_generation(batch)
|
967
|
+
)
|
968
|
+
batch.spec_info = spec_info
|
934
969
|
elif batch.forward_mode.is_idle():
|
935
970
|
model_worker_batch = batch.get_model_worker_batch()
|
936
971
|
self.tp_worker.forward_batch_idle(model_worker_batch)
|
@@ -974,12 +1009,10 @@ class Scheduler:
|
|
974
1009
|
logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
|
975
1010
|
else:
|
976
1011
|
# Move next_token_ids and logprobs to cpu
|
1012
|
+
next_token_ids = next_token_ids.tolist()
|
977
1013
|
if batch.return_logprob:
|
978
1014
|
logits_output.next_token_logprobs = (
|
979
|
-
logits_output.next_token_logprobs
|
980
|
-
torch.arange(len(next_token_ids), device=self.device),
|
981
|
-
next_token_ids,
|
982
|
-
].tolist()
|
1015
|
+
logits_output.next_token_logprobs.tolist()
|
983
1016
|
)
|
984
1017
|
logits_output.input_token_logprobs = (
|
985
1018
|
logits_output.input_token_logprobs.tolist()
|
@@ -987,7 +1020,6 @@ class Scheduler:
|
|
987
1020
|
logits_output.normalized_prompt_logprobs = (
|
988
1021
|
logits_output.normalized_prompt_logprobs.tolist()
|
989
1022
|
)
|
990
|
-
next_token_ids = next_token_ids.tolist()
|
991
1023
|
|
992
1024
|
# Check finish conditions
|
993
1025
|
logprob_pt = 0
|
@@ -1064,13 +1096,9 @@ class Scheduler:
|
|
1064
1096
|
logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
|
1065
1097
|
next_token_logprobs = logits_output.next_token_logprobs
|
1066
1098
|
else:
|
1067
|
-
# Move next_token_ids and logprobs to cpu
|
1068
|
-
if batch.return_logprob:
|
1069
|
-
next_token_logprobs = logits_output.next_token_logprobs[
|
1070
|
-
torch.arange(len(next_token_ids), device=self.device),
|
1071
|
-
next_token_ids,
|
1072
|
-
].tolist()
|
1073
1099
|
next_token_ids = next_token_ids.tolist()
|
1100
|
+
if batch.return_logprob:
|
1101
|
+
next_token_logprobs = logits_output.next_token_logprobs.tolist()
|
1074
1102
|
|
1075
1103
|
self.token_to_kv_pool.free_group_begin()
|
1076
1104
|
|
@@ -1084,7 +1112,10 @@ class Scheduler:
|
|
1084
1112
|
self.token_to_kv_pool.free(batch.out_cache_loc[i : i + 1])
|
1085
1113
|
continue
|
1086
1114
|
|
1087
|
-
|
1115
|
+
if batch.spec_algorithm.is_none():
|
1116
|
+
# speculative worker will solve the output_ids in speculative decoding
|
1117
|
+
req.output_ids.append(next_token_id)
|
1118
|
+
|
1088
1119
|
req.check_finished()
|
1089
1120
|
|
1090
1121
|
if req.finished():
|
@@ -1095,10 +1126,10 @@ class Scheduler:
|
|
1095
1126
|
req.output_token_logprobs_idx.append(next_token_id)
|
1096
1127
|
if req.top_logprobs_num > 0:
|
1097
1128
|
req.output_top_logprobs_val.append(
|
1098
|
-
logits_output.
|
1129
|
+
logits_output.next_token_top_logprobs_val[i]
|
1099
1130
|
)
|
1100
1131
|
req.output_top_logprobs_idx.append(
|
1101
|
-
logits_output.
|
1132
|
+
logits_output.next_token_top_logprobs_idx[i]
|
1102
1133
|
)
|
1103
1134
|
|
1104
1135
|
if req.grammar is not None:
|
@@ -1200,8 +1231,9 @@ class Scheduler:
|
|
1200
1231
|
req.output_top_logprobs_idx.extend(
|
1201
1232
|
output.input_top_logprobs_idx[i][-req.last_update_decode_tokens :]
|
1202
1233
|
)
|
1203
|
-
|
1204
|
-
req.
|
1234
|
+
|
1235
|
+
req.output_top_logprobs_val.append(output.next_token_top_logprobs_val[i])
|
1236
|
+
req.output_top_logprobs_idx.append(output.next_token_top_logprobs_idx[i])
|
1205
1237
|
|
1206
1238
|
return num_input_logprobs
|
1207
1239
|
|
@@ -1258,6 +1290,9 @@ class Scheduler:
|
|
1258
1290
|
# If not stream, we still want to output some tokens to get the benefit of incremental decoding.
|
1259
1291
|
or (not req.stream and len(req.output_ids) % 50 == 0)
|
1260
1292
|
):
|
1293
|
+
if self.draft_worker and req.finished():
|
1294
|
+
self.draft_worker.finish_request(req)
|
1295
|
+
|
1261
1296
|
rids.append(req.rid)
|
1262
1297
|
finished_reasons.append(
|
1263
1298
|
req.finished_reason.to_json() if req.finished_reason else None
|
@@ -1329,11 +1364,11 @@ class Scheduler:
|
|
1329
1364
|
embeddings = []
|
1330
1365
|
prompt_tokens = []
|
1331
1366
|
for req in reqs:
|
1332
|
-
|
1333
|
-
|
1334
|
-
|
1335
|
-
|
1336
|
-
|
1367
|
+
if req.finished():
|
1368
|
+
rids.append(req.rid)
|
1369
|
+
finished_reasons.append(req.finished_reason.to_json())
|
1370
|
+
embeddings.append(req.embedding)
|
1371
|
+
prompt_tokens.append(len(req.origin_input_ids))
|
1337
1372
|
self.send_to_detokenizer.send_pyobj(
|
1338
1373
|
BatchEmbeddingOut(rids, finished_reasons, embeddings, prompt_tokens)
|
1339
1374
|
)
|
@@ -1389,6 +1424,7 @@ class Scheduler:
|
|
1389
1424
|
self.tree_cache,
|
1390
1425
|
self.model_config,
|
1391
1426
|
self.enable_overlap,
|
1427
|
+
self.spec_algorithm,
|
1392
1428
|
)
|
1393
1429
|
idle_batch.prepare_for_idle()
|
1394
1430
|
return idle_batch
|