sglang 0.4.7__py3-none-any.whl → 0.4.7.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -0
- sglang/api.py +7 -0
- sglang/bench_serving.py +1 -1
- sglang/lang/interpreter.py +40 -1
- sglang/lang/ir.py +27 -0
- sglang/math_utils.py +8 -0
- sglang/srt/configs/model_config.py +6 -0
- sglang/srt/conversation.py +6 -0
- sglang/srt/disaggregation/base/__init__.py +1 -1
- sglang/srt/disaggregation/base/conn.py +25 -11
- sglang/srt/disaggregation/common/__init__.py +5 -1
- sglang/srt/disaggregation/common/utils.py +42 -0
- sglang/srt/disaggregation/decode.py +196 -51
- sglang/srt/disaggregation/fake/__init__.py +1 -1
- sglang/srt/disaggregation/fake/conn.py +15 -9
- sglang/srt/disaggregation/mooncake/__init__.py +1 -1
- sglang/srt/disaggregation/mooncake/conn.py +18 -13
- sglang/srt/disaggregation/nixl/__init__.py +6 -1
- sglang/srt/disaggregation/nixl/conn.py +17 -12
- sglang/srt/disaggregation/prefill.py +128 -43
- sglang/srt/disaggregation/utils.py +127 -123
- sglang/srt/entrypoints/engine.py +15 -1
- sglang/srt/entrypoints/http_server.py +13 -2
- sglang/srt/eplb_simulator/__init__.py +1 -0
- sglang/srt/eplb_simulator/reader.py +51 -0
- sglang/srt/layers/activation.py +19 -0
- sglang/srt/layers/attention/aiter_backend.py +15 -2
- sglang/srt/layers/attention/cutlass_mla_backend.py +38 -15
- sglang/srt/layers/attention/flashattention_backend.py +53 -64
- sglang/srt/layers/attention/flashinfer_backend.py +1 -2
- sglang/srt/layers/attention/flashinfer_mla_backend.py +22 -24
- sglang/srt/layers/attention/flashmla_backend.py +2 -10
- sglang/srt/layers/attention/triton_backend.py +119 -119
- sglang/srt/layers/attention/triton_ops/decode_attention.py +2 -7
- sglang/srt/layers/attention/vision.py +51 -24
- sglang/srt/layers/communicator.py +23 -5
- sglang/srt/layers/linear.py +0 -4
- sglang/srt/layers/logits_processor.py +0 -12
- sglang/srt/layers/moe/ep_moe/kernels.py +6 -5
- sglang/srt/layers/moe/ep_moe/layer.py +42 -32
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +11 -37
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -4
- sglang/srt/layers/moe/topk.py +16 -8
- sglang/srt/layers/pooler.py +56 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/__init__.py +1 -0
- sglang/srt/layers/quantization/{deep_gemm.py → deep_gemm_wrapper/compile_utils.py} +23 -80
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +32 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +110 -0
- sglang/srt/layers/quantization/fp8_kernel.py +44 -15
- sglang/srt/layers/quantization/fp8_utils.py +87 -22
- sglang/srt/layers/radix_attention.py +2 -3
- sglang/srt/lora/lora_manager.py +79 -34
- sglang/srt/lora/mem_pool.py +4 -5
- sglang/srt/managers/cache_controller.py +2 -1
- sglang/srt/managers/io_struct.py +28 -4
- sglang/srt/managers/multimodal_processors/base_processor.py +2 -2
- sglang/srt/managers/multimodal_processors/vila.py +85 -0
- sglang/srt/managers/schedule_batch.py +39 -6
- sglang/srt/managers/scheduler.py +73 -17
- sglang/srt/managers/tokenizer_manager.py +29 -2
- sglang/srt/mem_cache/chunk_cache.py +1 -0
- sglang/srt/mem_cache/hiradix_cache.py +4 -2
- sglang/srt/mem_cache/memory_pool.py +111 -407
- sglang/srt/mem_cache/memory_pool_host.py +380 -0
- sglang/srt/mem_cache/radix_cache.py +36 -12
- sglang/srt/model_executor/cuda_graph_runner.py +122 -55
- sglang/srt/model_executor/forward_batch_info.py +14 -5
- sglang/srt/model_executor/model_runner.py +6 -6
- sglang/srt/model_loader/loader.py +8 -1
- sglang/srt/models/bert.py +113 -13
- sglang/srt/models/deepseek_v2.py +113 -155
- sglang/srt/models/internvl.py +46 -102
- sglang/srt/models/roberta.py +117 -9
- sglang/srt/models/vila.py +305 -0
- sglang/srt/openai_api/adapter.py +162 -4
- sglang/srt/openai_api/protocol.py +37 -1
- sglang/srt/sampling/sampling_batch_info.py +24 -0
- sglang/srt/sampling/sampling_params.py +2 -0
- sglang/srt/server_args.py +318 -233
- sglang/srt/speculative/build_eagle_tree.py +1 -1
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +4 -3
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +5 -2
- sglang/srt/speculative/eagle_utils.py +389 -109
- sglang/srt/speculative/eagle_worker.py +134 -43
- sglang/srt/two_batch_overlap.py +4 -2
- sglang/srt/utils.py +58 -0
- sglang/test/attention/test_prefix_chunk_info.py +2 -0
- sglang/test/runners.py +38 -3
- sglang/test/test_block_fp8.py +1 -0
- sglang/test/test_block_fp8_deep_gemm_blackwell.py +252 -0
- sglang/test/test_block_fp8_ep.py +1 -0
- sglang/test/test_utils.py +3 -1
- sglang/utils.py +9 -0
- sglang/version.py +1 -1
- {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/METADATA +5 -5
- {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/RECORD +99 -88
- {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/top_level.txt +0 -0
@@ -14,15 +14,15 @@ import requests
|
|
14
14
|
import torch
|
15
15
|
import torch.distributed as dist
|
16
16
|
|
17
|
-
from sglang.srt.utils import get_ip
|
17
|
+
from sglang.srt.utils import get_ip
|
18
18
|
|
19
19
|
if TYPE_CHECKING:
|
20
20
|
from sglang.srt.managers.schedule_batch import Req
|
21
21
|
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
22
|
+
#########################
|
23
|
+
# Constants & Enums
|
24
|
+
#########################
|
25
|
+
FAKE_BOOTSTRAP_HOST = "2.2.2.2"
|
26
26
|
|
27
27
|
|
28
28
|
class DisaggregationMode(Enum):
|
@@ -31,6 +31,14 @@ class DisaggregationMode(Enum):
|
|
31
31
|
DECODE = "decode"
|
32
32
|
|
33
33
|
|
34
|
+
#########################
|
35
|
+
# Synchronization
|
36
|
+
#########################
|
37
|
+
|
38
|
+
# env var for testing failure, convert to float explicitly
|
39
|
+
FAILURE_PROB = float(os.getenv("DISAGGREGATION_TEST_FAILURE_PROB", 0))
|
40
|
+
|
41
|
+
|
34
42
|
def poll_and_all_reduce(pollers, gloo_group):
|
35
43
|
# at a certain prob, the poll is failed to simulate failure
|
36
44
|
if FAILURE_PROB > 0:
|
@@ -47,6 +55,11 @@ def poll_and_all_reduce(pollers, gloo_group):
|
|
47
55
|
return tensor_to_reduce.tolist()
|
48
56
|
|
49
57
|
|
58
|
+
#########################
|
59
|
+
# Metadata Buffers
|
60
|
+
#########################
|
61
|
+
|
62
|
+
|
50
63
|
class ReqToMetadataIdxAllocator:
|
51
64
|
"""A memory pool that maps a request to its first output token location."""
|
52
65
|
|
@@ -70,6 +83,91 @@ class ReqToMetadataIdxAllocator:
|
|
70
83
|
self.free_slots.append(free_index)
|
71
84
|
|
72
85
|
|
86
|
+
class MetadataBuffers:
|
87
|
+
def __init__(self, size: int, max_top_logprobs_num: int = 128):
|
88
|
+
# TODO: abort top_logprobs_num > 128 in PD
|
89
|
+
|
90
|
+
# We transfer the metadata of first output token to decode
|
91
|
+
# The minimal size for RDMA is 64Bytes, so we pad it to > 64Bytes
|
92
|
+
self.output_ids = torch.zeros((size, 16), dtype=torch.int32, device="cpu")
|
93
|
+
self.output_token_logprobs_val = torch.zeros(
|
94
|
+
(size, 16), dtype=torch.float32, device="cpu"
|
95
|
+
)
|
96
|
+
self.output_token_logprobs_idx = torch.zeros(
|
97
|
+
(size, 16), dtype=torch.int32, device="cpu"
|
98
|
+
)
|
99
|
+
self.output_top_logprobs_val = torch.zeros(
|
100
|
+
(size, max_top_logprobs_num), dtype=torch.float32, device="cpu"
|
101
|
+
)
|
102
|
+
self.output_top_logprobs_idx = torch.zeros(
|
103
|
+
(size, max_top_logprobs_num), dtype=torch.int32, device="cpu"
|
104
|
+
)
|
105
|
+
|
106
|
+
def get_buf_infos(self):
|
107
|
+
ptrs = [
|
108
|
+
self.output_ids.data_ptr(),
|
109
|
+
self.output_token_logprobs_val.data_ptr(),
|
110
|
+
self.output_token_logprobs_idx.data_ptr(),
|
111
|
+
self.output_top_logprobs_val.data_ptr(),
|
112
|
+
self.output_top_logprobs_idx.data_ptr(),
|
113
|
+
]
|
114
|
+
data_lens = [
|
115
|
+
self.output_ids.nbytes,
|
116
|
+
self.output_token_logprobs_val.nbytes,
|
117
|
+
self.output_token_logprobs_idx.nbytes,
|
118
|
+
self.output_top_logprobs_val.nbytes,
|
119
|
+
self.output_top_logprobs_idx.nbytes,
|
120
|
+
]
|
121
|
+
item_lens = [
|
122
|
+
self.output_ids[0].nbytes,
|
123
|
+
self.output_token_logprobs_val[0].nbytes,
|
124
|
+
self.output_token_logprobs_idx[0].nbytes,
|
125
|
+
self.output_top_logprobs_val[0].nbytes,
|
126
|
+
self.output_top_logprobs_idx[0].nbytes,
|
127
|
+
]
|
128
|
+
return ptrs, data_lens, item_lens
|
129
|
+
|
130
|
+
def get_buf(self, idx: int):
|
131
|
+
return (
|
132
|
+
self.output_ids[idx],
|
133
|
+
self.output_token_logprobs_val[idx],
|
134
|
+
self.output_token_logprobs_idx[idx],
|
135
|
+
self.output_top_logprobs_val[idx],
|
136
|
+
self.output_top_logprobs_idx[idx],
|
137
|
+
)
|
138
|
+
|
139
|
+
def set_buf(self, req: Req):
|
140
|
+
|
141
|
+
self.output_ids[req.metadata_buffer_index][0] = req.output_ids[0]
|
142
|
+
if req.return_logprob:
|
143
|
+
if req.output_token_logprobs_val: # not none or empty list
|
144
|
+
self.output_token_logprobs_val[req.metadata_buffer_index][0] = (
|
145
|
+
req.output_token_logprobs_val[0]
|
146
|
+
)
|
147
|
+
if req.output_token_logprobs_idx: # not none or empty list
|
148
|
+
self.output_token_logprobs_idx[req.metadata_buffer_index][0] = (
|
149
|
+
req.output_token_logprobs_idx[0]
|
150
|
+
)
|
151
|
+
|
152
|
+
if req.output_top_logprobs_val: # not none or empty list
|
153
|
+
self.output_top_logprobs_val[req.metadata_buffer_index][
|
154
|
+
: len(req.output_top_logprobs_val[0])
|
155
|
+
] = torch.tensor(
|
156
|
+
req.output_top_logprobs_val[0], dtype=torch.float32, device="cpu"
|
157
|
+
)
|
158
|
+
if req.output_top_logprobs_idx: # not none or empty list
|
159
|
+
self.output_top_logprobs_idx[req.metadata_buffer_index][
|
160
|
+
: len(req.output_top_logprobs_idx[0])
|
161
|
+
] = torch.tensor(
|
162
|
+
req.output_top_logprobs_idx[0], dtype=torch.int32, device="cpu"
|
163
|
+
)
|
164
|
+
|
165
|
+
|
166
|
+
#########################
|
167
|
+
# Transfer Backend
|
168
|
+
#########################
|
169
|
+
|
170
|
+
|
73
171
|
class TransferBackend(Enum):
|
74
172
|
MOONCAKE = "mooncake"
|
75
173
|
NIXL = "nixl"
|
@@ -77,6 +175,7 @@ class TransferBackend(Enum):
|
|
77
175
|
|
78
176
|
|
79
177
|
class KVClassType(Enum):
|
178
|
+
KVARGS = "kvargs"
|
80
179
|
MANAGER = "manager"
|
81
180
|
SENDER = "sender"
|
82
181
|
RECEIVER = "receiver"
|
@@ -87,6 +186,7 @@ def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
|
|
87
186
|
from sglang.srt.disaggregation.fake import FakeKVReceiver, FakeKVSender
|
88
187
|
|
89
188
|
if transfer_backend == TransferBackend.MOONCAKE:
|
189
|
+
from sglang.srt.disaggregation.base import KVArgs
|
90
190
|
from sglang.srt.disaggregation.mooncake import (
|
91
191
|
MooncakeKVBootstrapServer,
|
92
192
|
MooncakeKVManager,
|
@@ -95,13 +195,15 @@ def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
|
|
95
195
|
)
|
96
196
|
|
97
197
|
class_mapping = {
|
198
|
+
KVClassType.KVARGS: KVArgs,
|
98
199
|
KVClassType.MANAGER: MooncakeKVManager,
|
99
200
|
KVClassType.SENDER: MooncakeKVSender,
|
100
201
|
KVClassType.RECEIVER: (MooncakeKVReceiver),
|
101
202
|
KVClassType.BOOTSTRAP_SERVER: MooncakeKVBootstrapServer,
|
102
203
|
}
|
103
204
|
return class_mapping.get(class_type)
|
104
|
-
|
205
|
+
elif transfer_backend == TransferBackend.NIXL:
|
206
|
+
from sglang.srt.disaggregation.base import KVArgs
|
105
207
|
from sglang.srt.disaggregation.nixl import (
|
106
208
|
NixlKVBootstrapServer,
|
107
209
|
NixlKVManager,
|
@@ -110,16 +212,19 @@ def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
|
|
110
212
|
)
|
111
213
|
|
112
214
|
class_mapping = {
|
215
|
+
KVClassType.KVARGS: KVArgs,
|
113
216
|
KVClassType.MANAGER: NixlKVManager,
|
114
217
|
KVClassType.SENDER: NixlKVSender,
|
115
218
|
KVClassType.RECEIVER: (NixlKVReceiver),
|
116
219
|
KVClassType.BOOTSTRAP_SERVER: NixlKVBootstrapServer,
|
117
220
|
}
|
118
221
|
return class_mapping.get(class_type)
|
119
|
-
|
222
|
+
elif transfer_backend == TransferBackend.FAKE:
|
223
|
+
from sglang.srt.disaggregation.base import KVArgs
|
120
224
|
from sglang.srt.disaggregation.fake import FakeKVReceiver, FakeKVSender
|
121
225
|
|
122
226
|
class_mapping = {
|
227
|
+
KVClassType.KVARGS: KVArgs,
|
123
228
|
KVClassType.SENDER: FakeKVSender,
|
124
229
|
KVClassType.RECEIVER: (FakeKVReceiver),
|
125
230
|
}
|
@@ -128,6 +233,11 @@ def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
|
|
128
233
|
raise ValueError(f"Unsupported transfer backend: {transfer_backend}")
|
129
234
|
|
130
235
|
|
236
|
+
#########################
|
237
|
+
# KV Pages
|
238
|
+
#########################
|
239
|
+
|
240
|
+
|
131
241
|
def kv_to_page_indices(kv_indices: np.ndarray, page_size: int):
|
132
242
|
# 1. The page is guaranteed to be full except the last page.
|
133
243
|
# 2. page index = kv_index // page_size
|
@@ -143,6 +253,11 @@ def kv_to_page_num(num_kv_indices: int, page_size: int):
|
|
143
253
|
return (num_kv_indices + page_size - 1) // page_size
|
144
254
|
|
145
255
|
|
256
|
+
#########################
|
257
|
+
# PDLB Registry
|
258
|
+
#########################
|
259
|
+
|
260
|
+
|
146
261
|
@dataclasses.dataclass
|
147
262
|
class PDRegistryRequest:
|
148
263
|
"""A request to register a machine itself to the LB."""
|
@@ -181,6 +296,11 @@ def register_disaggregation_server(
|
|
181
296
|
)
|
182
297
|
|
183
298
|
|
299
|
+
#########################
|
300
|
+
# Misc
|
301
|
+
#########################
|
302
|
+
|
303
|
+
|
184
304
|
def is_mla_backend(target_kv_pool) -> bool:
|
185
305
|
from sglang.srt.mem_cache.memory_pool import MLATokenToKVPool
|
186
306
|
|
@@ -200,119 +320,3 @@ def prepare_abort(req: Req, error_message: str, status_code=None):
|
|
200
320
|
req.input_top_logprobs_idx = []
|
201
321
|
req.input_token_ids_logprobs_val = []
|
202
322
|
req.input_token_ids_logprobs_idx = []
|
203
|
-
|
204
|
-
|
205
|
-
class MetadataBuffers:
|
206
|
-
def __init__(self, size: int, max_top_logprobs_num: int = 128):
|
207
|
-
# TODO: abort top_logprobs_num > 128 in PD
|
208
|
-
|
209
|
-
# We transfer the metadata of first output token to decode
|
210
|
-
# The minimal size for RDMA is 64Bytes, so we pad it to > 64Bytes
|
211
|
-
self.output_ids = torch.zeros((size, 16), dtype=torch.int32, device="cpu")
|
212
|
-
self.output_token_logprobs_val = torch.zeros(
|
213
|
-
(size, 16), dtype=torch.float32, device="cpu"
|
214
|
-
)
|
215
|
-
self.output_token_logprobs_idx = torch.zeros(
|
216
|
-
(size, 16), dtype=torch.int32, device="cpu"
|
217
|
-
)
|
218
|
-
self.output_top_logprobs_val = torch.zeros(
|
219
|
-
(size, max_top_logprobs_num), dtype=torch.float32, device="cpu"
|
220
|
-
)
|
221
|
-
self.output_top_logprobs_idx = torch.zeros(
|
222
|
-
(size, max_top_logprobs_num), dtype=torch.int32, device="cpu"
|
223
|
-
)
|
224
|
-
|
225
|
-
def get_buf_infos(self):
|
226
|
-
ptrs = [
|
227
|
-
self.output_ids.data_ptr(),
|
228
|
-
self.output_token_logprobs_val.data_ptr(),
|
229
|
-
self.output_token_logprobs_idx.data_ptr(),
|
230
|
-
self.output_top_logprobs_val.data_ptr(),
|
231
|
-
self.output_top_logprobs_idx.data_ptr(),
|
232
|
-
]
|
233
|
-
data_lens = [
|
234
|
-
self.output_ids.nbytes,
|
235
|
-
self.output_token_logprobs_val.nbytes,
|
236
|
-
self.output_token_logprobs_idx.nbytes,
|
237
|
-
self.output_top_logprobs_val.nbytes,
|
238
|
-
self.output_top_logprobs_idx.nbytes,
|
239
|
-
]
|
240
|
-
item_lens = [
|
241
|
-
self.output_ids[0].nbytes,
|
242
|
-
self.output_token_logprobs_val[0].nbytes,
|
243
|
-
self.output_token_logprobs_idx[0].nbytes,
|
244
|
-
self.output_top_logprobs_val[0].nbytes,
|
245
|
-
self.output_top_logprobs_idx[0].nbytes,
|
246
|
-
]
|
247
|
-
return ptrs, data_lens, item_lens
|
248
|
-
|
249
|
-
def get_buf(self, idx: int):
|
250
|
-
return (
|
251
|
-
self.output_ids[idx],
|
252
|
-
self.output_token_logprobs_val[idx],
|
253
|
-
self.output_token_logprobs_idx[idx],
|
254
|
-
self.output_top_logprobs_val[idx],
|
255
|
-
self.output_top_logprobs_idx[idx],
|
256
|
-
)
|
257
|
-
|
258
|
-
def set_buf(self, req: Req):
|
259
|
-
|
260
|
-
self.output_ids[req.metadata_buffer_index][0] = req.output_ids[0]
|
261
|
-
if req.return_logprob:
|
262
|
-
if req.output_token_logprobs_val: # not none or empty list
|
263
|
-
self.output_token_logprobs_val[req.metadata_buffer_index][0] = (
|
264
|
-
req.output_token_logprobs_val[0]
|
265
|
-
)
|
266
|
-
if req.output_token_logprobs_idx: # not none or empty list
|
267
|
-
self.output_token_logprobs_idx[req.metadata_buffer_index][0] = (
|
268
|
-
req.output_token_logprobs_idx[0]
|
269
|
-
)
|
270
|
-
|
271
|
-
if req.output_top_logprobs_val: # not none or empty list
|
272
|
-
self.output_top_logprobs_val[req.metadata_buffer_index][
|
273
|
-
: len(req.output_top_logprobs_val[0])
|
274
|
-
] = torch.tensor(
|
275
|
-
req.output_top_logprobs_val[0], dtype=torch.float32, device="cpu"
|
276
|
-
)
|
277
|
-
if req.output_top_logprobs_idx: # not none or empty list
|
278
|
-
self.output_top_logprobs_idx[req.metadata_buffer_index][
|
279
|
-
: len(req.output_top_logprobs_idx[0])
|
280
|
-
] = torch.tensor(
|
281
|
-
req.output_top_logprobs_idx[0], dtype=torch.int32, device="cpu"
|
282
|
-
)
|
283
|
-
|
284
|
-
|
285
|
-
class FastQueue:
|
286
|
-
def __init__(self):
|
287
|
-
self._buf = deque()
|
288
|
-
self._cond = threading.Condition()
|
289
|
-
|
290
|
-
def put(self, item):
|
291
|
-
with self._cond:
|
292
|
-
self._buf.append(item)
|
293
|
-
# wake up a thread of wait()
|
294
|
-
self._cond.notify()
|
295
|
-
|
296
|
-
def get(self):
|
297
|
-
with self._cond:
|
298
|
-
# if queue is empty ,block until is notified()
|
299
|
-
while not self._buf:
|
300
|
-
self._cond.wait()
|
301
|
-
return self._buf.popleft()
|
302
|
-
|
303
|
-
|
304
|
-
def group_concurrent_contiguous(
|
305
|
-
src_indices: npt.NDArray[np.int64], dst_indices: npt.NDArray[np.int64]
|
306
|
-
) -> Tuple[List[npt.NDArray[np.int64]], List[npt.NDArray[np.int64]]]:
|
307
|
-
"""Vectorised NumPy implementation."""
|
308
|
-
if src_indices.size == 0:
|
309
|
-
return [], []
|
310
|
-
|
311
|
-
brk = np.where((np.diff(src_indices) != 1) | (np.diff(dst_indices) != 1))[0] + 1
|
312
|
-
src_groups = np.split(src_indices, brk)
|
313
|
-
dst_groups = np.split(dst_indices, brk)
|
314
|
-
|
315
|
-
src_groups = [g.tolist() for g in src_groups]
|
316
|
-
dst_groups = [g.tolist() for g in dst_groups]
|
317
|
-
|
318
|
-
return src_groups, dst_groups
|
sglang/srt/entrypoints/engine.py
CHANGED
@@ -327,6 +327,20 @@ class Engine(EngineBase):
|
|
327
327
|
generator = self.tokenizer_manager.generate_request(obj, None)
|
328
328
|
return await generator.__anext__()
|
329
329
|
|
330
|
+
def rerank(
|
331
|
+
self,
|
332
|
+
prompt: Union[List[List[str]]],
|
333
|
+
) -> Dict:
|
334
|
+
"""
|
335
|
+
The arguments of this function is the same as `sglang/srt/managers/io_struct.py::EmbeddingReqInput`.
|
336
|
+
Please refer to `EmbeddingReqInput` for the documentation.
|
337
|
+
"""
|
338
|
+
obj = EmbeddingReqInput(text=prompt, is_cross_encoder_request=True)
|
339
|
+
loop = asyncio.get_event_loop()
|
340
|
+
generator = self.tokenizer_manager.generate_request(obj, None)
|
341
|
+
ret = loop.run_until_complete(generator.__anext__())
|
342
|
+
return ret
|
343
|
+
|
330
344
|
def shutdown(self):
|
331
345
|
"""Shutdown the engine"""
|
332
346
|
kill_process_tree(os.getpid(), include_parent=False)
|
@@ -605,7 +619,7 @@ def _set_envs_and_config(server_args: ServerArgs):
|
|
605
619
|
if _is_cuda:
|
606
620
|
assert_pkg_version(
|
607
621
|
"sgl-kernel",
|
608
|
-
"0.1.
|
622
|
+
"0.1.9",
|
609
623
|
"Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`",
|
610
624
|
)
|
611
625
|
|
@@ -43,7 +43,7 @@ from fastapi.middleware.cors import CORSMiddleware
|
|
43
43
|
from fastapi.responses import ORJSONResponse, Response, StreamingResponse
|
44
44
|
|
45
45
|
from sglang.srt.disaggregation.utils import (
|
46
|
-
|
46
|
+
FAKE_BOOTSTRAP_HOST,
|
47
47
|
register_disaggregation_server,
|
48
48
|
)
|
49
49
|
from sglang.srt.entrypoints.engine import _launch_subprocesses
|
@@ -67,6 +67,7 @@ from sglang.srt.managers.io_struct import (
|
|
67
67
|
UpdateWeightFromDiskReqInput,
|
68
68
|
UpdateWeightsFromDistributedReqInput,
|
69
69
|
UpdateWeightsFromTensorReqInput,
|
70
|
+
V1RerankReqInput,
|
70
71
|
VertexGenerateReqInput,
|
71
72
|
)
|
72
73
|
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
@@ -79,6 +80,7 @@ from sglang.srt.openai_api.adapter import (
|
|
79
80
|
v1_delete_file,
|
80
81
|
v1_embeddings,
|
81
82
|
v1_files_create,
|
83
|
+
v1_rerank,
|
82
84
|
v1_retrieve_batch,
|
83
85
|
v1_retrieve_file,
|
84
86
|
v1_retrieve_file_content,
|
@@ -328,6 +330,15 @@ async def classify_request(obj: EmbeddingReqInput, request: Request):
|
|
328
330
|
return _create_error_response(e)
|
329
331
|
|
330
332
|
|
333
|
+
@app.api_route("/v1/rerank", methods=["POST", "PUT"])
|
334
|
+
async def v1_rerank_request(obj: V1RerankReqInput, raw_request: Request):
|
335
|
+
try:
|
336
|
+
ret = await v1_rerank(_global_state.tokenizer_manager, obj, raw_request)
|
337
|
+
return ret
|
338
|
+
except ValueError as e:
|
339
|
+
return _create_error_response(e)
|
340
|
+
|
341
|
+
|
331
342
|
@app.api_route("/flush_cache", methods=["GET", "POST"])
|
332
343
|
async def flush_cache():
|
333
344
|
"""Flush the radix cache."""
|
@@ -878,7 +889,7 @@ def _wait_and_warmup(
|
|
878
889
|
"max_new_tokens": 8,
|
879
890
|
"ignore_eos": True,
|
880
891
|
},
|
881
|
-
"bootstrap_host": [
|
892
|
+
"bootstrap_host": [FAKE_BOOTSTRAP_HOST] * server_args.dp_size,
|
882
893
|
# This is a hack to ensure fake transfer is enabled during prefill warmup
|
883
894
|
# ensure each dp rank has a unique bootstrap_room during prefill warmup
|
884
895
|
"bootstrap_room": [
|
@@ -0,0 +1 @@
|
|
1
|
+
from . import reader
|
@@ -0,0 +1,51 @@
|
|
1
|
+
from collections import defaultdict
|
2
|
+
from pathlib import Path
|
3
|
+
|
4
|
+
import torch
|
5
|
+
from tqdm import tqdm
|
6
|
+
|
7
|
+
from sglang.srt.managers.expert_distribution import (
|
8
|
+
_convert_global_physical_count_to_logical_count,
|
9
|
+
)
|
10
|
+
|
11
|
+
convert_global_physical_count_to_logical_count = (
|
12
|
+
_convert_global_physical_count_to_logical_count
|
13
|
+
)
|
14
|
+
|
15
|
+
|
16
|
+
def read_mode_per_pass(dir_data: Path):
|
17
|
+
"""Read data from ExpertDistributionRecorder when recorded with mode `per_pass`"""
|
18
|
+
|
19
|
+
# gpc := global_physical_count
|
20
|
+
gpc_of_forward_pass_and_rank = defaultdict(lambda: defaultdict())
|
21
|
+
for path in tqdm(list(dir_data.glob("*.pt"))):
|
22
|
+
data_pack = torch.load(path, weights_only=True)
|
23
|
+
last_physical_to_logical_map = data_pack["last_physical_to_logical_map"]
|
24
|
+
for record in data_pack["records"]:
|
25
|
+
forward_pass_id = record["forward_pass_id"]
|
26
|
+
rank = record["rank"]
|
27
|
+
assert (
|
28
|
+
gpc_of_forward_pass_and_rank[forward_pass_id].get(rank) is None
|
29
|
+
), f"Duplicated {forward_pass_id=} {rank=}"
|
30
|
+
gpc_of_forward_pass_and_rank[forward_pass_id][rank] = record[
|
31
|
+
"global_physical_count"
|
32
|
+
]
|
33
|
+
|
34
|
+
forward_pass_ids = sorted(gpc_of_forward_pass_and_rank.keys())
|
35
|
+
print(f"Make {forward_pass_ids=} into array")
|
36
|
+
|
37
|
+
items = []
|
38
|
+
for forward_pass_id, gpc_of_rank in sorted(gpc_of_forward_pass_and_rank.items()):
|
39
|
+
gpc_of_rank_tensor = torch.stack(
|
40
|
+
[gpc for rank, gpc in sorted(gpc_of_rank.items())]
|
41
|
+
).sum(dim=0)
|
42
|
+
items.append(gpc_of_rank_tensor)
|
43
|
+
|
44
|
+
gpc_of_forward_pass = torch.stack(items)
|
45
|
+
print(f"{gpc_of_forward_pass.shape=}")
|
46
|
+
|
47
|
+
return dict(
|
48
|
+
global_physical_count_of_forward_pass=gpc_of_forward_pass,
|
49
|
+
last_physical_to_logical_map=last_physical_to_logical_map,
|
50
|
+
forward_pass_ids=forward_pass_ids,
|
51
|
+
)
|
sglang/srt/layers/activation.py
CHANGED
@@ -20,6 +20,7 @@ from typing import Optional
|
|
20
20
|
import torch
|
21
21
|
import torch.nn as nn
|
22
22
|
import torch.nn.functional as F
|
23
|
+
from transformers import PretrainedConfig
|
23
24
|
|
24
25
|
from sglang.srt.custom_op import CustomOp
|
25
26
|
from sglang.srt.distributed import (
|
@@ -29,6 +30,7 @@ from sglang.srt.distributed import (
|
|
29
30
|
)
|
30
31
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
31
32
|
from sglang.srt.utils import is_cuda, set_weight_attrs
|
33
|
+
from sglang.utils import resolve_obj_by_qualname
|
32
34
|
|
33
35
|
_is_cuda = is_cuda()
|
34
36
|
|
@@ -165,6 +167,23 @@ def get_act_fn(
|
|
165
167
|
return act_fn
|
166
168
|
|
167
169
|
|
170
|
+
def get_cross_encoder_activation_function(config: PretrainedConfig):
|
171
|
+
if (
|
172
|
+
hasattr(config, "sbert_ce_default_activation_function")
|
173
|
+
and config.sbert_ce_default_activation_function is not None
|
174
|
+
):
|
175
|
+
|
176
|
+
function_name = config.sbert_ce_default_activation_function
|
177
|
+
assert function_name.startswith("torch.nn.modules."), (
|
178
|
+
"Loading of activation functions is restricted to "
|
179
|
+
"torch.nn.modules for security reasons"
|
180
|
+
)
|
181
|
+
return resolve_obj_by_qualname(function_name)()
|
182
|
+
else:
|
183
|
+
# adapt bge-reranker
|
184
|
+
return nn.Identity()
|
185
|
+
|
186
|
+
|
168
187
|
if not _is_cuda:
|
169
188
|
logger.info(
|
170
189
|
"sgl-kernel is not available on Non-NV platforms. Fallback to other kernel libraries."
|
@@ -717,6 +717,11 @@ class AiterIndicesUpdaterPrefill:
|
|
717
717
|
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
718
718
|
self.update = self.update_single_wrapper
|
719
719
|
|
720
|
+
# get the last index of the pool
|
721
|
+
self.pool_size = (
|
722
|
+
model_runner.token_to_kv_pool.size + model_runner.token_to_kv_pool.page_size
|
723
|
+
) - 1
|
724
|
+
|
720
725
|
self.kv_indices = None
|
721
726
|
self.max_q_len = 0
|
722
727
|
self.max_kv_len = 0
|
@@ -754,8 +759,16 @@ class AiterIndicesUpdaterPrefill:
|
|
754
759
|
# Normal extend
|
755
760
|
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
|
756
761
|
kv_indptr = kv_indptr[: bs + 1]
|
757
|
-
|
758
|
-
|
762
|
+
|
763
|
+
# (TODO: Kk) WA - CI test_moe_eval_accuracy_large.py
|
764
|
+
# mha_batch_prefill reads 128 data to do computatoin
|
765
|
+
# if real data is not long enough then original padding value 0 is used
|
766
|
+
# but the 0 location will be made nan (noqa) in cuda graph capture mode
|
767
|
+
# this will cause the output tensor value becomes nan
|
768
|
+
# WA is to assure that last index of pool not changed
|
769
|
+
kv_indices = torch.full(
|
770
|
+
(paged_kernel_lens_sum + 128,),
|
771
|
+
self.pool_size,
|
759
772
|
dtype=torch.int32,
|
760
773
|
device=req_pool_indices.device,
|
761
774
|
)
|
@@ -11,8 +11,6 @@ from typing import TYPE_CHECKING, Optional, Union
|
|
11
11
|
import torch
|
12
12
|
import triton
|
13
13
|
|
14
|
-
from sglang.global_config import global_config
|
15
|
-
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
16
14
|
from sglang.srt.layers.attention.flashinfer_mla_backend import FlashInferMLAAttnBackend
|
17
15
|
from sglang.srt.layers.attention.utils import create_flashmla_kv_indices_triton
|
18
16
|
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
@@ -22,7 +20,6 @@ from sglang.srt.utils import is_cuda
|
|
22
20
|
if TYPE_CHECKING:
|
23
21
|
from sglang.srt.layers.radix_attention import RadixAttention
|
24
22
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
25
|
-
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
26
23
|
from sglang.srt.speculative.spec_info import SpecInfo
|
27
24
|
|
28
25
|
_is_cuda = is_cuda()
|
@@ -108,7 +105,7 @@ class CutlassMLABackend(FlashInferMLAAttnBackend):
|
|
108
105
|
PAGE_SIZE,
|
109
106
|
)
|
110
107
|
workspace_size = cutlass_mla_get_workspace_size(
|
111
|
-
max_seqlen_pad * PAGE_SIZE, bs
|
108
|
+
max_seqlen_pad * PAGE_SIZE, bs, num_kv_splits=1
|
112
109
|
)
|
113
110
|
workspace = torch.empty(
|
114
111
|
workspace_size, device="cuda", dtype=torch.uint8
|
@@ -138,7 +135,7 @@ class CutlassMLABackend(FlashInferMLAAttnBackend):
|
|
138
135
|
cuda_graph_kv_indices = block_kv_indices
|
139
136
|
|
140
137
|
workspace_size = cutlass_mla_get_workspace_size(
|
141
|
-
cuda_graph_kv_indices.shape[1] * PAGE_SIZE, max_bs
|
138
|
+
cuda_graph_kv_indices.shape[1] * PAGE_SIZE, max_bs, num_kv_splits=1
|
142
139
|
)
|
143
140
|
self.cuda_graph_mla_workspace = torch.empty(
|
144
141
|
workspace_size, device="cuda", dtype=torch.uint8
|
@@ -233,29 +230,55 @@ class CutlassMLABackend(FlashInferMLAAttnBackend):
|
|
233
230
|
layer: RadixAttention,
|
234
231
|
forward_batch: ForwardBatch,
|
235
232
|
save_kv_cache: bool = True,
|
233
|
+
# For multi-head latent attention
|
234
|
+
q_rope: Optional[torch.Tensor] = None,
|
235
|
+
k_rope: Optional[torch.Tensor] = None,
|
236
236
|
):
|
237
237
|
cache_loc = forward_batch.out_cache_loc
|
238
238
|
|
239
239
|
if k is not None:
|
240
240
|
assert v is not None
|
241
241
|
if save_kv_cache:
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
242
|
+
if k_rope is not None:
|
243
|
+
forward_batch.token_to_kv_pool.set_mla_kv_buffer(
|
244
|
+
layer,
|
245
|
+
cache_loc,
|
246
|
+
k,
|
247
|
+
k_rope,
|
248
|
+
)
|
249
|
+
else:
|
250
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
251
|
+
layer,
|
252
|
+
cache_loc,
|
253
|
+
k,
|
254
|
+
v,
|
255
|
+
)
|
256
|
+
|
257
|
+
# Reshape inputs
|
258
|
+
if q_rope is not None:
|
259
|
+
q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
|
260
|
+
q_rope = q_rope.view(
|
261
|
+
-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
|
262
|
+
)
|
263
|
+
else:
|
264
|
+
reshaped_q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
265
|
+
q_nope = reshaped_q[:, :, : layer.v_head_dim]
|
266
|
+
q_rope = reshaped_q[:, :, layer.v_head_dim :]
|
250
267
|
|
251
|
-
|
268
|
+
q_nope = q_nope.to(self.q_data_type)
|
269
|
+
q_rope = q_rope.to(self.q_data_type)
|
270
|
+
|
271
|
+
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
252
272
|
|
253
273
|
o = cutlass_mla_decode(
|
254
|
-
|
274
|
+
q_nope=q_nope,
|
275
|
+
q_pe=q_rope,
|
255
276
|
kv_c_and_k_pe_cache=k_cache.view(-1, PAGE_SIZE, self.kv_cache_dim),
|
256
277
|
seq_lens=forward_batch.seq_lens.to(torch.int32),
|
257
278
|
page_table=self.forward_metadata.block_kv_indices,
|
258
279
|
workspace=self.forward_metadata.workspace,
|
280
|
+
sm_scale=layer.scaling,
|
281
|
+
num_kv_splits=1,
|
259
282
|
)
|
260
283
|
|
261
284
|
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|