sglang 0.4.5.post2__py3-none-any.whl → 0.4.5.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_serving.py +3 -2
- sglang/compile_deep_gemm.py +136 -0
- sglang/lang/backend/openai.py +5 -1
- sglang/lang/backend/runtime_endpoint.py +5 -1
- sglang/srt/configs/model_config.py +4 -1
- sglang/srt/constrained/xgrammar_backend.py +1 -0
- sglang/srt/disaggregation/decode.py +43 -0
- sglang/srt/disaggregation/mini_lb.py +69 -8
- sglang/srt/disaggregation/mooncake/conn.py +1 -1
- sglang/srt/disaggregation/nixl/__init__.py +1 -0
- sglang/srt/disaggregation/nixl/conn.py +622 -0
- sglang/srt/disaggregation/prefill.py +100 -16
- sglang/srt/disaggregation/utils.py +17 -0
- sglang/srt/entrypoints/engine.py +4 -0
- sglang/srt/entrypoints/http_server.py +3 -7
- sglang/srt/function_call_parser.py +60 -0
- sglang/srt/layers/activation.py +2 -2
- sglang/srt/layers/attention/flashattention_backend.py +781 -150
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +5 -5
- sglang/srt/layers/attention/triton_ops/extend_attention.py +5 -5
- sglang/srt/layers/attention/triton_ops/prefill_attention.py +7 -3
- sglang/srt/layers/dp_attention.py +1 -1
- sglang/srt/layers/layernorm.py +19 -4
- sglang/srt/layers/moe/ep_moe/layer.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +153 -0
- sglang/srt/layers/quantization/deep_gemm.py +378 -0
- sglang/srt/layers/quantization/fp8_kernel.py +7 -38
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/quantization/gptq.py +13 -7
- sglang/srt/layers/quantization/modelopt_quant.py +2 -2
- sglang/srt/layers/quantization/w8a8_int8.py +3 -3
- sglang/srt/layers/rotary_embedding.py +6 -6
- sglang/srt/layers/sampler.py +2 -2
- sglang/srt/managers/data_parallel_controller.py +7 -1
- sglang/srt/managers/io_struct.py +14 -3
- sglang/srt/managers/schedule_batch.py +13 -0
- sglang/srt/managers/scheduler.py +16 -6
- sglang/srt/managers/tokenizer_manager.py +115 -29
- sglang/srt/managers/tp_worker.py +1 -0
- sglang/srt/mem_cache/hiradix_cache.py +40 -32
- sglang/srt/mem_cache/memory_pool.py +31 -13
- sglang/srt/model_executor/cuda_graph_runner.py +13 -8
- sglang/srt/model_executor/model_runner.py +19 -4
- sglang/srt/models/deepseek_v2.py +9 -6
- sglang/srt/models/minicpm3.py +2 -2
- sglang/srt/models/minicpmo.py +17 -6
- sglang/srt/openai_api/adapter.py +71 -4
- sglang/srt/openai_api/protocol.py +6 -1
- sglang/srt/server_args.py +52 -40
- sglang/srt/speculative/build_eagle_tree.py +2 -2
- sglang/srt/speculative/eagle_utils.py +2 -2
- sglang/srt/speculative/eagle_worker.py +2 -7
- sglang/srt/utils.py +46 -5
- sglang/test/test_utils.py +3 -1
- sglang/version.py +1 -1
- {sglang-0.4.5.post2.dist-info → sglang-0.4.5.post3.dist-info}/METADATA +3 -3
- {sglang-0.4.5.post2.dist-info → sglang-0.4.5.post3.dist-info}/RECORD +62 -57
- {sglang-0.4.5.post2.dist-info → sglang-0.4.5.post3.dist-info}/WHEEL +0 -0
- {sglang-0.4.5.post2.dist-info → sglang-0.4.5.post3.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.5.post2.dist-info → sglang-0.4.5.post3.dist-info}/top_level.txt +0 -0
sglang/bench_serving.py
CHANGED
@@ -690,7 +690,6 @@ def sample_random_requests(
|
|
690
690
|
dataset_path: str,
|
691
691
|
random_sample: bool = True,
|
692
692
|
) -> List[Tuple[str, int, int]]:
|
693
|
-
|
694
693
|
input_lens = np.random.randint(
|
695
694
|
max(int(input_len * range_ratio), 1),
|
696
695
|
input_len + 1,
|
@@ -1025,7 +1024,9 @@ async def benchmark(
|
|
1025
1024
|
warmup_outputs = await asyncio.gather(*warmup_tasks)
|
1026
1025
|
|
1027
1026
|
# Check if at least one warmup request succeeded
|
1028
|
-
if
|
1027
|
+
if args.warmup_requests > 0 and not any(
|
1028
|
+
output.success for output in warmup_outputs
|
1029
|
+
):
|
1029
1030
|
raise ValueError(
|
1030
1031
|
"Warmup failed - Please make sure benchmark arguments "
|
1031
1032
|
f"are correctly specified. Error: {warmup_outputs[0].error}"
|
@@ -0,0 +1,136 @@
|
|
1
|
+
"""
|
2
|
+
Compile DeepGEMM Kernels for a model with specify server arguments
|
3
|
+
|
4
|
+
This script launches a server for capturing DeepGEMM calls and then compiles the kernels.
|
5
|
+
It accepts server arguments (the same as launch_server.py).
|
6
|
+
|
7
|
+
Usage:
|
8
|
+
python3 -m sglang.compile_deep_gemm --model deepseek-ai/DeepSeek-V3 --tp 8 --trust-remote-code
|
9
|
+
|
10
|
+
"""
|
11
|
+
|
12
|
+
import argparse
|
13
|
+
import dataclasses
|
14
|
+
import multiprocessing
|
15
|
+
import os
|
16
|
+
import time
|
17
|
+
|
18
|
+
import requests
|
19
|
+
|
20
|
+
from sglang.srt.entrypoints.http_server import launch_server
|
21
|
+
from sglang.srt.managers.io_struct import GenerateReqInput
|
22
|
+
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
23
|
+
from sglang.srt.server_args import ServerArgs
|
24
|
+
from sglang.srt.utils import kill_process_tree
|
25
|
+
from sglang.srt.warmup import warmup
|
26
|
+
|
27
|
+
multiprocessing.set_start_method("spawn", force=True)
|
28
|
+
|
29
|
+
# Reduce warning
|
30
|
+
os.environ["SGL_IN_DEEP_GEMM_PRE_COMPILE_STAGE"] = "1"
|
31
|
+
|
32
|
+
|
33
|
+
@dataclasses.dataclass
|
34
|
+
class CompileArgs:
|
35
|
+
timeout: int = 3600
|
36
|
+
|
37
|
+
@staticmethod
|
38
|
+
def add_cli_args(parser: argparse.ArgumentParser):
|
39
|
+
parser.add_argument("--timeout", type=int, default=CompileArgs.timeout)
|
40
|
+
|
41
|
+
@classmethod
|
42
|
+
def from_cli_args(cls, args: argparse.Namespace):
|
43
|
+
# use the default value's type to cast the args into correct types.
|
44
|
+
attrs = [(attr.name, type(attr.default)) for attr in dataclasses.fields(cls)]
|
45
|
+
return cls(
|
46
|
+
**{attr: attr_type(getattr(args, attr)) for attr, attr_type in attrs}
|
47
|
+
)
|
48
|
+
|
49
|
+
|
50
|
+
@warmup("compile-deep-gemm")
|
51
|
+
async def warm_up_compile(tokenizer_manager: TokenizerManager):
|
52
|
+
print("\nGenerate warm up request for compiling DeepGEMM...\n")
|
53
|
+
generate_req_input = GenerateReqInput(
|
54
|
+
input_ids=[0, 1, 2, 3],
|
55
|
+
sampling_params={
|
56
|
+
"temperature": 0.0,
|
57
|
+
"max_new_tokens": 8,
|
58
|
+
"ignore_eos": True,
|
59
|
+
},
|
60
|
+
)
|
61
|
+
await tokenizer_manager.generate_request(generate_req_input, None).__anext__()
|
62
|
+
|
63
|
+
|
64
|
+
def launch_server_internal(server_args):
|
65
|
+
try:
|
66
|
+
launch_server(server_args)
|
67
|
+
except Exception as e:
|
68
|
+
raise e
|
69
|
+
finally:
|
70
|
+
kill_process_tree(os.getpid(), include_parent=False)
|
71
|
+
|
72
|
+
|
73
|
+
def launch_server_process_and_send_one_request(
|
74
|
+
server_args: ServerArgs, compile_args: CompileArgs
|
75
|
+
):
|
76
|
+
proc = multiprocessing.Process(target=launch_server_internal, args=(server_args,))
|
77
|
+
proc.start()
|
78
|
+
base_url = f"http://{server_args.host}:{server_args.port}"
|
79
|
+
timeout = compile_args.timeout
|
80
|
+
|
81
|
+
start_time = time.time()
|
82
|
+
while time.time() - start_time < timeout:
|
83
|
+
try:
|
84
|
+
headers = {
|
85
|
+
"Content-Type": "application/json; charset=utf-8",
|
86
|
+
}
|
87
|
+
response = requests.get(f"{base_url}/v1/models", headers=headers)
|
88
|
+
if response.status_code == 200:
|
89
|
+
return proc
|
90
|
+
except requests.RequestException:
|
91
|
+
pass
|
92
|
+
time.sleep(10)
|
93
|
+
raise TimeoutError(
|
94
|
+
"DeepGEMM Kernels compilation timeout."
|
95
|
+
"\n\nFeel free and please restart the command."
|
96
|
+
)
|
97
|
+
|
98
|
+
|
99
|
+
def refine_server_args(server_args: ServerArgs, compile_args: CompileArgs):
|
100
|
+
# Disbale cuda graph and torch compile to save time
|
101
|
+
server_args.disable_cuda_graph = True
|
102
|
+
server_args.enable_torch_compile = False
|
103
|
+
print(f"Disable CUDA Graph and Torch Compile to save time...")
|
104
|
+
|
105
|
+
# Set watchdog timeout to compile_args.timeout because compilation will take a long time
|
106
|
+
server_args.watchdog_timeout = compile_args.timeout
|
107
|
+
server_args.warmups = "compile-deep-gemm"
|
108
|
+
|
109
|
+
|
110
|
+
def run_compile(server_args: ServerArgs, compile_args: CompileArgs):
|
111
|
+
print(
|
112
|
+
"Begin DeepGEMM Kernels compilation...\n"
|
113
|
+
"It may take a long time and timeout maybe raised "
|
114
|
+
"while the compilation is still in progress.\n"
|
115
|
+
"Just feel free to restart the command "
|
116
|
+
"until the compilation is fully finished.\n"
|
117
|
+
)
|
118
|
+
|
119
|
+
proc = launch_server_process_and_send_one_request(server_args, compile_args)
|
120
|
+
|
121
|
+
kill_process_tree(proc.pid)
|
122
|
+
|
123
|
+
print("\nDeepGEMM Kernels compilation finished successfully.")
|
124
|
+
|
125
|
+
|
126
|
+
if __name__ == "__main__":
|
127
|
+
parser = argparse.ArgumentParser()
|
128
|
+
ServerArgs.add_cli_args(parser)
|
129
|
+
CompileArgs.add_cli_args(parser)
|
130
|
+
args = parser.parse_args()
|
131
|
+
server_args = ServerArgs.from_cli_args(args)
|
132
|
+
compile_args = CompileArgs.from_cli_args(args)
|
133
|
+
|
134
|
+
refine_server_args(server_args, compile_args)
|
135
|
+
|
136
|
+
run_compile(server_args, compile_args)
|
sglang/lang/backend/openai.py
CHANGED
@@ -161,7 +161,11 @@ class OpenAI(BaseBackend):
|
|
161
161
|
prompt = s.text_
|
162
162
|
|
163
163
|
kwargs = sampling_params.to_openai_kwargs()
|
164
|
-
if
|
164
|
+
if (
|
165
|
+
self.model_name.startswith("o1")
|
166
|
+
or self.model_name.startswith("o3")
|
167
|
+
or "o1" in self.model_name
|
168
|
+
):
|
165
169
|
kwargs.pop("max_tokens", None)
|
166
170
|
else:
|
167
171
|
kwargs.pop("max_completion_tokens", None)
|
@@ -324,7 +324,11 @@ class RuntimeEndpoint(BaseBackend):
|
|
324
324
|
|
325
325
|
def _assert_success(self, res):
|
326
326
|
if res.status_code != 200:
|
327
|
-
|
327
|
+
try:
|
328
|
+
content = res.json()
|
329
|
+
except json.JSONDecodeError:
|
330
|
+
content = res.text
|
331
|
+
raise RuntimeError(content)
|
328
332
|
|
329
333
|
|
330
334
|
def compute_normalized_prompt_logprobs(input_logprobs):
|
@@ -73,8 +73,11 @@ class ModelConfig:
|
|
73
73
|
)
|
74
74
|
|
75
75
|
if enable_multimodal is None:
|
76
|
-
if self.hf_config.architectures == "Llama4ForConditionalGeneration":
|
76
|
+
if self.hf_config.architectures[0] == "Llama4ForConditionalGeneration":
|
77
77
|
enable_multimodal = False
|
78
|
+
logger.info(
|
79
|
+
"Multimodal is disabled for Llama4. To enable it, set --enable-llama4-multimodal."
|
80
|
+
)
|
78
81
|
else:
|
79
82
|
enable_multimodal = True
|
80
83
|
|
@@ -158,6 +158,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
|
|
158
158
|
def dispatch_json(self, key_string: str) -> Optional[XGrammarGrammar]:
|
159
159
|
try:
|
160
160
|
if key_string == "$$ANY$$":
|
161
|
+
# Note: This builtin JSON grammar includes *all* valid JSON (including, for example, arrays at the root)
|
161
162
|
ctx = self.grammar_compiler.compile_builtin_json_grammar()
|
162
163
|
else:
|
163
164
|
ctx = self.grammar_compiler.compile_json_schema(schema=key_string)
|
@@ -21,6 +21,7 @@ Life cycle of a request in the decode server
|
|
21
21
|
from __future__ import annotations
|
22
22
|
|
23
23
|
import logging
|
24
|
+
from collections import deque
|
24
25
|
from dataclasses import dataclass
|
25
26
|
from typing import TYPE_CHECKING, List, Optional, Tuple
|
26
27
|
|
@@ -475,6 +476,48 @@ class SchedulerDisaggregationDecodeMixin:
|
|
475
476
|
|
476
477
|
self.last_batch = batch
|
477
478
|
|
479
|
+
@torch.no_grad()
|
480
|
+
def event_loop_overlap_disagg_decode(self):
|
481
|
+
result_queue = deque()
|
482
|
+
self.last_batch: Optional[ScheduleBatch] = None
|
483
|
+
self.last_batch_is_extend = False # last batch is modifed in-place, so we need another variable to track if it's extend
|
484
|
+
|
485
|
+
while True:
|
486
|
+
recv_reqs = self.recv_requests()
|
487
|
+
self.process_input_requests(recv_reqs)
|
488
|
+
# polling and allocating kv cache
|
489
|
+
self.process_decode_queue()
|
490
|
+
batch = self.get_next_disagg_decode_batch_to_run()
|
491
|
+
self.cur_batch = batch
|
492
|
+
last_batch_is_extend = False
|
493
|
+
|
494
|
+
if batch:
|
495
|
+
# Generate fake extend output.
|
496
|
+
if batch.forward_mode.is_extend():
|
497
|
+
# Note: Logprobs should be handled on the prefill engine.
|
498
|
+
self.stream_output(batch.reqs, False)
|
499
|
+
last_batch_is_extend = True
|
500
|
+
else:
|
501
|
+
result = self.run_batch(batch)
|
502
|
+
result_queue.append((batch.copy(), result))
|
503
|
+
|
504
|
+
# Process the results of the previous batch but skip if the last batch is extend
|
505
|
+
if self.last_batch and not self.last_batch_is_extend:
|
506
|
+
tmp_batch, tmp_result = result_queue.popleft()
|
507
|
+
self.process_batch_result(tmp_batch, tmp_result)
|
508
|
+
|
509
|
+
if batch is None and (
|
510
|
+
len(self.disagg_decode_transfer_queue.queue)
|
511
|
+
+ len(self.disagg_decode_prealloc_queue.queue)
|
512
|
+
== 0
|
513
|
+
):
|
514
|
+
# When the server is idle, do self-check and re-init some states
|
515
|
+
self.check_memory()
|
516
|
+
self.new_token_ratio = self.init_new_token_ratio
|
517
|
+
|
518
|
+
self.last_batch = batch
|
519
|
+
self.last_batch_is_extend = last_batch_is_extend
|
520
|
+
|
478
521
|
def get_next_disagg_decode_batch_to_run(
|
479
522
|
self: Scheduler,
|
480
523
|
) -> Optional[Tuple[ScheduleBatch, bool]]:
|
@@ -23,8 +23,9 @@ class MiniLoadBalancer:
|
|
23
23
|
return random.choice(self.prefill_servers), random.choice(self.decode_servers)
|
24
24
|
|
25
25
|
async def generate(
|
26
|
-
self, modified_request, prefill_server, decode_server
|
26
|
+
self, modified_request, prefill_server, decode_server, endpoint
|
27
27
|
) -> ORJSONResponse:
|
28
|
+
assert endpoint[0] != "/", f"Endpoint should not start with '/': {endpoint}"
|
28
29
|
|
29
30
|
async with aiohttp.ClientSession(
|
30
31
|
timeout=aiohttp.ClientTimeout(
|
@@ -32,8 +33,8 @@ class MiniLoadBalancer:
|
|
32
33
|
) # Add timeout for request reliability
|
33
34
|
) as session:
|
34
35
|
tasks = [
|
35
|
-
session.post(f"{prefill_server}/
|
36
|
-
session.post(f"{decode_server}/
|
36
|
+
session.post(f"{prefill_server}/{endpoint}", json=modified_request),
|
37
|
+
session.post(f"{decode_server}/{endpoint}", json=modified_request),
|
37
38
|
]
|
38
39
|
# Wait for both responses to complete. Prefill should end first.
|
39
40
|
prefill_response, decode_response = await asyncio.gather(*tasks)
|
@@ -43,7 +44,11 @@ class MiniLoadBalancer:
|
|
43
44
|
status_code=decode_response.status,
|
44
45
|
)
|
45
46
|
|
46
|
-
async def generate_stream(
|
47
|
+
async def generate_stream(
|
48
|
+
self, modified_request, prefill_server, decode_server, endpoint="generate"
|
49
|
+
):
|
50
|
+
assert endpoint[0] != "/", f"Endpoint should not start with '/': {endpoint}"
|
51
|
+
|
47
52
|
async def stream_results():
|
48
53
|
async with aiohttp.ClientSession(
|
49
54
|
timeout=aiohttp.ClientTimeout(
|
@@ -54,10 +59,10 @@ class MiniLoadBalancer:
|
|
54
59
|
# Create the tasks for both prefill and decode requests
|
55
60
|
tasks = [
|
56
61
|
session.post(
|
57
|
-
f"{prefill_server}/
|
62
|
+
f"{prefill_server}/{endpoint}", json=modified_request
|
58
63
|
),
|
59
64
|
session.post(
|
60
|
-
f"{decode_server}/
|
65
|
+
f"{decode_server}/{endpoint}", json=modified_request
|
61
66
|
),
|
62
67
|
]
|
63
68
|
# Wait for both responses to complete. Since this is streaming, they return immediately.
|
@@ -157,6 +162,43 @@ async def get_model_info():
|
|
157
162
|
async def handle_generate_request(request_data: dict):
|
158
163
|
prefill_server, decode_server = load_balancer.select_pair()
|
159
164
|
|
165
|
+
# Parse and transform prefill_server for bootstrap data
|
166
|
+
parsed_url = urllib.parse.urlparse(prefill_server)
|
167
|
+
hostname = parsed_url.hostname
|
168
|
+
modified_request = request_data.copy()
|
169
|
+
|
170
|
+
batch_size = _get_request_batch_size(modified_request)
|
171
|
+
if batch_size is not None:
|
172
|
+
modified_request.update(
|
173
|
+
{
|
174
|
+
"bootstrap_host": [hostname] * batch_size,
|
175
|
+
"bootstrap_room": [
|
176
|
+
_generate_bootstrap_room() for _ in range(batch_size)
|
177
|
+
],
|
178
|
+
}
|
179
|
+
)
|
180
|
+
else:
|
181
|
+
modified_request.update(
|
182
|
+
{
|
183
|
+
"bootstrap_host": hostname,
|
184
|
+
"bootstrap_room": _generate_bootstrap_room(),
|
185
|
+
}
|
186
|
+
)
|
187
|
+
|
188
|
+
if request_data.get("stream", False):
|
189
|
+
return await load_balancer.generate_stream(
|
190
|
+
modified_request, prefill_server, decode_server, "generate"
|
191
|
+
)
|
192
|
+
else:
|
193
|
+
return await load_balancer.generate(
|
194
|
+
modified_request, prefill_server, decode_server, "generate"
|
195
|
+
)
|
196
|
+
|
197
|
+
|
198
|
+
@app.post("/v1/chat/completions")
|
199
|
+
async def handle_completion_request(request_data: dict):
|
200
|
+
prefill_server, decode_server = load_balancer.select_pair()
|
201
|
+
|
160
202
|
# Parse and transform prefill_server for bootstrap data
|
161
203
|
parsed_url = urllib.parse.urlparse(prefill_server)
|
162
204
|
hostname = parsed_url.hostname
|
@@ -170,14 +212,33 @@ async def handle_generate_request(request_data: dict):
|
|
170
212
|
|
171
213
|
if request_data.get("stream", False):
|
172
214
|
return await load_balancer.generate_stream(
|
173
|
-
modified_request,
|
215
|
+
modified_request,
|
216
|
+
prefill_server,
|
217
|
+
decode_server,
|
218
|
+
endpoint="v1/chat/completions",
|
174
219
|
)
|
175
220
|
else:
|
176
221
|
return await load_balancer.generate(
|
177
|
-
modified_request,
|
222
|
+
modified_request,
|
223
|
+
prefill_server,
|
224
|
+
decode_server,
|
225
|
+
endpoint="v1/chat/completions",
|
178
226
|
)
|
179
227
|
|
180
228
|
|
229
|
+
def _generate_bootstrap_room():
|
230
|
+
return random.randint(0, 2**63 - 1)
|
231
|
+
|
232
|
+
|
233
|
+
# We may utilize `GenerateReqInput`'s logic later
|
234
|
+
def _get_request_batch_size(request):
|
235
|
+
if (text := request.get("text")) is not None:
|
236
|
+
return None if isinstance(text, str) else len(text)
|
237
|
+
if (input_ids := request.get("input_ids")) is not None:
|
238
|
+
return None if isinstance(input_ids[0], int) else len(input_ids)
|
239
|
+
return None
|
240
|
+
|
241
|
+
|
181
242
|
@app.get("/v1/models")
|
182
243
|
async def get_models():
|
183
244
|
prefill_server = load_balancer.prefill_servers[0] # Get the first prefill server
|
@@ -231,7 +231,7 @@ class MooncakeKVManager(BaseKVManager):
|
|
231
231
|
chunked_dst_kv_indice = req.dst_kv_indices[kv_chunk.index_slice]
|
232
232
|
assert len(chunked_dst_kv_indice) == len(
|
233
233
|
kv_chunk.prefill_kv_indices
|
234
|
-
)
|
234
|
+
), f"len(chunked_dst_kv_indice) = {len(chunked_dst_kv_indice)}, len(kv_chunk.prefill_kv_indices) = {len(kv_chunk.prefill_kv_indices)}"
|
235
235
|
|
236
236
|
ret = self.send_kvcache(
|
237
237
|
req.mooncake_session_id,
|
@@ -0,0 +1 @@
|
|
1
|
+
from .conn import NixlKVBootstrapServer, NixlKVManager, NixlKVReceiver, NixlKVSender
|