sglang 0.4.1.post3__py3-none-any.whl → 0.4.1.post5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +2 -0
- sglang/bench_serving.py +18 -1
- sglang/lang/interpreter.py +71 -1
- sglang/lang/ir.py +2 -0
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/chatglm.py +78 -0
- sglang/srt/configs/dbrx.py +279 -0
- sglang/srt/configs/model_config.py +1 -1
- sglang/srt/hf_transformers_utils.py +9 -14
- sglang/srt/layers/attention/__init__.py +22 -6
- sglang/srt/layers/attention/double_sparsity_backend.py +0 -52
- sglang/srt/layers/attention/flashinfer_backend.py +215 -83
- sglang/srt/layers/attention/torch_native_backend.py +1 -38
- sglang/srt/layers/attention/triton_backend.py +20 -11
- sglang/srt/layers/attention/triton_ops/decode_attention.py +4 -0
- sglang/srt/layers/linear.py +159 -55
- sglang/srt/layers/logits_processor.py +170 -215
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=2560,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=2560,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=320,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=320,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +198 -29
- sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -7
- sglang/srt/layers/parameter.py +431 -0
- sglang/srt/layers/quantization/__init__.py +3 -2
- sglang/srt/layers/quantization/fp8.py +3 -3
- sglang/srt/layers/quantization/modelopt_quant.py +174 -0
- sglang/srt/layers/sampler.py +57 -21
- sglang/srt/layers/torchao_utils.py +17 -3
- sglang/srt/layers/vocab_parallel_embedding.py +1 -1
- sglang/srt/managers/cache_controller.py +307 -0
- sglang/srt/managers/data_parallel_controller.py +2 -0
- sglang/srt/managers/io_struct.py +1 -2
- sglang/srt/managers/schedule_batch.py +33 -3
- sglang/srt/managers/schedule_policy.py +159 -90
- sglang/srt/managers/scheduler.py +68 -28
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +27 -21
- sglang/srt/managers/tp_worker.py +16 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
- sglang/srt/mem_cache/memory_pool.py +206 -1
- sglang/srt/metrics/collector.py +22 -30
- sglang/srt/model_executor/cuda_graph_runner.py +129 -77
- sglang/srt/model_executor/forward_batch_info.py +51 -21
- sglang/srt/model_executor/model_runner.py +72 -64
- sglang/srt/models/chatglm.py +1 -1
- sglang/srt/models/dbrx.py +1 -1
- sglang/srt/models/deepseek_v2.py +34 -7
- sglang/srt/models/grok.py +109 -29
- sglang/srt/models/llama.py +9 -2
- sglang/srt/openai_api/adapter.py +0 -17
- sglang/srt/openai_api/protocol.py +3 -3
- sglang/srt/sampling/sampling_batch_info.py +22 -0
- sglang/srt/sampling/sampling_params.py +9 -1
- sglang/srt/server.py +20 -13
- sglang/srt/server_args.py +120 -58
- sglang/srt/speculative/build_eagle_tree.py +347 -0
- sglang/srt/speculative/eagle_utils.py +626 -0
- sglang/srt/speculative/eagle_worker.py +184 -0
- sglang/srt/speculative/spec_info.py +5 -0
- sglang/srt/utils.py +47 -7
- sglang/test/test_programs.py +23 -1
- sglang/test/test_utils.py +36 -7
- sglang/version.py +1 -1
- {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post5.dist-info}/METADATA +12 -12
- {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post5.dist-info}/RECORD +86 -57
- {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post5.dist-info}/WHEEL +1 -1
- {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post5.dist-info}/LICENSE +0 -0
- {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post5.dist-info}/top_level.txt +0 -0
@@ -18,7 +18,7 @@ import random
|
|
18
18
|
from collections import defaultdict
|
19
19
|
from contextlib import contextmanager
|
20
20
|
from enum import Enum, auto
|
21
|
-
from typing import Dict, List, Optional
|
21
|
+
from typing import Dict, List, Optional, Set, Union
|
22
22
|
|
23
23
|
import torch
|
24
24
|
|
@@ -50,13 +50,26 @@ IN_BATCH_PREFIX_CACHING_DEPRIORITIZE_THRESHOLD = int(
|
|
50
50
|
)
|
51
51
|
|
52
52
|
|
53
|
+
class CacheAwarePolicy(Enum):
|
54
|
+
"""Scheduling policies that are aware of the tree cache."""
|
55
|
+
|
56
|
+
LPM = "lpm" # longest prefix match
|
57
|
+
DFS_WEIGHT = "dfs-weight" # depth-first search weighting
|
58
|
+
|
59
|
+
|
60
|
+
class CacheAgnosticPolicy(Enum):
|
61
|
+
"""Scheduling policies that are not aware of the tree cache."""
|
62
|
+
|
63
|
+
FCFS = "fcfs" # first come first serve
|
64
|
+
LOF = "lof" # longest output first
|
65
|
+
RANDOM = "random"
|
66
|
+
|
67
|
+
|
53
68
|
class SchedulePolicy:
|
54
|
-
|
55
|
-
if tree_cache.disable and policy in ["lpm", "dfs-weight"]:
|
56
|
-
# LPM and DFS-weight is meaningless when the tree cache is disabled.
|
57
|
-
policy = "fcfs"
|
69
|
+
Policy = Union[CacheAwarePolicy, CacheAgnosticPolicy]
|
58
70
|
|
59
|
-
|
71
|
+
def __init__(self, policy: str, tree_cache: BasePrefixCache):
|
72
|
+
self.policy = self._validate_and_adjust_policy(policy, tree_cache)
|
60
73
|
self.tree_cache = tree_cache
|
61
74
|
|
62
75
|
# It is used to find the matching prefix for in-batch prefix caching.
|
@@ -64,110 +77,166 @@ class SchedulePolicy:
|
|
64
77
|
req_to_token_pool=None, token_to_kv_pool=None, disable=False
|
65
78
|
)
|
66
79
|
|
67
|
-
def calc_priority(self, waiting_queue: List[Req]):
|
68
|
-
|
69
|
-
# Turn off the expensive prefix matching and sorting when the #queue is large.
|
70
|
-
policy = "fcfs"
|
71
|
-
else:
|
72
|
-
policy = self.policy
|
80
|
+
def calc_priority(self, waiting_queue: List[Req]) -> bool:
|
81
|
+
policy = self._determine_active_policy(waiting_queue)
|
73
82
|
|
74
|
-
# Compute matched prefix length
|
75
83
|
prefix_computed = False
|
76
|
-
if policy
|
77
|
-
|
78
|
-
temporary_deprioritized =
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
# NOTE: the prefix_indices must always be aligned with last_node
|
85
|
-
r.prefix_indices, r.last_node = self.tree_cache.match_prefix(
|
86
|
-
rid=r.rid, key=prefix_ids
|
84
|
+
if isinstance(policy, CacheAwarePolicy):
|
85
|
+
prefix_computed = True
|
86
|
+
temporary_deprioritized = self._compute_prefix_matches(
|
87
|
+
waiting_queue, policy
|
88
|
+
)
|
89
|
+
if policy == CacheAwarePolicy.LPM:
|
90
|
+
SchedulePolicy._sort_by_longest_prefix(
|
91
|
+
waiting_queue, temporary_deprioritized
|
87
92
|
)
|
93
|
+
elif policy == CacheAwarePolicy.DFS_WEIGHT:
|
94
|
+
SchedulePolicy._sort_by_dfs_weight(waiting_queue, self.tree_cache)
|
95
|
+
else:
|
96
|
+
raise ValueError(f"Unknown CacheAware Policy: {policy=}")
|
97
|
+
else:
|
98
|
+
if policy == CacheAgnosticPolicy.FCFS:
|
99
|
+
pass
|
100
|
+
elif policy == CacheAgnosticPolicy.LOF:
|
101
|
+
SchedulePolicy._sort_by_longest_output(waiting_queue)
|
102
|
+
elif policy == CacheAgnosticPolicy.RANDOM:
|
103
|
+
SchedulePolicy._sort_randomly(waiting_queue)
|
104
|
+
else:
|
105
|
+
raise ValueError(f"Unknown CacheAgnostic Policy: {policy=}")
|
88
106
|
|
89
|
-
|
90
|
-
# If there are more than 1 request that have small matching prefix from
|
91
|
-
# existing cache, but all those requests share the same prefix, we prefer
|
92
|
-
# to schedule only one of them so that we can increase the cache hit rate.
|
93
|
-
# We prefer to set IN_BATCH_PREFIX_CACHING_CHECK_THRESHOLD > 0 because too small
|
94
|
-
# threshold means we cannot use in-batch prefix caching for short prefixes.
|
95
|
-
# It is kind of common when the engine is long running (e.g., imagine the prefix "the").
|
96
|
-
if len(r.prefix_indices) <= IN_BATCH_PREFIX_CACHING_CHECK_THRESHOLD:
|
97
|
-
in_batch_matching_prefixes, _ = (
|
98
|
-
self.waiting_queue_radix_tree.match_prefix(
|
99
|
-
rid=r.rid, key=prefix_ids
|
100
|
-
)
|
101
|
-
)
|
102
|
-
if (
|
103
|
-
len(in_batch_matching_prefixes)
|
104
|
-
>= IN_BATCH_PREFIX_CACHING_DEPRIORITIZE_THRESHOLD
|
105
|
-
):
|
106
|
-
temporary_deprioritized.add(r.rid)
|
107
|
-
else:
|
108
|
-
# Insert with a dummy key
|
109
|
-
self.waiting_queue_radix_tree.insert(
|
110
|
-
prefix_ids, torch.empty(len(prefix_ids), dtype=torch.bool)
|
111
|
-
)
|
107
|
+
return prefix_computed
|
112
108
|
|
113
|
-
|
109
|
+
def _determine_active_policy(self, waiting_queue: List[Req]) -> Policy:
|
110
|
+
if len(waiting_queue) > 128 and self.policy == CacheAwarePolicy.LPM:
|
111
|
+
# Turn off the expensive prefix matching and sorting when the #queue is large.
|
112
|
+
return CacheAgnosticPolicy.FCFS
|
113
|
+
return self.policy
|
114
|
+
|
115
|
+
def _validate_and_adjust_policy(
|
116
|
+
self, policy: str, tree_cache: BasePrefixCache
|
117
|
+
) -> Policy:
|
118
|
+
"""
|
119
|
+
Validates the policy and adjusts it if necessary based on tree cache settings.
|
120
|
+
"""
|
121
|
+
try:
|
122
|
+
policy_enum = CacheAwarePolicy(policy)
|
123
|
+
if tree_cache.disable:
|
124
|
+
# If tree_cache is disabled, using CacheAgnosticPolicy policy
|
125
|
+
return CacheAgnosticPolicy.FCFS
|
126
|
+
return policy_enum
|
127
|
+
except ValueError:
|
128
|
+
try:
|
129
|
+
return CacheAgnosticPolicy(policy)
|
130
|
+
except ValueError:
|
131
|
+
raise ValueError(f"Unknown schedule_policy: {policy=}")
|
132
|
+
|
133
|
+
def _compute_prefix_matches(
|
134
|
+
self, waiting_queue: List[Req], policy: CacheAwarePolicy
|
135
|
+
) -> Set[int]:
|
136
|
+
"""
|
137
|
+
Computes and caches the matching prefixes for requests in the waiting queue,
|
138
|
+
and handles in-batch prefix caching logic.
|
139
|
+
"""
|
140
|
+
temporary_deprioritized: Set[int] = set()
|
141
|
+
self.waiting_queue_radix_tree.reset()
|
142
|
+
|
143
|
+
for r in waiting_queue:
|
144
|
+
prefix_ids = r.adjust_max_prefix_ids()
|
145
|
+
|
146
|
+
# NOTE: the prefix_indices must always be aligned with last_node
|
147
|
+
r.prefix_indices, r.last_node = self.tree_cache.match_prefix(
|
148
|
+
rid=r.rid, key=prefix_ids
|
149
|
+
)
|
114
150
|
|
115
|
-
|
116
|
-
#
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
151
|
+
# NOTE(sang): This logic is for in-batch prefix caching;
|
152
|
+
# If there are more than 1 request that have small matching prefix from
|
153
|
+
# existing cache, but all those requests share the same prefix, we prefer
|
154
|
+
# to schedule only one of them so that we can increase the cache hit rate.
|
155
|
+
# We prefer to set IN_BATCH_PREFIX_CACHING_CHECK_THRESHOLD > 0 because too small
|
156
|
+
# threshold means we cannot use in-batch prefix caching for short prefixes.
|
157
|
+
# It is kind of common when the engine is long running (e.g., imagine the prefix "the").
|
158
|
+
if len(r.prefix_indices) <= IN_BATCH_PREFIX_CACHING_CHECK_THRESHOLD:
|
159
|
+
in_batch_matching_prefixes, _ = (
|
160
|
+
self.waiting_queue_radix_tree.match_prefix(
|
161
|
+
rid=r.rid, key=prefix_ids
|
162
|
+
)
|
122
163
|
)
|
164
|
+
if (
|
165
|
+
len(in_batch_matching_prefixes)
|
166
|
+
>= IN_BATCH_PREFIX_CACHING_DEPRIORITIZE_THRESHOLD
|
167
|
+
):
|
168
|
+
temporary_deprioritized.add(r.rid)
|
169
|
+
else:
|
170
|
+
# Insert with a dummy key
|
171
|
+
self.waiting_queue_radix_tree.insert(
|
172
|
+
prefix_ids, torch.empty(len(prefix_ids), dtype=torch.bool)
|
173
|
+
)
|
174
|
+
return temporary_deprioritized
|
175
|
+
|
176
|
+
@staticmethod
|
177
|
+
def _sort_by_longest_prefix(
|
178
|
+
waiting_queue: List[Req], temporary_deprioritized: Set[int]
|
179
|
+
) -> None:
|
180
|
+
"""Sorts the waiting queue based on the longest prefix match."""
|
181
|
+
waiting_queue.sort(
|
182
|
+
key=lambda r: (
|
183
|
+
-len(r.prefix_indices)
|
184
|
+
if r.rid not in temporary_deprioritized
|
185
|
+
else float("inf")
|
123
186
|
)
|
124
|
-
|
125
|
-
# first come first serve
|
126
|
-
pass
|
127
|
-
elif policy == "lof":
|
128
|
-
# longest output first
|
129
|
-
waiting_queue.sort(key=lambda x: -x.sampling_params.max_new_tokens)
|
130
|
-
elif policy == "random":
|
131
|
-
random.shuffle(waiting_queue)
|
132
|
-
elif policy == "dfs-weight":
|
133
|
-
# Experimental policy based on custom weights
|
134
|
-
last_node_to_reqs = defaultdict(list)
|
135
|
-
for req in waiting_queue:
|
136
|
-
last_node_to_reqs[req.last_node].append(req)
|
137
|
-
|
138
|
-
node_to_weight = defaultdict(int)
|
139
|
-
for node in last_node_to_reqs:
|
140
|
-
node_to_weight[node] = len(last_node_to_reqs[node])
|
141
|
-
self.calc_weight(self.tree_cache.root_node, node_to_weight)
|
142
|
-
|
143
|
-
waiting_queue.clear()
|
144
|
-
self.get_dfs_priority(
|
145
|
-
self.tree_cache.root_node,
|
146
|
-
node_to_weight,
|
147
|
-
last_node_to_reqs,
|
148
|
-
waiting_queue,
|
149
|
-
)
|
150
|
-
else:
|
151
|
-
raise ValueError(f"Unknown schedule_policy: {policy=}")
|
187
|
+
)
|
152
188
|
|
153
|
-
|
189
|
+
@staticmethod
|
190
|
+
def _sort_by_dfs_weight(
|
191
|
+
waiting_queue: List[Req], tree_cache: BasePrefixCache
|
192
|
+
) -> None:
|
193
|
+
"""Sorts the waiting queue based on a depth-first search weighting."""
|
194
|
+
last_node_to_reqs = defaultdict(list)
|
195
|
+
for req in waiting_queue:
|
196
|
+
last_node_to_reqs[req.last_node].append(req)
|
197
|
+
|
198
|
+
node_to_weight = defaultdict(int)
|
199
|
+
for node in last_node_to_reqs:
|
200
|
+
node_to_weight[node] = len(last_node_to_reqs[node])
|
201
|
+
SchedulePolicy._calc_weight(tree_cache.root_node, node_to_weight)
|
202
|
+
|
203
|
+
waiting_queue.clear()
|
204
|
+
SchedulePolicy._get_dfs_priority(
|
205
|
+
tree_cache.root_node,
|
206
|
+
node_to_weight,
|
207
|
+
last_node_to_reqs,
|
208
|
+
waiting_queue,
|
209
|
+
)
|
210
|
+
|
211
|
+
@staticmethod
|
212
|
+
def _sort_by_longest_output(waiting_queue: List[Req]) -> None:
|
213
|
+
"""Sorts the waiting queue based on the longest output (max_new_tokens)."""
|
214
|
+
waiting_queue.sort(key=lambda x: -x.sampling_params.max_new_tokens)
|
154
215
|
|
155
|
-
|
216
|
+
@staticmethod
|
217
|
+
def _sort_randomly(waiting_queue: List[Req]) -> None:
|
218
|
+
"""Shuffles the waiting queue randomly."""
|
219
|
+
random.shuffle(waiting_queue)
|
220
|
+
|
221
|
+
@staticmethod
|
222
|
+
def _calc_weight(cur_node: TreeNode, node_to_weight: Dict[TreeNode, int]) -> None:
|
156
223
|
for child in cur_node.children.values():
|
157
|
-
|
224
|
+
SchedulePolicy._calc_weight(child, node_to_weight)
|
158
225
|
node_to_weight[cur_node] += node_to_weight[child]
|
159
226
|
|
160
|
-
|
161
|
-
|
227
|
+
@staticmethod
|
228
|
+
def _get_dfs_priority(
|
162
229
|
cur_node: TreeNode,
|
163
230
|
node_to_priority: Dict[TreeNode, int],
|
164
231
|
last_node_to_reqs: Dict[TreeNode, List[Req]],
|
165
232
|
q: List,
|
166
|
-
):
|
233
|
+
) -> None:
|
167
234
|
childs = [child for child in cur_node.children.values()]
|
168
235
|
childs.sort(key=lambda x: -node_to_priority[x])
|
169
236
|
for child in childs:
|
170
|
-
|
237
|
+
SchedulePolicy._get_dfs_priority(
|
238
|
+
child, node_to_priority, last_node_to_reqs, q
|
239
|
+
)
|
171
240
|
q.extend(last_node_to_reqs[cur_node])
|
172
241
|
|
173
242
|
|
sglang/srt/managers/scheduler.py
CHANGED
@@ -76,6 +76,7 @@ from sglang.srt.mem_cache.radix_cache import RadixCache
|
|
76
76
|
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
|
77
77
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
78
78
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
79
|
+
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
79
80
|
from sglang.srt.utils import (
|
80
81
|
broadcast_pyobj,
|
81
82
|
configure_logger,
|
@@ -116,6 +117,14 @@ class Scheduler:
|
|
116
117
|
self.enable_overlap = not server_args.disable_overlap_schedule
|
117
118
|
self.skip_tokenizer_init = server_args.skip_tokenizer_init
|
118
119
|
self.enable_metrics = server_args.enable_metrics
|
120
|
+
self.spec_algorithm = SpeculativeAlgorithm.from_string(
|
121
|
+
server_args.speculative_algorithm
|
122
|
+
)
|
123
|
+
self.decode_mem_cache_buf_multiplier = (
|
124
|
+
self.server_args.speculative_num_draft_tokens
|
125
|
+
if not self.spec_algorithm.is_none()
|
126
|
+
else 1
|
127
|
+
)
|
119
128
|
|
120
129
|
# Init inter-process communication
|
121
130
|
context = zmq.Context(2)
|
@@ -199,6 +208,21 @@ class Scheduler:
|
|
199
208
|
nccl_port=port_args.nccl_port,
|
200
209
|
)
|
201
210
|
|
211
|
+
# Launch worker for speculative decoding if need
|
212
|
+
if self.spec_algorithm.is_eagle():
|
213
|
+
from sglang.srt.speculative.eagle_worker import EAGLEWorker
|
214
|
+
|
215
|
+
self.draft_worker = EAGLEWorker(
|
216
|
+
gpu_id=gpu_id,
|
217
|
+
tp_rank=tp_rank,
|
218
|
+
server_args=server_args,
|
219
|
+
nccl_port=port_args.nccl_port,
|
220
|
+
target_worker=self.tp_worker,
|
221
|
+
dp_rank=dp_rank,
|
222
|
+
)
|
223
|
+
else:
|
224
|
+
self.draft_worker = None
|
225
|
+
|
202
226
|
# Get token and memory info from the model worker
|
203
227
|
(
|
204
228
|
self.max_total_num_tokens,
|
@@ -855,6 +879,7 @@ class Scheduler:
|
|
855
879
|
self.tree_cache,
|
856
880
|
self.model_config,
|
857
881
|
self.enable_overlap,
|
882
|
+
self.spec_algorithm,
|
858
883
|
)
|
859
884
|
new_batch.prepare_for_extend()
|
860
885
|
|
@@ -888,11 +913,15 @@ class Scheduler:
|
|
888
913
|
return None
|
889
914
|
|
890
915
|
# Check if decode out of memory
|
891
|
-
if not batch.check_decode_mem() or (
|
916
|
+
if not batch.check_decode_mem(self.decode_mem_cache_buf_multiplier) or (
|
917
|
+
test_retract and batch.batch_size() > 10
|
918
|
+
):
|
892
919
|
old_ratio = self.new_token_ratio
|
893
920
|
|
894
921
|
retracted_reqs, new_token_ratio = batch.retract_decode()
|
895
922
|
self.new_token_ratio = new_token_ratio
|
923
|
+
if self.draft_worker:
|
924
|
+
self.draft_worker.finish_request(retracted_reqs)
|
896
925
|
|
897
926
|
logger.info(
|
898
927
|
"Decode out of memory happened. "
|
@@ -926,11 +955,20 @@ class Scheduler:
|
|
926
955
|
self.forward_ct += 1
|
927
956
|
|
928
957
|
if self.is_generation:
|
929
|
-
model_worker_batch = batch.get_model_worker_batch()
|
930
958
|
if batch.forward_mode.is_decode() or batch.extend_num_tokens != 0:
|
931
|
-
|
932
|
-
model_worker_batch
|
933
|
-
|
959
|
+
if self.spec_algorithm.is_none():
|
960
|
+
model_worker_batch = batch.get_model_worker_batch()
|
961
|
+
logits_output, next_token_ids = (
|
962
|
+
self.tp_worker.forward_batch_generation(model_worker_batch)
|
963
|
+
)
|
964
|
+
else:
|
965
|
+
(
|
966
|
+
logits_output,
|
967
|
+
next_token_ids,
|
968
|
+
model_worker_batch,
|
969
|
+
num_accepted_tokens,
|
970
|
+
) = self.draft_worker.forward_batch_speculative_generation(batch)
|
971
|
+
self.num_generated_tokens += num_accepted_tokens
|
934
972
|
elif batch.forward_mode.is_idle():
|
935
973
|
model_worker_batch = batch.get_model_worker_batch()
|
936
974
|
self.tp_worker.forward_batch_idle(model_worker_batch)
|
@@ -974,12 +1012,10 @@ class Scheduler:
|
|
974
1012
|
logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
|
975
1013
|
else:
|
976
1014
|
# Move next_token_ids and logprobs to cpu
|
1015
|
+
next_token_ids = next_token_ids.tolist()
|
977
1016
|
if batch.return_logprob:
|
978
1017
|
logits_output.next_token_logprobs = (
|
979
|
-
logits_output.next_token_logprobs
|
980
|
-
torch.arange(len(next_token_ids), device=self.device),
|
981
|
-
next_token_ids,
|
982
|
-
].tolist()
|
1018
|
+
logits_output.next_token_logprobs.tolist()
|
983
1019
|
)
|
984
1020
|
logits_output.input_token_logprobs = (
|
985
1021
|
logits_output.input_token_logprobs.tolist()
|
@@ -987,7 +1023,6 @@ class Scheduler:
|
|
987
1023
|
logits_output.normalized_prompt_logprobs = (
|
988
1024
|
logits_output.normalized_prompt_logprobs.tolist()
|
989
1025
|
)
|
990
|
-
next_token_ids = next_token_ids.tolist()
|
991
1026
|
|
992
1027
|
# Check finish conditions
|
993
1028
|
logprob_pt = 0
|
@@ -1064,13 +1099,9 @@ class Scheduler:
|
|
1064
1099
|
logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
|
1065
1100
|
next_token_logprobs = logits_output.next_token_logprobs
|
1066
1101
|
else:
|
1067
|
-
# Move next_token_ids and logprobs to cpu
|
1068
|
-
if batch.return_logprob:
|
1069
|
-
next_token_logprobs = logits_output.next_token_logprobs[
|
1070
|
-
torch.arange(len(next_token_ids), device=self.device),
|
1071
|
-
next_token_ids,
|
1072
|
-
].tolist()
|
1073
1102
|
next_token_ids = next_token_ids.tolist()
|
1103
|
+
if batch.return_logprob:
|
1104
|
+
next_token_logprobs = logits_output.next_token_logprobs.tolist()
|
1074
1105
|
|
1075
1106
|
self.token_to_kv_pool.free_group_begin()
|
1076
1107
|
|
@@ -1084,7 +1115,10 @@ class Scheduler:
|
|
1084
1115
|
self.token_to_kv_pool.free(batch.out_cache_loc[i : i + 1])
|
1085
1116
|
continue
|
1086
1117
|
|
1087
|
-
|
1118
|
+
if batch.spec_algorithm.is_none():
|
1119
|
+
# speculative worker will solve the output_ids in speculative decoding
|
1120
|
+
req.output_ids.append(next_token_id)
|
1121
|
+
|
1088
1122
|
req.check_finished()
|
1089
1123
|
|
1090
1124
|
if req.finished():
|
@@ -1095,10 +1129,10 @@ class Scheduler:
|
|
1095
1129
|
req.output_token_logprobs_idx.append(next_token_id)
|
1096
1130
|
if req.top_logprobs_num > 0:
|
1097
1131
|
req.output_top_logprobs_val.append(
|
1098
|
-
logits_output.
|
1132
|
+
logits_output.next_token_top_logprobs_val[i]
|
1099
1133
|
)
|
1100
1134
|
req.output_top_logprobs_idx.append(
|
1101
|
-
logits_output.
|
1135
|
+
logits_output.next_token_top_logprobs_idx[i]
|
1102
1136
|
)
|
1103
1137
|
|
1104
1138
|
if req.grammar is not None:
|
@@ -1200,8 +1234,9 @@ class Scheduler:
|
|
1200
1234
|
req.output_top_logprobs_idx.extend(
|
1201
1235
|
output.input_top_logprobs_idx[i][-req.last_update_decode_tokens :]
|
1202
1236
|
)
|
1203
|
-
|
1204
|
-
req.
|
1237
|
+
|
1238
|
+
req.output_top_logprobs_val.append(output.next_token_top_logprobs_val[i])
|
1239
|
+
req.output_top_logprobs_idx.append(output.next_token_top_logprobs_idx[i])
|
1205
1240
|
|
1206
1241
|
return num_input_logprobs
|
1207
1242
|
|
@@ -1258,6 +1293,9 @@ class Scheduler:
|
|
1258
1293
|
# If not stream, we still want to output some tokens to get the benefit of incremental decoding.
|
1259
1294
|
or (not req.stream and len(req.output_ids) % 50 == 0)
|
1260
1295
|
):
|
1296
|
+
if self.draft_worker and req.finished():
|
1297
|
+
self.draft_worker.finish_request(req)
|
1298
|
+
|
1261
1299
|
rids.append(req.rid)
|
1262
1300
|
finished_reasons.append(
|
1263
1301
|
req.finished_reason.to_json() if req.finished_reason else None
|
@@ -1329,11 +1367,11 @@ class Scheduler:
|
|
1329
1367
|
embeddings = []
|
1330
1368
|
prompt_tokens = []
|
1331
1369
|
for req in reqs:
|
1332
|
-
|
1333
|
-
|
1334
|
-
|
1335
|
-
|
1336
|
-
|
1370
|
+
if req.finished():
|
1371
|
+
rids.append(req.rid)
|
1372
|
+
finished_reasons.append(req.finished_reason.to_json())
|
1373
|
+
embeddings.append(req.embedding)
|
1374
|
+
prompt_tokens.append(len(req.origin_input_ids))
|
1337
1375
|
self.send_to_detokenizer.send_pyobj(
|
1338
1376
|
BatchEmbeddingOut(rids, finished_reasons, embeddings, prompt_tokens)
|
1339
1377
|
)
|
@@ -1389,6 +1427,7 @@ class Scheduler:
|
|
1389
1427
|
self.tree_cache,
|
1390
1428
|
self.model_config,
|
1391
1429
|
self.enable_overlap,
|
1430
|
+
self.spec_algorithm,
|
1392
1431
|
)
|
1393
1432
|
idle_batch.prepare_for_idle()
|
1394
1433
|
return idle_batch
|
@@ -1477,8 +1516,9 @@ class Scheduler:
|
|
1477
1516
|
return success, message
|
1478
1517
|
|
1479
1518
|
def update_weights_from_distributed(
|
1480
|
-
self,
|
1481
|
-
|
1519
|
+
self,
|
1520
|
+
recv_req: UpdateWeightsFromDistributedReqInput,
|
1521
|
+
) -> Tuple[bool, str]:
|
1482
1522
|
"""Update the online model parameter."""
|
1483
1523
|
success, message = self.tp_worker.update_weights_from_distributed(recv_req)
|
1484
1524
|
if success:
|
@@ -99,7 +99,7 @@ class Session:
|
|
99
99
|
|
100
100
|
if last_req is not None:
|
101
101
|
# trim bos token if it is an append
|
102
|
-
if req.input_ids[0] == tokenizer.bos_token_id:
|
102
|
+
if tokenizer is not None and req.input_ids[0] == tokenizer.bos_token_id:
|
103
103
|
req.input_ids = req.input_ids[1:]
|
104
104
|
|
105
105
|
input_ids = (
|
@@ -222,10 +222,8 @@ class TokenizerManager:
|
|
222
222
|
is_single = obj.is_single
|
223
223
|
if is_single:
|
224
224
|
tokenized_obj = await self._tokenize_one_request(obj)
|
225
|
-
self.
|
226
|
-
async for response in self._wait_one_response(
|
227
|
-
obj, request, created_time
|
228
|
-
):
|
225
|
+
self._send_one_request(obj, tokenized_obj, created_time)
|
226
|
+
async for response in self._wait_one_response(obj, request):
|
229
227
|
yield response
|
230
228
|
else:
|
231
229
|
async for response in self._handle_batch_request(
|
@@ -306,16 +304,24 @@ class TokenizerManager:
|
|
306
304
|
|
307
305
|
return tokenized_obj
|
308
306
|
|
309
|
-
|
307
|
+
def _send_one_request(
|
310
308
|
self,
|
311
309
|
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
312
|
-
|
310
|
+
tokenized_obj: Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput],
|
313
311
|
created_time: Optional[float] = None,
|
314
312
|
):
|
315
|
-
"""Wait for the response of one request."""
|
316
313
|
event = asyncio.Event()
|
317
314
|
state = ReqState([], False, event, obj, created_time=created_time)
|
318
315
|
self.rid_to_state[obj.rid] = state
|
316
|
+
self.send_to_scheduler.send_pyobj(tokenized_obj)
|
317
|
+
|
318
|
+
async def _wait_one_response(
|
319
|
+
self,
|
320
|
+
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
321
|
+
request: Optional[fastapi.Request] = None,
|
322
|
+
):
|
323
|
+
"""Wait for the response of one request."""
|
324
|
+
state = self.rid_to_state[obj.rid]
|
319
325
|
|
320
326
|
while True:
|
321
327
|
try:
|
@@ -361,10 +367,8 @@ class TokenizerManager:
|
|
361
367
|
for i in range(batch_size):
|
362
368
|
tmp_obj = obj[i]
|
363
369
|
tokenized_obj = await self._tokenize_one_request(tmp_obj)
|
364
|
-
self.
|
365
|
-
generators.append(
|
366
|
-
self._wait_one_response(tmp_obj, request, created_time)
|
367
|
-
)
|
370
|
+
self._send_one_request(tmp_obj, tokenized_obj, created_time)
|
371
|
+
generators.append(self._wait_one_response(tmp_obj, request))
|
368
372
|
rids.append(tmp_obj.rid)
|
369
373
|
else:
|
370
374
|
# FIXME: When using batch and parallel_sample_num together, the perf is not optimal.
|
@@ -389,10 +393,8 @@ class TokenizerManager:
|
|
389
393
|
tokenized_obj.sampling_params = copy.copy(tokenized_obj.sampling_params)
|
390
394
|
tokenized_obj.sampling_params.max_new_tokens = 0
|
391
395
|
tokenized_obj.stream = False
|
392
|
-
self.
|
393
|
-
await self._wait_one_response(
|
394
|
-
tmp_obj, request, created_time
|
395
|
-
).__anext__()
|
396
|
+
self._send_one_request(tmp_obj, tokenized_obj, created_time)
|
397
|
+
await self._wait_one_response(tmp_obj, request).__anext__()
|
396
398
|
|
397
399
|
# Expand requests, assign new rids for them, and send them
|
398
400
|
for i in range(batch_size):
|
@@ -400,10 +402,8 @@ class TokenizerManager:
|
|
400
402
|
tmp_obj = copy.copy(objs[i])
|
401
403
|
tokenized_obj = copy.copy(tokenized_objs[i])
|
402
404
|
tokenized_obj.rid = tmp_obj.regenerate_rid()
|
403
|
-
self.
|
404
|
-
generators.append(
|
405
|
-
self._wait_one_response(tmp_obj, request, created_time)
|
406
|
-
)
|
405
|
+
self._send_one_request(tmp_obj, tokenized_obj, created_time)
|
406
|
+
generators.append(self._wait_one_response(tmp_obj, request))
|
407
407
|
rids.append(tmp_obj.rid)
|
408
408
|
|
409
409
|
# Wait for all requests
|
@@ -688,7 +688,7 @@ class TokenizerManager:
|
|
688
688
|
if self.enable_metrics:
|
689
689
|
completion_tokens = (
|
690
690
|
recv_obj.completion_tokens[i]
|
691
|
-
if recv_obj
|
691
|
+
if getattr(recv_obj, "completion_tokens", None)
|
692
692
|
else 0
|
693
693
|
)
|
694
694
|
|
@@ -699,6 +699,7 @@ class TokenizerManager:
|
|
699
699
|
)
|
700
700
|
else:
|
701
701
|
if completion_tokens >= 2:
|
702
|
+
# Compute time_per_output_token for the streaming case
|
702
703
|
self.metrics_collector.observe_time_per_output_token(
|
703
704
|
(time.time() - state.first_token_time)
|
704
705
|
/ (completion_tokens - 1)
|
@@ -714,7 +715,12 @@ class TokenizerManager:
|
|
714
715
|
self.metrics_collector.observe_e2e_request_latency(
|
715
716
|
time.time() - state.created_time
|
716
717
|
)
|
717
|
-
|
718
|
+
# Compute time_per_output_token for the non-streaming case
|
719
|
+
if (
|
720
|
+
hasattr(state.obj, "stream")
|
721
|
+
and not state.obj.stream
|
722
|
+
and completion_tokens >= 1
|
723
|
+
):
|
718
724
|
self.metrics_collector.observe_time_per_output_token(
|
719
725
|
(time.time() - state.created_time)
|
720
726
|
/ completion_tokens
|