sglang 0.4.7__py3-none-any.whl → 0.4.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -0
- sglang/api.py +7 -0
- sglang/bench_one_batch.py +8 -6
- sglang/bench_serving.py +1 -1
- sglang/lang/interpreter.py +40 -1
- sglang/lang/ir.py +27 -0
- sglang/math_utils.py +8 -0
- sglang/srt/_custom_ops.py +2 -2
- sglang/srt/code_completion_parser.py +2 -44
- sglang/srt/configs/model_config.py +6 -0
- sglang/srt/constants.py +3 -0
- sglang/srt/conversation.py +19 -3
- sglang/srt/custom_op.py +5 -1
- sglang/srt/disaggregation/base/__init__.py +1 -1
- sglang/srt/disaggregation/base/conn.py +25 -11
- sglang/srt/disaggregation/common/__init__.py +5 -1
- sglang/srt/disaggregation/common/utils.py +42 -0
- sglang/srt/disaggregation/decode.py +211 -72
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
- sglang/srt/disaggregation/fake/__init__.py +1 -1
- sglang/srt/disaggregation/fake/conn.py +15 -9
- sglang/srt/disaggregation/mini_lb.py +34 -4
- sglang/srt/disaggregation/mooncake/__init__.py +1 -1
- sglang/srt/disaggregation/mooncake/conn.py +30 -29
- sglang/srt/disaggregation/nixl/__init__.py +6 -1
- sglang/srt/disaggregation/nixl/conn.py +17 -12
- sglang/srt/disaggregation/prefill.py +144 -55
- sglang/srt/disaggregation/utils.py +155 -123
- sglang/srt/distributed/parallel_state.py +12 -4
- sglang/srt/entrypoints/engine.py +37 -29
- sglang/srt/entrypoints/http_server.py +153 -72
- sglang/srt/entrypoints/http_server_engine.py +0 -3
- sglang/srt/entrypoints/openai/__init__.py +0 -0
- sglang/srt/{openai_api → entrypoints/openai}/protocol.py +84 -10
- sglang/srt/entrypoints/openai/serving_base.py +149 -0
- sglang/srt/entrypoints/openai/serving_chat.py +921 -0
- sglang/srt/entrypoints/openai/serving_completions.py +424 -0
- sglang/srt/entrypoints/openai/serving_embedding.py +169 -0
- sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
- sglang/srt/entrypoints/openai/serving_score.py +61 -0
- sglang/srt/entrypoints/openai/usage_processor.py +81 -0
- sglang/srt/entrypoints/openai/utils.py +72 -0
- sglang/srt/eplb_simulator/__init__.py +1 -0
- sglang/srt/eplb_simulator/reader.py +51 -0
- sglang/srt/function_call/base_format_detector.py +7 -4
- sglang/srt/function_call/deepseekv3_detector.py +1 -1
- sglang/srt/function_call/ebnf_composer.py +64 -10
- sglang/srt/function_call/function_call_parser.py +6 -6
- sglang/srt/function_call/llama32_detector.py +1 -1
- sglang/srt/function_call/mistral_detector.py +1 -1
- sglang/srt/function_call/pythonic_detector.py +1 -1
- sglang/srt/function_call/qwen25_detector.py +1 -1
- sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
- sglang/srt/layers/activation.py +40 -3
- sglang/srt/layers/attention/aiter_backend.py +20 -4
- sglang/srt/layers/attention/base_attn_backend.py +1 -1
- sglang/srt/layers/attention/cutlass_mla_backend.py +39 -15
- sglang/srt/layers/attention/flashattention_backend.py +71 -72
- sglang/srt/layers/attention/flashinfer_backend.py +10 -8
- sglang/srt/layers/attention/flashinfer_mla_backend.py +29 -28
- sglang/srt/layers/attention/flashmla_backend.py +7 -12
- sglang/srt/layers/attention/tbo_backend.py +3 -3
- sglang/srt/layers/attention/triton_backend.py +138 -130
- sglang/srt/layers/attention/triton_ops/decode_attention.py +2 -7
- sglang/srt/layers/attention/vision.py +51 -24
- sglang/srt/layers/communicator.py +28 -10
- sglang/srt/layers/dp_attention.py +11 -2
- sglang/srt/layers/layernorm.py +29 -2
- sglang/srt/layers/linear.py +0 -4
- sglang/srt/layers/logits_processor.py +2 -14
- sglang/srt/layers/moe/ep_moe/kernels.py +165 -7
- sglang/srt/layers/moe/ep_moe/layer.py +249 -33
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +11 -37
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +7 -4
- sglang/srt/layers/moe/fused_moe_triton/layer.py +75 -12
- sglang/srt/layers/moe/topk.py +107 -12
- sglang/srt/layers/pooler.py +56 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
- sglang/srt/layers/quantization/deep_gemm_wrapper/__init__.py +1 -0
- sglang/srt/layers/quantization/{deep_gemm.py → deep_gemm_wrapper/compile_utils.py} +23 -80
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +32 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +110 -0
- sglang/srt/layers/quantization/fp8.py +25 -17
- sglang/srt/layers/quantization/fp8_kernel.py +44 -15
- sglang/srt/layers/quantization/fp8_utils.py +87 -22
- sglang/srt/layers/quantization/modelopt_quant.py +62 -8
- sglang/srt/layers/quantization/utils.py +5 -2
- sglang/srt/layers/radix_attention.py +2 -3
- sglang/srt/layers/rotary_embedding.py +42 -2
- sglang/srt/layers/sampler.py +1 -1
- sglang/srt/lora/lora_manager.py +249 -105
- sglang/srt/lora/mem_pool.py +53 -50
- sglang/srt/lora/utils.py +1 -1
- sglang/srt/managers/cache_controller.py +33 -14
- sglang/srt/managers/io_struct.py +31 -10
- sglang/srt/managers/multimodal_processors/base_processor.py +2 -2
- sglang/srt/managers/multimodal_processors/vila.py +85 -0
- sglang/srt/managers/schedule_batch.py +79 -37
- sglang/srt/managers/schedule_policy.py +70 -56
- sglang/srt/managers/scheduler.py +220 -79
- sglang/srt/managers/template_manager.py +226 -0
- sglang/srt/managers/tokenizer_manager.py +40 -10
- sglang/srt/managers/tp_worker.py +12 -2
- sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
- sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
- sglang/srt/mem_cache/base_prefix_cache.py +52 -8
- sglang/srt/mem_cache/chunk_cache.py +11 -15
- sglang/srt/mem_cache/hiradix_cache.py +38 -25
- sglang/srt/mem_cache/memory_pool.py +213 -505
- sglang/srt/mem_cache/memory_pool_host.py +380 -0
- sglang/srt/mem_cache/radix_cache.py +56 -28
- sglang/srt/model_executor/cuda_graph_runner.py +198 -100
- sglang/srt/model_executor/forward_batch_info.py +32 -10
- sglang/srt/model_executor/model_runner.py +28 -12
- sglang/srt/model_loader/loader.py +16 -2
- sglang/srt/model_loader/weight_utils.py +11 -2
- sglang/srt/models/bert.py +113 -13
- sglang/srt/models/deepseek_nextn.py +29 -27
- sglang/srt/models/deepseek_v2.py +213 -173
- sglang/srt/models/glm4.py +312 -0
- sglang/srt/models/internvl.py +46 -102
- sglang/srt/models/mimo_mtp.py +2 -18
- sglang/srt/models/roberta.py +117 -9
- sglang/srt/models/vila.py +305 -0
- sglang/srt/reasoning_parser.py +21 -11
- sglang/srt/sampling/sampling_batch_info.py +24 -0
- sglang/srt/sampling/sampling_params.py +2 -0
- sglang/srt/server_args.py +351 -238
- sglang/srt/speculative/build_eagle_tree.py +1 -1
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -9
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +130 -14
- sglang/srt/speculative/eagle_utils.py +468 -116
- sglang/srt/speculative/eagle_worker.py +258 -84
- sglang/srt/torch_memory_saver_adapter.py +19 -15
- sglang/srt/two_batch_overlap.py +4 -2
- sglang/srt/utils.py +235 -11
- sglang/test/attention/test_prefix_chunk_info.py +2 -0
- sglang/test/runners.py +38 -3
- sglang/test/test_block_fp8.py +1 -0
- sglang/test/test_block_fp8_deep_gemm_blackwell.py +252 -0
- sglang/test/test_block_fp8_ep.py +2 -0
- sglang/test/test_utils.py +4 -1
- sglang/utils.py +9 -0
- sglang/version.py +1 -1
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/METADATA +8 -14
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/RECORD +150 -128
- sglang/srt/entrypoints/verl_engine.py +0 -179
- sglang/srt/openai_api/adapter.py +0 -1990
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/WHEEL +0 -0
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/top_level.txt +0 -0
@@ -6,6 +6,7 @@ import random
|
|
6
6
|
import threading
|
7
7
|
import warnings
|
8
8
|
from collections import deque
|
9
|
+
from contextlib import nullcontext
|
9
10
|
from enum import Enum
|
10
11
|
from typing import TYPE_CHECKING, List, Optional
|
11
12
|
|
@@ -14,15 +15,15 @@ import requests
|
|
14
15
|
import torch
|
15
16
|
import torch.distributed as dist
|
16
17
|
|
17
|
-
from sglang.srt.utils import get_ip
|
18
|
+
from sglang.srt.utils import get_ip
|
18
19
|
|
19
20
|
if TYPE_CHECKING:
|
20
21
|
from sglang.srt.managers.schedule_batch import Req
|
21
22
|
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
23
|
+
#########################
|
24
|
+
# Constants & Enums
|
25
|
+
#########################
|
26
|
+
FAKE_BOOTSTRAP_HOST = "2.2.2.2"
|
26
27
|
|
27
28
|
|
28
29
|
class DisaggregationMode(Enum):
|
@@ -31,6 +32,14 @@ class DisaggregationMode(Enum):
|
|
31
32
|
DECODE = "decode"
|
32
33
|
|
33
34
|
|
35
|
+
#########################
|
36
|
+
# Synchronization
|
37
|
+
#########################
|
38
|
+
|
39
|
+
# env var for testing failure, convert to float explicitly
|
40
|
+
FAILURE_PROB = float(os.getenv("DISAGGREGATION_TEST_FAILURE_PROB", 0))
|
41
|
+
|
42
|
+
|
34
43
|
def poll_and_all_reduce(pollers, gloo_group):
|
35
44
|
# at a certain prob, the poll is failed to simulate failure
|
36
45
|
if FAILURE_PROB > 0:
|
@@ -47,6 +56,11 @@ def poll_and_all_reduce(pollers, gloo_group):
|
|
47
56
|
return tensor_to_reduce.tolist()
|
48
57
|
|
49
58
|
|
59
|
+
#########################
|
60
|
+
# Metadata Buffers
|
61
|
+
#########################
|
62
|
+
|
63
|
+
|
50
64
|
class ReqToMetadataIdxAllocator:
|
51
65
|
"""A memory pool that maps a request to its first output token location."""
|
52
66
|
|
@@ -70,6 +84,118 @@ class ReqToMetadataIdxAllocator:
|
|
70
84
|
self.free_slots.append(free_index)
|
71
85
|
|
72
86
|
|
87
|
+
class MetadataBuffers:
|
88
|
+
def __init__(
|
89
|
+
self,
|
90
|
+
size: int,
|
91
|
+
hidden_size: int,
|
92
|
+
dtype: torch.dtype,
|
93
|
+
max_top_logprobs_num: int = 128,
|
94
|
+
custom_mem_pool: torch.cuda.MemPool = None,
|
95
|
+
):
|
96
|
+
self.custom_mem_pool = custom_mem_pool
|
97
|
+
device = "cuda" if self.custom_mem_pool else "cpu"
|
98
|
+
|
99
|
+
with (
|
100
|
+
torch.cuda.use_mem_pool(self.custom_mem_pool)
|
101
|
+
if self.custom_mem_pool
|
102
|
+
else nullcontext()
|
103
|
+
):
|
104
|
+
# TODO: abort top_logprobs_num > 128 in PD
|
105
|
+
|
106
|
+
# We transfer the metadata of first output token to decode
|
107
|
+
# The minimal size for RDMA is 64Bytes, so we pad it to > 64Bytes
|
108
|
+
self.output_ids = torch.zeros((size, 16), dtype=torch.int32, device=device)
|
109
|
+
|
110
|
+
self.output_hidden_states = torch.zeros(
|
111
|
+
(size, hidden_size), dtype=dtype, device=device
|
112
|
+
)
|
113
|
+
self.output_token_logprobs_val = torch.zeros(
|
114
|
+
(size, 16), dtype=torch.float32, device=device
|
115
|
+
)
|
116
|
+
self.output_token_logprobs_idx = torch.zeros(
|
117
|
+
(size, 16), dtype=torch.int32, device=device
|
118
|
+
)
|
119
|
+
self.output_top_logprobs_val = torch.zeros(
|
120
|
+
(size, max_top_logprobs_num), dtype=torch.float32, device=device
|
121
|
+
)
|
122
|
+
self.output_top_logprobs_idx = torch.zeros(
|
123
|
+
(size, max_top_logprobs_num), dtype=torch.int32, device=device
|
124
|
+
)
|
125
|
+
|
126
|
+
def get_buf_infos(self):
|
127
|
+
ptrs = [
|
128
|
+
self.output_ids.data_ptr(),
|
129
|
+
self.output_hidden_states.data_ptr(), # TODO: set None to avoid transfer hidden_states when spec_algorithm is None
|
130
|
+
self.output_token_logprobs_val.data_ptr(),
|
131
|
+
self.output_token_logprobs_idx.data_ptr(),
|
132
|
+
self.output_top_logprobs_val.data_ptr(),
|
133
|
+
self.output_top_logprobs_idx.data_ptr(),
|
134
|
+
]
|
135
|
+
data_lens = [
|
136
|
+
self.output_ids.nbytes,
|
137
|
+
self.output_hidden_states.nbytes,
|
138
|
+
self.output_token_logprobs_val.nbytes,
|
139
|
+
self.output_token_logprobs_idx.nbytes,
|
140
|
+
self.output_top_logprobs_val.nbytes,
|
141
|
+
self.output_top_logprobs_idx.nbytes,
|
142
|
+
]
|
143
|
+
item_lens = [
|
144
|
+
self.output_ids[0].nbytes,
|
145
|
+
self.output_hidden_states[0].nbytes,
|
146
|
+
self.output_token_logprobs_val[0].nbytes,
|
147
|
+
self.output_token_logprobs_idx[0].nbytes,
|
148
|
+
self.output_top_logprobs_val[0].nbytes,
|
149
|
+
self.output_top_logprobs_idx[0].nbytes,
|
150
|
+
]
|
151
|
+
return ptrs, data_lens, item_lens
|
152
|
+
|
153
|
+
def get_buf(self, idx: int):
|
154
|
+
return (
|
155
|
+
self.output_ids[idx],
|
156
|
+
self.output_hidden_states[idx],
|
157
|
+
self.output_token_logprobs_val[idx],
|
158
|
+
self.output_token_logprobs_idx[idx],
|
159
|
+
self.output_top_logprobs_val[idx],
|
160
|
+
self.output_top_logprobs_idx[idx],
|
161
|
+
)
|
162
|
+
|
163
|
+
def set_buf(self, req: Req):
|
164
|
+
|
165
|
+
self.output_ids[req.metadata_buffer_index][0] = req.output_ids[0]
|
166
|
+
if req.hidden_states_tensor is not None:
|
167
|
+
self.output_hidden_states[req.metadata_buffer_index].copy_(
|
168
|
+
req.hidden_states_tensor
|
169
|
+
)
|
170
|
+
if req.return_logprob:
|
171
|
+
if req.output_token_logprobs_val: # not none or empty list
|
172
|
+
self.output_token_logprobs_val[req.metadata_buffer_index][0] = (
|
173
|
+
req.output_token_logprobs_val[0]
|
174
|
+
)
|
175
|
+
if req.output_token_logprobs_idx: # not none or empty list
|
176
|
+
self.output_token_logprobs_idx[req.metadata_buffer_index][0] = (
|
177
|
+
req.output_token_logprobs_idx[0]
|
178
|
+
)
|
179
|
+
|
180
|
+
if req.output_top_logprobs_val: # not none or empty list
|
181
|
+
self.output_top_logprobs_val[req.metadata_buffer_index][
|
182
|
+
: len(req.output_top_logprobs_val[0])
|
183
|
+
] = torch.tensor(
|
184
|
+
req.output_top_logprobs_val[0], dtype=torch.float32, device="cpu"
|
185
|
+
)
|
186
|
+
if req.output_top_logprobs_idx: # not none or empty list
|
187
|
+
self.output_top_logprobs_idx[req.metadata_buffer_index][
|
188
|
+
: len(req.output_top_logprobs_idx[0])
|
189
|
+
] = torch.tensor(
|
190
|
+
req.output_top_logprobs_idx[0], dtype=torch.int32, device="cpu"
|
191
|
+
)
|
192
|
+
|
193
|
+
|
194
|
+
#########################
|
195
|
+
# Transfer Backend
|
196
|
+
#########################
|
197
|
+
|
198
|
+
|
73
199
|
class TransferBackend(Enum):
|
74
200
|
MOONCAKE = "mooncake"
|
75
201
|
NIXL = "nixl"
|
@@ -77,6 +203,7 @@ class TransferBackend(Enum):
|
|
77
203
|
|
78
204
|
|
79
205
|
class KVClassType(Enum):
|
206
|
+
KVARGS = "kvargs"
|
80
207
|
MANAGER = "manager"
|
81
208
|
SENDER = "sender"
|
82
209
|
RECEIVER = "receiver"
|
@@ -87,6 +214,7 @@ def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
|
|
87
214
|
from sglang.srt.disaggregation.fake import FakeKVReceiver, FakeKVSender
|
88
215
|
|
89
216
|
if transfer_backend == TransferBackend.MOONCAKE:
|
217
|
+
from sglang.srt.disaggregation.base import KVArgs
|
90
218
|
from sglang.srt.disaggregation.mooncake import (
|
91
219
|
MooncakeKVBootstrapServer,
|
92
220
|
MooncakeKVManager,
|
@@ -95,13 +223,15 @@ def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
|
|
95
223
|
)
|
96
224
|
|
97
225
|
class_mapping = {
|
226
|
+
KVClassType.KVARGS: KVArgs,
|
98
227
|
KVClassType.MANAGER: MooncakeKVManager,
|
99
228
|
KVClassType.SENDER: MooncakeKVSender,
|
100
229
|
KVClassType.RECEIVER: (MooncakeKVReceiver),
|
101
230
|
KVClassType.BOOTSTRAP_SERVER: MooncakeKVBootstrapServer,
|
102
231
|
}
|
103
232
|
return class_mapping.get(class_type)
|
104
|
-
|
233
|
+
elif transfer_backend == TransferBackend.NIXL:
|
234
|
+
from sglang.srt.disaggregation.base import KVArgs
|
105
235
|
from sglang.srt.disaggregation.nixl import (
|
106
236
|
NixlKVBootstrapServer,
|
107
237
|
NixlKVManager,
|
@@ -110,16 +240,19 @@ def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
|
|
110
240
|
)
|
111
241
|
|
112
242
|
class_mapping = {
|
243
|
+
KVClassType.KVARGS: KVArgs,
|
113
244
|
KVClassType.MANAGER: NixlKVManager,
|
114
245
|
KVClassType.SENDER: NixlKVSender,
|
115
246
|
KVClassType.RECEIVER: (NixlKVReceiver),
|
116
247
|
KVClassType.BOOTSTRAP_SERVER: NixlKVBootstrapServer,
|
117
248
|
}
|
118
249
|
return class_mapping.get(class_type)
|
119
|
-
|
250
|
+
elif transfer_backend == TransferBackend.FAKE:
|
251
|
+
from sglang.srt.disaggregation.base import KVArgs
|
120
252
|
from sglang.srt.disaggregation.fake import FakeKVReceiver, FakeKVSender
|
121
253
|
|
122
254
|
class_mapping = {
|
255
|
+
KVClassType.KVARGS: KVArgs,
|
123
256
|
KVClassType.SENDER: FakeKVSender,
|
124
257
|
KVClassType.RECEIVER: (FakeKVReceiver),
|
125
258
|
}
|
@@ -128,6 +261,11 @@ def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
|
|
128
261
|
raise ValueError(f"Unsupported transfer backend: {transfer_backend}")
|
129
262
|
|
130
263
|
|
264
|
+
#########################
|
265
|
+
# KV Pages
|
266
|
+
#########################
|
267
|
+
|
268
|
+
|
131
269
|
def kv_to_page_indices(kv_indices: np.ndarray, page_size: int):
|
132
270
|
# 1. The page is guaranteed to be full except the last page.
|
133
271
|
# 2. page index = kv_index // page_size
|
@@ -143,6 +281,11 @@ def kv_to_page_num(num_kv_indices: int, page_size: int):
|
|
143
281
|
return (num_kv_indices + page_size - 1) // page_size
|
144
282
|
|
145
283
|
|
284
|
+
#########################
|
285
|
+
# PDLB Registry
|
286
|
+
#########################
|
287
|
+
|
288
|
+
|
146
289
|
@dataclasses.dataclass
|
147
290
|
class PDRegistryRequest:
|
148
291
|
"""A request to register a machine itself to the LB."""
|
@@ -181,6 +324,11 @@ def register_disaggregation_server(
|
|
181
324
|
)
|
182
325
|
|
183
326
|
|
327
|
+
#########################
|
328
|
+
# Misc
|
329
|
+
#########################
|
330
|
+
|
331
|
+
|
184
332
|
def is_mla_backend(target_kv_pool) -> bool:
|
185
333
|
from sglang.srt.mem_cache.memory_pool import MLATokenToKVPool
|
186
334
|
|
@@ -200,119 +348,3 @@ def prepare_abort(req: Req, error_message: str, status_code=None):
|
|
200
348
|
req.input_top_logprobs_idx = []
|
201
349
|
req.input_token_ids_logprobs_val = []
|
202
350
|
req.input_token_ids_logprobs_idx = []
|
203
|
-
|
204
|
-
|
205
|
-
class MetadataBuffers:
|
206
|
-
def __init__(self, size: int, max_top_logprobs_num: int = 128):
|
207
|
-
# TODO: abort top_logprobs_num > 128 in PD
|
208
|
-
|
209
|
-
# We transfer the metadata of first output token to decode
|
210
|
-
# The minimal size for RDMA is 64Bytes, so we pad it to > 64Bytes
|
211
|
-
self.output_ids = torch.zeros((size, 16), dtype=torch.int32, device="cpu")
|
212
|
-
self.output_token_logprobs_val = torch.zeros(
|
213
|
-
(size, 16), dtype=torch.float32, device="cpu"
|
214
|
-
)
|
215
|
-
self.output_token_logprobs_idx = torch.zeros(
|
216
|
-
(size, 16), dtype=torch.int32, device="cpu"
|
217
|
-
)
|
218
|
-
self.output_top_logprobs_val = torch.zeros(
|
219
|
-
(size, max_top_logprobs_num), dtype=torch.float32, device="cpu"
|
220
|
-
)
|
221
|
-
self.output_top_logprobs_idx = torch.zeros(
|
222
|
-
(size, max_top_logprobs_num), dtype=torch.int32, device="cpu"
|
223
|
-
)
|
224
|
-
|
225
|
-
def get_buf_infos(self):
|
226
|
-
ptrs = [
|
227
|
-
self.output_ids.data_ptr(),
|
228
|
-
self.output_token_logprobs_val.data_ptr(),
|
229
|
-
self.output_token_logprobs_idx.data_ptr(),
|
230
|
-
self.output_top_logprobs_val.data_ptr(),
|
231
|
-
self.output_top_logprobs_idx.data_ptr(),
|
232
|
-
]
|
233
|
-
data_lens = [
|
234
|
-
self.output_ids.nbytes,
|
235
|
-
self.output_token_logprobs_val.nbytes,
|
236
|
-
self.output_token_logprobs_idx.nbytes,
|
237
|
-
self.output_top_logprobs_val.nbytes,
|
238
|
-
self.output_top_logprobs_idx.nbytes,
|
239
|
-
]
|
240
|
-
item_lens = [
|
241
|
-
self.output_ids[0].nbytes,
|
242
|
-
self.output_token_logprobs_val[0].nbytes,
|
243
|
-
self.output_token_logprobs_idx[0].nbytes,
|
244
|
-
self.output_top_logprobs_val[0].nbytes,
|
245
|
-
self.output_top_logprobs_idx[0].nbytes,
|
246
|
-
]
|
247
|
-
return ptrs, data_lens, item_lens
|
248
|
-
|
249
|
-
def get_buf(self, idx: int):
|
250
|
-
return (
|
251
|
-
self.output_ids[idx],
|
252
|
-
self.output_token_logprobs_val[idx],
|
253
|
-
self.output_token_logprobs_idx[idx],
|
254
|
-
self.output_top_logprobs_val[idx],
|
255
|
-
self.output_top_logprobs_idx[idx],
|
256
|
-
)
|
257
|
-
|
258
|
-
def set_buf(self, req: Req):
|
259
|
-
|
260
|
-
self.output_ids[req.metadata_buffer_index][0] = req.output_ids[0]
|
261
|
-
if req.return_logprob:
|
262
|
-
if req.output_token_logprobs_val: # not none or empty list
|
263
|
-
self.output_token_logprobs_val[req.metadata_buffer_index][0] = (
|
264
|
-
req.output_token_logprobs_val[0]
|
265
|
-
)
|
266
|
-
if req.output_token_logprobs_idx: # not none or empty list
|
267
|
-
self.output_token_logprobs_idx[req.metadata_buffer_index][0] = (
|
268
|
-
req.output_token_logprobs_idx[0]
|
269
|
-
)
|
270
|
-
|
271
|
-
if req.output_top_logprobs_val: # not none or empty list
|
272
|
-
self.output_top_logprobs_val[req.metadata_buffer_index][
|
273
|
-
: len(req.output_top_logprobs_val[0])
|
274
|
-
] = torch.tensor(
|
275
|
-
req.output_top_logprobs_val[0], dtype=torch.float32, device="cpu"
|
276
|
-
)
|
277
|
-
if req.output_top_logprobs_idx: # not none or empty list
|
278
|
-
self.output_top_logprobs_idx[req.metadata_buffer_index][
|
279
|
-
: len(req.output_top_logprobs_idx[0])
|
280
|
-
] = torch.tensor(
|
281
|
-
req.output_top_logprobs_idx[0], dtype=torch.int32, device="cpu"
|
282
|
-
)
|
283
|
-
|
284
|
-
|
285
|
-
class FastQueue:
|
286
|
-
def __init__(self):
|
287
|
-
self._buf = deque()
|
288
|
-
self._cond = threading.Condition()
|
289
|
-
|
290
|
-
def put(self, item):
|
291
|
-
with self._cond:
|
292
|
-
self._buf.append(item)
|
293
|
-
# wake up a thread of wait()
|
294
|
-
self._cond.notify()
|
295
|
-
|
296
|
-
def get(self):
|
297
|
-
with self._cond:
|
298
|
-
# if queue is empty ,block until is notified()
|
299
|
-
while not self._buf:
|
300
|
-
self._cond.wait()
|
301
|
-
return self._buf.popleft()
|
302
|
-
|
303
|
-
|
304
|
-
def group_concurrent_contiguous(
|
305
|
-
src_indices: npt.NDArray[np.int64], dst_indices: npt.NDArray[np.int64]
|
306
|
-
) -> Tuple[List[npt.NDArray[np.int64]], List[npt.NDArray[np.int64]]]:
|
307
|
-
"""Vectorised NumPy implementation."""
|
308
|
-
if src_indices.size == 0:
|
309
|
-
return [], []
|
310
|
-
|
311
|
-
brk = np.where((np.diff(src_indices) != 1) | (np.diff(dst_indices) != 1))[0] + 1
|
312
|
-
src_groups = np.split(src_indices, brk)
|
313
|
-
dst_groups = np.split(dst_indices, brk)
|
314
|
-
|
315
|
-
src_groups = [g.tolist() for g in src_groups]
|
316
|
-
dst_groups = [g.tolist() for g in dst_groups]
|
317
|
-
|
318
|
-
return src_groups, dst_groups
|
@@ -523,17 +523,25 @@ class GroupCoordinator:
|
|
523
523
|
self,
|
524
524
|
input_: torch.Tensor,
|
525
525
|
dim: int = -1,
|
526
|
-
|
526
|
+
output_tensor_list: Optional[List[torch.Tensor]] = None,
|
527
527
|
) -> torch.Tensor:
|
528
528
|
world_size = self.world_size
|
529
529
|
# Bypass the function if we are using only 1 GPU.
|
530
530
|
if world_size == 1:
|
531
|
-
|
531
|
+
if output_tensor_list is not None:
|
532
|
+
logger.warning(
|
533
|
+
"Performing in-place all-gather with a group size of 1. "
|
534
|
+
"This may be unnecessary; consider bypassing it for better efficiency."
|
535
|
+
)
|
536
|
+
output_tensor_list[0].copy_(input_)
|
537
|
+
return None
|
538
|
+
else:
|
539
|
+
return input_
|
532
540
|
|
533
|
-
if
|
541
|
+
if output_tensor_list is not None:
|
534
542
|
# TODO(ch-wan): support other backends
|
535
543
|
return torch.distributed.all_gather(
|
536
|
-
|
544
|
+
output_tensor_list, input_, group=self.device_group
|
537
545
|
)
|
538
546
|
|
539
547
|
assert (
|
sglang/srt/entrypoints/engine.py
CHANGED
@@ -37,7 +37,6 @@ setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
|
|
37
37
|
import torch
|
38
38
|
import uvloop
|
39
39
|
|
40
|
-
from sglang.srt.code_completion_parser import load_completion_template_for_openai_api
|
41
40
|
from sglang.srt.entrypoints.EngineBase import EngineBase
|
42
41
|
from sglang.srt.managers.data_parallel_controller import (
|
43
42
|
run_data_parallel_controller_process,
|
@@ -58,11 +57,8 @@ from sglang.srt.managers.io_struct import (
|
|
58
57
|
UpdateWeightsFromTensorReqInput,
|
59
58
|
)
|
60
59
|
from sglang.srt.managers.scheduler import run_scheduler_process
|
60
|
+
from sglang.srt.managers.template_manager import TemplateManager
|
61
61
|
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
62
|
-
from sglang.srt.openai_api.adapter import (
|
63
|
-
guess_chat_template_name_from_model_path,
|
64
|
-
load_chat_template_for_openai_api,
|
65
|
-
)
|
66
62
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
67
63
|
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
68
64
|
from sglang.srt.utils import (
|
@@ -123,12 +119,13 @@ class Engine(EngineBase):
|
|
123
119
|
logger.info(f"{server_args=}")
|
124
120
|
|
125
121
|
# Launch subprocesses
|
126
|
-
tokenizer_manager, scheduler_info = _launch_subprocesses(
|
122
|
+
tokenizer_manager, template_manager, scheduler_info = _launch_subprocesses(
|
127
123
|
server_args=server_args,
|
128
124
|
port_args=port_args,
|
129
125
|
)
|
130
126
|
self.server_args = server_args
|
131
127
|
self.tokenizer_manager = tokenizer_manager
|
128
|
+
self.template_manager = template_manager
|
132
129
|
self.scheduler_info = scheduler_info
|
133
130
|
|
134
131
|
context = zmq.Context(2)
|
@@ -175,7 +172,7 @@ class Engine(EngineBase):
|
|
175
172
|
"""
|
176
173
|
if self.server_args.enable_dp_attention:
|
177
174
|
if data_parallel_rank is None:
|
178
|
-
logger.
|
175
|
+
logger.debug("data_parallel_rank not provided, using default dispatch")
|
179
176
|
elif data_parallel_rank < 0:
|
180
177
|
raise ValueError("data_parallel_rank must be non-negative")
|
181
178
|
elif data_parallel_rank >= self.server_args.dp_size:
|
@@ -258,7 +255,7 @@ class Engine(EngineBase):
|
|
258
255
|
|
259
256
|
if self.server_args.enable_dp_attention:
|
260
257
|
if data_parallel_rank is None:
|
261
|
-
logger.
|
258
|
+
logger.debug("data_parallel_rank not provided, using default dispatch")
|
262
259
|
elif data_parallel_rank < 0:
|
263
260
|
raise ValueError("data_parallel_rank must be non-negative")
|
264
261
|
elif data_parallel_rank >= self.server_args.dp_size:
|
@@ -327,6 +324,20 @@ class Engine(EngineBase):
|
|
327
324
|
generator = self.tokenizer_manager.generate_request(obj, None)
|
328
325
|
return await generator.__anext__()
|
329
326
|
|
327
|
+
def rerank(
|
328
|
+
self,
|
329
|
+
prompt: Union[List[List[str]]],
|
330
|
+
) -> Dict:
|
331
|
+
"""
|
332
|
+
The arguments of this function is the same as `sglang/srt/managers/io_struct.py::EmbeddingReqInput`.
|
333
|
+
Please refer to `EmbeddingReqInput` for the documentation.
|
334
|
+
"""
|
335
|
+
obj = EmbeddingReqInput(text=prompt, is_cross_encoder_request=True)
|
336
|
+
loop = asyncio.get_event_loop()
|
337
|
+
generator = self.tokenizer_manager.generate_request(obj, None)
|
338
|
+
ret = loop.run_until_complete(generator.__anext__())
|
339
|
+
return ret
|
340
|
+
|
330
341
|
def shutdown(self):
|
331
342
|
"""Shutdown the engine"""
|
332
343
|
kill_process_tree(os.getpid(), include_parent=False)
|
@@ -465,17 +476,15 @@ class Engine(EngineBase):
|
|
465
476
|
self.tokenizer_manager.get_weights_by_name(obj, None)
|
466
477
|
)
|
467
478
|
|
468
|
-
def release_memory_occupation(self):
|
469
|
-
|
470
|
-
obj = ReleaseMemoryOccupationReqInput()
|
479
|
+
def release_memory_occupation(self, tags: Optional[List[str]] = None):
|
480
|
+
obj = ReleaseMemoryOccupationReqInput(tags=tags)
|
471
481
|
loop = asyncio.get_event_loop()
|
472
482
|
return loop.run_until_complete(
|
473
483
|
self.tokenizer_manager.release_memory_occupation(obj, None)
|
474
484
|
)
|
475
485
|
|
476
|
-
def resume_memory_occupation(self):
|
477
|
-
|
478
|
-
obj = ResumeMemoryOccupationReqInput()
|
486
|
+
def resume_memory_occupation(self, tags: Optional[List[str]] = None):
|
487
|
+
obj = ResumeMemoryOccupationReqInput(tags=tags)
|
479
488
|
loop = asyncio.get_event_loop()
|
480
489
|
return loop.run_until_complete(
|
481
490
|
self.tokenizer_manager.resume_memory_occupation(obj, None)
|
@@ -605,7 +614,7 @@ def _set_envs_and_config(server_args: ServerArgs):
|
|
605
614
|
if _is_cuda:
|
606
615
|
assert_pkg_version(
|
607
616
|
"sgl-kernel",
|
608
|
-
"0.1.
|
617
|
+
"0.1.9",
|
609
618
|
"Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`",
|
610
619
|
)
|
611
620
|
|
@@ -635,7 +644,7 @@ def _set_envs_and_config(server_args: ServerArgs):
|
|
635
644
|
|
636
645
|
def _launch_subprocesses(
|
637
646
|
server_args: ServerArgs, port_args: Optional[PortArgs] = None
|
638
|
-
) -> Tuple[TokenizerManager, Dict]:
|
647
|
+
) -> Tuple[TokenizerManager, TemplateManager, Dict]:
|
639
648
|
"""
|
640
649
|
Launch the TokenizerManager in the main process, the Scheduler in a subprocess, and the DetokenizerManager in another subprocess.
|
641
650
|
"""
|
@@ -656,11 +665,9 @@ def _launch_subprocesses(
|
|
656
665
|
|
657
666
|
scheduler_procs = []
|
658
667
|
if server_args.dp_size == 1:
|
659
|
-
# Launch tensor parallel scheduler processes
|
660
668
|
memory_saver_adapter = TorchMemorySaverAdapter.create(
|
661
669
|
enable=server_args.enable_memory_saver
|
662
670
|
)
|
663
|
-
|
664
671
|
scheduler_pipe_readers = []
|
665
672
|
|
666
673
|
nnodes_per_tp_group = max(server_args.nnodes // server_args.pp_size, 1)
|
@@ -696,6 +703,7 @@ def _launch_subprocesses(
|
|
696
703
|
writer,
|
697
704
|
),
|
698
705
|
)
|
706
|
+
|
699
707
|
with memory_saver_adapter.configure_subprocess():
|
700
708
|
proc.start()
|
701
709
|
scheduler_procs.append(proc)
|
@@ -721,7 +729,7 @@ def _launch_subprocesses(
|
|
721
729
|
|
722
730
|
if os.getenv("SGLANG_BLOCK_NONZERO_RANK_CHILDREN") == "0":
|
723
731
|
# When using `Engine` as a Python API, we don't want to block here.
|
724
|
-
return None, None
|
732
|
+
return None, None, None
|
725
733
|
|
726
734
|
launch_dummy_health_check_server(server_args.host, server_args.port)
|
727
735
|
|
@@ -730,7 +738,7 @@ def _launch_subprocesses(
|
|
730
738
|
logger.error(
|
731
739
|
f"Scheduler or DataParallelController {proc.pid} terminated with {proc.exitcode}"
|
732
740
|
)
|
733
|
-
return None, None
|
741
|
+
return None, None, None
|
734
742
|
|
735
743
|
# Launch detokenizer process
|
736
744
|
detoken_proc = mp.Process(
|
@@ -744,15 +752,15 @@ def _launch_subprocesses(
|
|
744
752
|
|
745
753
|
# Launch tokenizer process
|
746
754
|
tokenizer_manager = TokenizerManager(server_args, port_args)
|
747
|
-
if server_args.chat_template:
|
748
|
-
load_chat_template_for_openai_api(
|
749
|
-
tokenizer_manager, server_args.chat_template, server_args.model_path
|
750
|
-
)
|
751
|
-
else:
|
752
|
-
guess_chat_template_name_from_model_path(server_args.model_path)
|
753
755
|
|
754
|
-
|
755
|
-
|
756
|
+
# Initialize templates
|
757
|
+
template_manager = TemplateManager()
|
758
|
+
template_manager.initialize_templates(
|
759
|
+
tokenizer_manager=tokenizer_manager,
|
760
|
+
model_path=server_args.model_path,
|
761
|
+
chat_template=server_args.chat_template,
|
762
|
+
completion_template=server_args.completion_template,
|
763
|
+
)
|
756
764
|
|
757
765
|
# Wait for the model to finish loading
|
758
766
|
scheduler_infos = []
|
@@ -776,4 +784,4 @@ def _launch_subprocesses(
|
|
776
784
|
# Assume all schedulers have the same scheduler_info
|
777
785
|
scheduler_info = scheduler_infos[0]
|
778
786
|
tokenizer_manager.max_req_input_len = scheduler_info["max_req_input_len"]
|
779
|
-
return tokenizer_manager, scheduler_info
|
787
|
+
return tokenizer_manager, template_manager, scheduler_info
|