sglang 0.2.11__py3-none-any.whl → 0.2.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/api.py +7 -1
- sglang/bench_latency.py +9 -6
- sglang/bench_serving.py +46 -22
- sglang/global_config.py +1 -1
- sglang/lang/backend/runtime_endpoint.py +60 -49
- sglang/lang/compiler.py +2 -2
- sglang/lang/interpreter.py +4 -2
- sglang/lang/ir.py +16 -7
- sglang/srt/constrained/base_tool_cache.py +1 -1
- sglang/srt/constrained/fsm_cache.py +12 -2
- sglang/srt/constrained/jump_forward.py +13 -2
- sglang/srt/layers/activation.py +32 -0
- sglang/srt/layers/{token_attention.py → decode_attention.py} +9 -5
- sglang/srt/layers/extend_attention.py +9 -2
- sglang/srt/layers/fused_moe/__init__.py +1 -0
- sglang/srt/layers/{fused_moe.py → fused_moe/fused_moe.py} +165 -108
- sglang/srt/layers/fused_moe/layer.py +587 -0
- sglang/srt/layers/layernorm.py +65 -0
- sglang/srt/layers/logits_processor.py +7 -2
- sglang/srt/layers/pooler.py +50 -0
- sglang/srt/layers/{context_flashattention_nopad.py → prefill_attention.py} +5 -0
- sglang/srt/layers/radix_attention.py +40 -16
- sglang/srt/managers/detokenizer_manager.py +31 -9
- sglang/srt/managers/io_struct.py +63 -0
- sglang/srt/managers/policy_scheduler.py +173 -25
- sglang/srt/managers/schedule_batch.py +115 -97
- sglang/srt/managers/tokenizer_manager.py +194 -112
- sglang/srt/managers/tp_worker.py +290 -359
- sglang/srt/mem_cache/{base_cache.py → base_prefix_cache.py} +9 -4
- sglang/srt/mem_cache/chunk_cache.py +43 -20
- sglang/srt/mem_cache/memory_pool.py +2 -2
- sglang/srt/mem_cache/radix_cache.py +74 -40
- sglang/srt/model_executor/cuda_graph_runner.py +71 -25
- sglang/srt/model_executor/forward_batch_info.py +293 -156
- sglang/srt/model_executor/model_runner.py +77 -57
- sglang/srt/models/chatglm.py +2 -2
- sglang/srt/models/commandr.py +1 -1
- sglang/srt/models/deepseek.py +2 -2
- sglang/srt/models/deepseek_v2.py +7 -6
- sglang/srt/models/gemma.py +1 -1
- sglang/srt/models/gemma2.py +11 -6
- sglang/srt/models/grok.py +50 -396
- sglang/srt/models/internlm2.py +2 -7
- sglang/srt/models/llama2.py +4 -4
- sglang/srt/models/llama_embedding.py +88 -0
- sglang/srt/models/minicpm.py +2 -2
- sglang/srt/models/mixtral.py +56 -254
- sglang/srt/models/mixtral_quant.py +1 -4
- sglang/srt/models/qwen.py +2 -2
- sglang/srt/models/qwen2.py +2 -2
- sglang/srt/models/qwen2_moe.py +2 -13
- sglang/srt/models/stablelm.py +1 -1
- sglang/srt/openai_api/adapter.py +187 -48
- sglang/srt/openai_api/protocol.py +37 -1
- sglang/srt/sampling/penaltylib/__init__.py +13 -0
- sglang/srt/sampling/penaltylib/orchestrator.py +357 -0
- sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +80 -0
- sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +105 -0
- sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +79 -0
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +83 -0
- sglang/srt/sampling_params.py +31 -8
- sglang/srt/server.py +91 -29
- sglang/srt/server_args.py +32 -19
- sglang/srt/utils.py +32 -15
- sglang/test/run_eval.py +10 -1
- sglang/test/runners.py +81 -73
- sglang/test/simple_eval_humaneval.py +2 -8
- sglang/test/simple_eval_mgsm.py +203 -0
- sglang/test/srt/sampling/penaltylib/utils.py +337 -0
- sglang/test/test_layernorm.py +60 -0
- sglang/test/test_programs.py +36 -7
- sglang/test/test_utils.py +24 -2
- sglang/utils.py +0 -1
- sglang/version.py +1 -1
- {sglang-0.2.11.dist-info → sglang-0.2.13.dist-info}/METADATA +33 -16
- sglang-0.2.13.dist-info/RECORD +112 -0
- {sglang-0.2.11.dist-info → sglang-0.2.13.dist-info}/WHEEL +1 -1
- sglang/srt/layers/linear.py +0 -884
- sglang/srt/layers/quantization/__init__.py +0 -64
- sglang/srt/layers/quantization/fp8.py +0 -677
- sglang/srt/model_loader/model_loader.py +0 -292
- sglang/srt/model_loader/utils.py +0 -275
- sglang-0.2.11.dist-info/RECORD +0 -102
- {sglang-0.2.11.dist-info → sglang-0.2.13.dist-info}/LICENSE +0 -0
- {sglang-0.2.11.dist-info → sglang-0.2.13.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,5 @@
|
|
1
1
|
from abc import ABC, abstractmethod
|
2
|
+
from typing import Callable
|
2
3
|
|
3
4
|
|
4
5
|
class BasePrefixCache(ABC):
|
@@ -17,11 +18,15 @@ class BasePrefixCache(ABC):
|
|
17
18
|
pass
|
18
19
|
|
19
20
|
@abstractmethod
|
20
|
-
def
|
21
|
+
def cache_finished_req(self, **kwargs):
|
21
22
|
pass
|
22
23
|
|
23
24
|
@abstractmethod
|
24
|
-
def
|
25
|
+
def cache_unfinished_req(self, **kwargs):
|
26
|
+
pass
|
27
|
+
|
28
|
+
@abstractmethod
|
29
|
+
def evict(self, num_tokens: int, evict_callback: Callable):
|
25
30
|
pass
|
26
31
|
|
27
32
|
@abstractmethod
|
@@ -37,7 +42,7 @@ class BasePrefixCache(ABC):
|
|
37
42
|
pass
|
38
43
|
|
39
44
|
def total_size(self):
|
40
|
-
raise NotImplementedError
|
45
|
+
raise NotImplementedError()
|
41
46
|
|
42
47
|
def pretty_print(self):
|
43
|
-
raise NotImplementedError
|
48
|
+
raise NotImplementedError()
|
@@ -1,6 +1,14 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
1
3
|
"""Cache for chunked prefill, used when RadixCache is disabled."""
|
2
4
|
|
3
|
-
from
|
5
|
+
from typing import TYPE_CHECKING, Callable, List, Optional
|
6
|
+
|
7
|
+
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
8
|
+
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
9
|
+
|
10
|
+
if TYPE_CHECKING:
|
11
|
+
from sglang.srt.managers.schedule_batch import Req
|
4
12
|
|
5
13
|
|
6
14
|
class ChunkCacheEntry:
|
@@ -10,7 +18,9 @@ class ChunkCacheEntry:
|
|
10
18
|
|
11
19
|
|
12
20
|
class ChunkCache(BasePrefixCache):
|
13
|
-
def __init__(
|
21
|
+
def __init__(
|
22
|
+
self, req_to_token_pool: ReqToTokenPool, token_to_kv_pool: BaseTokenToKVPool
|
23
|
+
):
|
14
24
|
self.disable = True
|
15
25
|
self.req_to_token_pool = req_to_token_pool
|
16
26
|
self.token_to_kv_pool = token_to_kv_pool
|
@@ -20,34 +30,47 @@ class ChunkCache(BasePrefixCache):
|
|
20
30
|
def reset(self):
|
21
31
|
self.entries = {}
|
22
32
|
|
23
|
-
def match_prefix(self, rid,
|
33
|
+
def match_prefix(self, rid: int, key: List[int]):
|
24
34
|
if rid not in self.entries:
|
25
35
|
return [], None
|
26
36
|
|
27
37
|
entry = self.entries[rid]
|
28
|
-
|
38
|
+
max_prefix_len = len(key)
|
39
|
+
return entry.value[:max_prefix_len], entry
|
29
40
|
|
30
|
-
def
|
31
|
-
|
32
|
-
|
33
|
-
indices = self.req_to_token_pool.req_to_token[req_pool_idx, : len(token_ids)]
|
34
|
-
if del_in_memory_pool:
|
35
|
-
assert rid in self.entries
|
36
|
-
self.req_to_token_pool.free(req_pool_idx)
|
37
|
-
self.token_to_kv_pool.free(indices)
|
38
|
-
return
|
41
|
+
def cache_finished_req(self, req: Req, token_ids: Optional[List[int]] = None):
|
42
|
+
if token_ids is None:
|
43
|
+
token_ids = (req.origin_input_ids + req.output_ids)[:-1]
|
39
44
|
|
40
|
-
|
41
|
-
|
45
|
+
kv_indices = self.req_to_token_pool.req_to_token[
|
46
|
+
req.req_pool_idx, : len(token_ids)
|
47
|
+
]
|
48
|
+
self.req_to_token_pool.free(req.req_pool_idx)
|
49
|
+
self.token_to_kv_pool.free(kv_indices)
|
42
50
|
|
43
|
-
|
44
|
-
|
45
|
-
|
51
|
+
if req.rid in self.entries:
|
52
|
+
del self.entries[req.rid]
|
53
|
+
|
54
|
+
def cache_unfinished_req(self, req: Req, token_ids: Optional[List[int]] = None):
|
55
|
+
if token_ids is None:
|
56
|
+
token_ids = req.fill_ids
|
57
|
+
|
58
|
+
kv_indices = self.req_to_token_pool.req_to_token[
|
59
|
+
req.req_pool_idx, : len(token_ids)
|
60
|
+
]
|
61
|
+
|
62
|
+
if req.rid not in self.entries:
|
63
|
+
self.entries[req.rid] = ChunkCacheEntry(req.rid, kv_indices)
|
64
|
+
|
65
|
+
entry = self.entries[req.rid]
|
66
|
+
entry.value = kv_indices
|
67
|
+
req.prefix_indices = kv_indices
|
68
|
+
req.last_node = entry
|
46
69
|
|
47
70
|
def insert(self):
|
48
|
-
raise NotImplementedError
|
71
|
+
raise NotImplementedError()
|
49
72
|
|
50
|
-
def evict(self, num_tokens, evict_callback):
|
73
|
+
def evict(self, num_tokens: int, evict_callback: Callable):
|
51
74
|
pass
|
52
75
|
|
53
76
|
def inc_lock_ref(self, node):
|
@@ -16,7 +16,7 @@ limitations under the License.
|
|
16
16
|
"""Memory pool."""
|
17
17
|
|
18
18
|
import logging
|
19
|
-
from typing import List
|
19
|
+
from typing import List, Union
|
20
20
|
|
21
21
|
import torch
|
22
22
|
|
@@ -42,7 +42,7 @@ class ReqToTokenPool:
|
|
42
42
|
|
43
43
|
return select_index
|
44
44
|
|
45
|
-
def free(self, free_index):
|
45
|
+
def free(self, free_index: Union[int, List[int]]):
|
46
46
|
if isinstance(free_index, (int,)):
|
47
47
|
self.free_slots.append(free_index)
|
48
48
|
else:
|
@@ -1,3 +1,5 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
1
3
|
"""
|
2
4
|
Copyright 2023-2024 SGLang Team
|
3
5
|
Licensed under the Apache License, Version 2.0 (the "License");
|
@@ -20,10 +22,15 @@ The radix tree data structure for managing the KV cache.
|
|
20
22
|
import heapq
|
21
23
|
import time
|
22
24
|
from collections import defaultdict
|
25
|
+
from typing import TYPE_CHECKING, Callable, List, Optional
|
23
26
|
|
24
27
|
import torch
|
25
28
|
|
26
|
-
from sglang.srt.mem_cache.
|
29
|
+
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
30
|
+
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
31
|
+
|
32
|
+
if TYPE_CHECKING:
|
33
|
+
from sglang.srt.managers.schedule_batch import Req
|
27
34
|
|
28
35
|
|
29
36
|
class TreeNode:
|
@@ -39,7 +46,7 @@ class TreeNode:
|
|
39
46
|
return self.last_access_time < other.last_access_time
|
40
47
|
|
41
48
|
|
42
|
-
def _key_match(key0, key1):
|
49
|
+
def _key_match(key0: List, key1: List):
|
43
50
|
i = 0
|
44
51
|
for k0, k1 in zip(key0, key1):
|
45
52
|
if k0 != k1:
|
@@ -49,7 +56,12 @@ def _key_match(key0, key1):
|
|
49
56
|
|
50
57
|
|
51
58
|
class RadixCache(BasePrefixCache):
|
52
|
-
def __init__(
|
59
|
+
def __init__(
|
60
|
+
self,
|
61
|
+
req_to_token_pool: ReqToTokenPool,
|
62
|
+
token_to_kv_pool: BaseTokenToKVPool,
|
63
|
+
disable: bool = False,
|
64
|
+
):
|
53
65
|
self.req_to_token_pool = req_to_token_pool
|
54
66
|
self.token_to_kv_pool = token_to_kv_pool
|
55
67
|
self.disable = disable
|
@@ -64,7 +76,7 @@ class RadixCache(BasePrefixCache):
|
|
64
76
|
self.root_node.lock_ref = 1
|
65
77
|
self.evictable_size_ = 0
|
66
78
|
|
67
|
-
def match_prefix(self, key, **kwargs):
|
79
|
+
def match_prefix(self, key: List, **kwargs):
|
68
80
|
if self.disable:
|
69
81
|
return [], self.root_node
|
70
82
|
|
@@ -74,10 +86,10 @@ class RadixCache(BasePrefixCache):
|
|
74
86
|
if value:
|
75
87
|
value = torch.concat(value)
|
76
88
|
else:
|
77
|
-
value = torch.tensor([], dtype=torch.
|
89
|
+
value = torch.tensor([], dtype=torch.int32)
|
78
90
|
return value, last_node[0]
|
79
91
|
|
80
|
-
def insert(self, key, value=None):
|
92
|
+
def insert(self, key: List, value=None):
|
81
93
|
if self.disable:
|
82
94
|
return 0
|
83
95
|
|
@@ -85,40 +97,54 @@ class RadixCache(BasePrefixCache):
|
|
85
97
|
value = [x for x in key]
|
86
98
|
return self._insert_helper(self.root_node, key, value)
|
87
99
|
|
88
|
-
def
|
89
|
-
|
90
|
-
token_ids
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
**kwargs,
|
96
|
-
):
|
97
|
-
# Insert the request into radix cache
|
98
|
-
indices = self.req_to_token_pool.req_to_token[req_pool_idx, : len(token_ids)]
|
99
|
-
new_prefix_len = self.insert(token_ids, indices.clone())
|
100
|
+
def cache_finished_req(self, req: Req, token_ids: Optional[List[int]] = None):
|
101
|
+
"""Cache request when it finishes."""
|
102
|
+
if token_ids is None:
|
103
|
+
token_ids = (req.origin_input_ids + req.output_ids)[:-1]
|
104
|
+
kv_indices = self.req_to_token_pool.req_to_token[
|
105
|
+
req.req_pool_idx, : len(token_ids)
|
106
|
+
]
|
100
107
|
|
101
108
|
if self.disable:
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
return torch.tensor([], dtype=torch.int64), self.root_node
|
109
|
+
self.token_to_kv_pool.free(kv_indices)
|
110
|
+
self.req_to_token_pool.free(req.req_pool_idx)
|
111
|
+
return
|
106
112
|
|
107
113
|
# Radix Cache takes one ref in memory pool
|
108
|
-
self.
|
114
|
+
new_prefix_len = self.insert(token_ids, kv_indices.clone())
|
115
|
+
self.token_to_kv_pool.free(kv_indices[len(req.prefix_indices) : new_prefix_len])
|
109
116
|
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
117
|
+
# Remove req slot release the cache lock
|
118
|
+
self.req_to_token_pool.free(req.req_pool_idx)
|
119
|
+
self.dec_lock_ref(req.last_node)
|
120
|
+
|
121
|
+
def cache_unfinished_req(self, req: Req, token_ids: Optional[List[int]] = None):
|
122
|
+
"""Cache request when it is unfinished."""
|
123
|
+
if self.disable:
|
124
|
+
return
|
115
125
|
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
126
|
+
if token_ids is None:
|
127
|
+
token_ids = req.fill_ids
|
128
|
+
|
129
|
+
kv_indices = self.req_to_token_pool.req_to_token[
|
130
|
+
req.req_pool_idx, : len(token_ids)
|
131
|
+
]
|
132
|
+
|
133
|
+
# Radix Cache takes one ref in memory pool
|
134
|
+
new_prefix_len = self.insert(token_ids, kv_indices.clone())
|
135
|
+
self.token_to_kv_pool.free(kv_indices[len(req.prefix_indices) : new_prefix_len])
|
136
|
+
|
137
|
+
# The prefix indices could be updated, reuse it
|
138
|
+
new_indices, new_last_node = self.match_prefix(token_ids)
|
139
|
+
assert len(new_indices) == len(token_ids)
|
140
|
+
self.req_to_token_pool.req_to_token[
|
141
|
+
req.req_pool_idx, len(req.prefix_indices) : len(new_indices)
|
142
|
+
] = new_indices[len(req.prefix_indices) :]
|
143
|
+
|
144
|
+
self.dec_lock_ref(req.last_node)
|
145
|
+
self.inc_lock_ref(new_last_node)
|
146
|
+
req.prefix_indices = new_indices
|
147
|
+
req.last_node = new_last_node
|
122
148
|
|
123
149
|
def pretty_print(self):
|
124
150
|
self._print_helper(self.root_node, 0)
|
@@ -127,7 +153,7 @@ class RadixCache(BasePrefixCache):
|
|
127
153
|
def total_size(self):
|
128
154
|
return self._total_size_helper(self.root_node)
|
129
155
|
|
130
|
-
def evict(self, num_tokens, evict_callback):
|
156
|
+
def evict(self, num_tokens: int, evict_callback: Callable):
|
131
157
|
if self.disable:
|
132
158
|
return
|
133
159
|
|
@@ -151,6 +177,9 @@ class RadixCache(BasePrefixCache):
|
|
151
177
|
heapq.heappush(leaves, x.parent)
|
152
178
|
|
153
179
|
def inc_lock_ref(self, node: TreeNode):
|
180
|
+
if self.disable:
|
181
|
+
return 0
|
182
|
+
|
154
183
|
delta = 0
|
155
184
|
while node != self.root_node:
|
156
185
|
if node.lock_ref == 0:
|
@@ -161,6 +190,9 @@ class RadixCache(BasePrefixCache):
|
|
161
190
|
return delta
|
162
191
|
|
163
192
|
def dec_lock_ref(self, node: TreeNode):
|
193
|
+
if self.disable:
|
194
|
+
return 0
|
195
|
+
|
164
196
|
delta = 0
|
165
197
|
while node != self.root_node:
|
166
198
|
if node.lock_ref == 1:
|
@@ -175,7 +207,9 @@ class RadixCache(BasePrefixCache):
|
|
175
207
|
|
176
208
|
##### Internal Helper Functions #####
|
177
209
|
|
178
|
-
def _match_prefix_helper(
|
210
|
+
def _match_prefix_helper(
|
211
|
+
self, node: TreeNode, key: List, value, last_node: TreeNode
|
212
|
+
):
|
179
213
|
node.last_access_time = time.time()
|
180
214
|
if len(key) == 0:
|
181
215
|
return
|
@@ -192,7 +226,7 @@ class RadixCache(BasePrefixCache):
|
|
192
226
|
last_node[0] = child
|
193
227
|
self._match_prefix_helper(child, key[prefix_len:], value, last_node)
|
194
228
|
|
195
|
-
def _split_node(self, key, child: TreeNode, split_len):
|
229
|
+
def _split_node(self, key, child: TreeNode, split_len: int):
|
196
230
|
# new_node -> child
|
197
231
|
new_node = TreeNode()
|
198
232
|
new_node.children = {key[split_len:][0]: child}
|
@@ -206,7 +240,7 @@ class RadixCache(BasePrefixCache):
|
|
206
240
|
new_node.parent.children[key[:split_len][0]] = new_node
|
207
241
|
return new_node
|
208
242
|
|
209
|
-
def _insert_helper(self, node, key, value):
|
243
|
+
def _insert_helper(self, node: TreeNode, key: List, value):
|
210
244
|
node.last_access_time = time.time()
|
211
245
|
if len(key) == 0:
|
212
246
|
return 0
|
@@ -237,7 +271,7 @@ class RadixCache(BasePrefixCache):
|
|
237
271
|
self.evictable_size_ += len(value)
|
238
272
|
return 0
|
239
273
|
|
240
|
-
def _print_helper(self, node: TreeNode, indent):
|
274
|
+
def _print_helper(self, node: TreeNode, indent: int):
|
241
275
|
for _, child in node.children.items():
|
242
276
|
print(" " * indent, len(child.key), child.key[:10], f"r={child.lock_ref}")
|
243
277
|
self._print_helper(child, indent=indent + 2)
|
@@ -249,7 +283,7 @@ class RadixCache(BasePrefixCache):
|
|
249
283
|
del node.parent.children[k]
|
250
284
|
self.evictable_size_ -= len(node.key)
|
251
285
|
|
252
|
-
def _total_size_helper(self, node):
|
286
|
+
def _total_size_helper(self, node: TreeNode):
|
253
287
|
x = len(node.value)
|
254
288
|
for child in node.children.values():
|
255
289
|
x += self._total_size_helper(child)
|
@@ -33,7 +33,7 @@ from sglang.srt.managers.schedule_batch import ScheduleBatch
|
|
33
33
|
from sglang.srt.model_executor.forward_batch_info import (
|
34
34
|
ForwardMode,
|
35
35
|
InputMetadata,
|
36
|
-
|
36
|
+
update_flashinfer_indices,
|
37
37
|
)
|
38
38
|
from sglang.srt.utils import monkey_patch_vllm_all_gather
|
39
39
|
|
@@ -71,6 +71,18 @@ def patch_model(
|
|
71
71
|
tp_group.ca_comm = backup_ca_comm
|
72
72
|
|
73
73
|
|
74
|
+
def set_torch_compile_config():
|
75
|
+
import torch._dynamo.config
|
76
|
+
import torch._inductor.config
|
77
|
+
|
78
|
+
torch._inductor.config.coordinate_descent_tuning = True
|
79
|
+
torch._inductor.config.triton.unique_kernel_names = True
|
80
|
+
torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
|
81
|
+
|
82
|
+
# FIXME: tmp workaround
|
83
|
+
torch._dynamo.config.accumulated_cache_size_limit = 1024
|
84
|
+
|
85
|
+
|
74
86
|
class CudaGraphRunner:
|
75
87
|
def __init__(self, model_runner, max_batch_size_to_capture, use_torch_compile):
|
76
88
|
self.model_runner = model_runner
|
@@ -86,8 +98,8 @@ class CudaGraphRunner:
|
|
86
98
|
self.req_pool_indices = torch.zeros(
|
87
99
|
(self.max_bs,), dtype=torch.int32, device="cuda"
|
88
100
|
)
|
89
|
-
self.seq_lens = torch.
|
90
|
-
self.position_ids_offsets = torch.
|
101
|
+
self.seq_lens = torch.zeros((self.max_bs,), dtype=torch.int32, device="cuda")
|
102
|
+
self.position_ids_offsets = torch.ones(
|
91
103
|
(self.max_bs,), dtype=torch.int32, device="cuda"
|
92
104
|
)
|
93
105
|
self.out_cache_loc = torch.zeros(
|
@@ -95,9 +107,6 @@ class CudaGraphRunner:
|
|
95
107
|
)
|
96
108
|
|
97
109
|
# FlashInfer inputs
|
98
|
-
self.flashinfer_workspace_buffer = (
|
99
|
-
self.model_runner.flashinfer_workspace_buffers[0]
|
100
|
-
)
|
101
110
|
self.flashinfer_kv_indptr = torch.zeros(
|
102
111
|
(self.max_bs + 1,), dtype=torch.int32, device="cuda"
|
103
112
|
)
|
@@ -109,9 +118,29 @@ class CudaGraphRunner:
|
|
109
118
|
self.flashinfer_kv_last_page_len = torch.ones(
|
110
119
|
(self.max_bs,), dtype=torch.int32, device="cuda"
|
111
120
|
)
|
121
|
+
if model_runner.sliding_window_size is None:
|
122
|
+
self.flashinfer_workspace_buffer = (
|
123
|
+
self.model_runner.flashinfer_workspace_buffer
|
124
|
+
)
|
125
|
+
else:
|
126
|
+
self.flashinfer_workspace_buffer = (
|
127
|
+
self.model_runner.flashinfer_workspace_buffer
|
128
|
+
)
|
129
|
+
|
130
|
+
self.flashinfer_kv_indptr = [
|
131
|
+
self.flashinfer_kv_indptr,
|
132
|
+
self.flashinfer_kv_indptr.clone(),
|
133
|
+
]
|
134
|
+
self.flashinfer_kv_indices = [
|
135
|
+
self.flashinfer_kv_indices,
|
136
|
+
self.flashinfer_kv_indices.clone(),
|
137
|
+
]
|
112
138
|
|
113
139
|
self.compile_bs = [1, 2, 4, 8, 16, 24, 32] if use_torch_compile else []
|
114
140
|
|
141
|
+
if use_torch_compile:
|
142
|
+
set_torch_compile_config()
|
143
|
+
|
115
144
|
def can_run(self, batch_size):
|
116
145
|
return batch_size < self.max_bs
|
117
146
|
|
@@ -156,16 +185,33 @@ class CudaGraphRunner:
|
|
156
185
|
use_tensor_cores = True
|
157
186
|
else:
|
158
187
|
use_tensor_cores = False
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
188
|
+
if self.model_runner.sliding_window_size is None:
|
189
|
+
flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
|
190
|
+
self.flashinfer_workspace_buffer,
|
191
|
+
"NHD",
|
192
|
+
use_cuda_graph=True,
|
193
|
+
use_tensor_cores=use_tensor_cores,
|
194
|
+
paged_kv_indptr_buffer=self.flashinfer_kv_indptr[: bs + 1],
|
195
|
+
paged_kv_indices_buffer=self.flashinfer_kv_indices,
|
196
|
+
paged_kv_last_page_len_buffer=self.flashinfer_kv_last_page_len[:bs],
|
197
|
+
)
|
198
|
+
else:
|
199
|
+
flashinfer_decode_wrapper = []
|
200
|
+
for i in range(2):
|
201
|
+
flashinfer_decode_wrapper.append(
|
202
|
+
BatchDecodeWithPagedKVCacheWrapper(
|
203
|
+
self.flashinfer_workspace_buffer,
|
204
|
+
"NHD",
|
205
|
+
use_cuda_graph=True,
|
206
|
+
use_tensor_cores=use_tensor_cores,
|
207
|
+
paged_kv_indptr_buffer=self.flashinfer_kv_indptr[i][: bs + 1],
|
208
|
+
paged_kv_indices_buffer=self.flashinfer_kv_indices[i],
|
209
|
+
paged_kv_last_page_len_buffer=self.flashinfer_kv_last_page_len[
|
210
|
+
:bs
|
211
|
+
],
|
212
|
+
)
|
213
|
+
)
|
214
|
+
update_flashinfer_indices(
|
169
215
|
ForwardMode.DECODE,
|
170
216
|
self.model_runner,
|
171
217
|
req_pool_indices,
|
@@ -176,19 +222,19 @@ class CudaGraphRunner:
|
|
176
222
|
|
177
223
|
# Run and capture
|
178
224
|
def run_once():
|
179
|
-
input_metadata = InputMetadata
|
180
|
-
self.model_runner,
|
225
|
+
input_metadata = InputMetadata(
|
181
226
|
forward_mode=ForwardMode.DECODE,
|
227
|
+
batch_size=bs,
|
182
228
|
req_pool_indices=req_pool_indices,
|
183
229
|
seq_lens=seq_lens,
|
184
|
-
|
185
|
-
|
230
|
+
req_to_token_pool=self.model_runner.req_to_token_pool,
|
231
|
+
token_to_kv_pool=self.model_runner.token_to_kv_pool,
|
186
232
|
out_cache_loc=out_cache_loc,
|
187
233
|
return_logprob=False,
|
188
234
|
top_logprobs_nums=0,
|
189
|
-
|
235
|
+
positions=(seq_lens - 1 + position_ids_offsets).to(torch.int64),
|
236
|
+
flashinfer_decode_wrapper=flashinfer_decode_wrapper,
|
190
237
|
)
|
191
|
-
input_metadata.flashinfer_decode_wrapper = flashinfer_decode_wrapper
|
192
238
|
|
193
239
|
return forward(input_ids, input_metadata.positions, input_metadata)
|
194
240
|
|
@@ -210,8 +256,8 @@ class CudaGraphRunner:
|
|
210
256
|
index = bisect.bisect_left(self.batch_size_list, raw_bs)
|
211
257
|
bs = self.batch_size_list[index]
|
212
258
|
if bs != raw_bs:
|
213
|
-
self.seq_lens.
|
214
|
-
self.position_ids_offsets.
|
259
|
+
self.seq_lens.zero_()
|
260
|
+
self.position_ids_offsets.fill_(1)
|
215
261
|
self.out_cache_loc.zero_()
|
216
262
|
|
217
263
|
# Common inputs
|
@@ -222,7 +268,7 @@ class CudaGraphRunner:
|
|
222
268
|
self.out_cache_loc[:raw_bs] = batch.out_cache_loc
|
223
269
|
|
224
270
|
# FlashInfer inputs
|
225
|
-
|
271
|
+
update_flashinfer_indices(
|
226
272
|
ForwardMode.DECODE,
|
227
273
|
self.model_runner,
|
228
274
|
self.req_pool_indices[:bs],
|