sglang 0.4.4.post4__py3-none-any.whl → 0.4.5.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +21 -0
- sglang/bench_serving.py +10 -4
- sglang/lang/chat_template.py +24 -0
- sglang/srt/configs/model_config.py +40 -4
- sglang/srt/constrained/base_grammar_backend.py +26 -5
- sglang/srt/constrained/llguidance_backend.py +1 -0
- sglang/srt/constrained/outlines_backend.py +1 -0
- sglang/srt/constrained/reasoner_grammar_backend.py +101 -0
- sglang/srt/constrained/xgrammar_backend.py +1 -0
- sglang/srt/conversation.py +29 -4
- sglang/srt/disaggregation/base/__init__.py +8 -0
- sglang/srt/disaggregation/base/conn.py +113 -0
- sglang/srt/disaggregation/decode.py +18 -5
- sglang/srt/disaggregation/mini_lb.py +53 -122
- sglang/srt/disaggregation/mooncake/__init__.py +6 -0
- sglang/srt/disaggregation/mooncake/conn.py +615 -0
- sglang/srt/disaggregation/mooncake/transfer_engine.py +108 -0
- sglang/srt/disaggregation/prefill.py +43 -19
- sglang/srt/disaggregation/utils.py +31 -0
- sglang/srt/entrypoints/EngineBase.py +53 -0
- sglang/srt/entrypoints/engine.py +36 -8
- sglang/srt/entrypoints/http_server.py +37 -8
- sglang/srt/entrypoints/http_server_engine.py +142 -0
- sglang/srt/entrypoints/verl_engine.py +37 -10
- sglang/srt/hf_transformers_utils.py +4 -0
- sglang/srt/layers/attention/flashattention_backend.py +609 -202
- sglang/srt/layers/attention/flashinfer_backend.py +13 -7
- sglang/srt/layers/attention/vision.py +1 -1
- sglang/srt/layers/dp_attention.py +2 -4
- sglang/srt/layers/elementwise.py +15 -2
- sglang/srt/layers/linear.py +1 -0
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +145 -118
- sglang/srt/layers/moe/fused_moe_native.py +5 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=512,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=144,N=512,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1024,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=20,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=24,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/{E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=264,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +34 -34
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=288,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +51 -24
- sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
- sglang/srt/layers/moe/router.py +7 -1
- sglang/srt/layers/moe/topk.py +37 -16
- sglang/srt/layers/quantization/__init__.py +13 -5
- sglang/srt/layers/quantization/blockwise_int8.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +4 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +68 -45
- sglang/srt/layers/quantization/fp8.py +28 -14
- sglang/srt/layers/quantization/fp8_kernel.py +130 -4
- sglang/srt/layers/quantization/fp8_utils.py +34 -6
- sglang/srt/layers/quantization/kv_cache.py +43 -52
- sglang/srt/layers/quantization/modelopt_quant.py +271 -4
- sglang/srt/layers/quantization/moe_wna16.py +2 -0
- sglang/srt/layers/quantization/w8a8_fp8.py +154 -4
- sglang/srt/layers/quantization/w8a8_int8.py +3 -0
- sglang/srt/layers/radix_attention.py +14 -0
- sglang/srt/layers/rotary_embedding.py +75 -1
- sglang/srt/managers/io_struct.py +254 -97
- sglang/srt/managers/mm_utils.py +3 -2
- sglang/srt/managers/multimodal_processors/base_processor.py +114 -77
- sglang/srt/managers/multimodal_processors/janus_pro.py +3 -1
- sglang/srt/managers/multimodal_processors/mllama4.py +146 -0
- sglang/srt/managers/schedule_batch.py +62 -21
- sglang/srt/managers/scheduler.py +71 -14
- sglang/srt/managers/tokenizer_manager.py +17 -3
- sglang/srt/managers/tp_worker.py +1 -0
- sglang/srt/mem_cache/memory_pool.py +14 -1
- sglang/srt/metrics/collector.py +9 -0
- sglang/srt/model_executor/cuda_graph_runner.py +7 -4
- sglang/srt/model_executor/forward_batch_info.py +234 -15
- sglang/srt/model_executor/model_runner.py +49 -9
- sglang/srt/model_loader/loader.py +31 -4
- sglang/srt/model_loader/weight_utils.py +4 -2
- sglang/srt/models/baichuan.py +2 -0
- sglang/srt/models/chatglm.py +1 -0
- sglang/srt/models/commandr.py +1 -0
- sglang/srt/models/dbrx.py +1 -0
- sglang/srt/models/deepseek.py +1 -0
- sglang/srt/models/deepseek_v2.py +248 -61
- sglang/srt/models/exaone.py +1 -0
- sglang/srt/models/gemma.py +1 -0
- sglang/srt/models/gemma2.py +1 -0
- sglang/srt/models/gemma3_causal.py +1 -0
- sglang/srt/models/gpt2.py +1 -0
- sglang/srt/models/gpt_bigcode.py +1 -0
- sglang/srt/models/granite.py +1 -0
- sglang/srt/models/grok.py +1 -0
- sglang/srt/models/internlm2.py +1 -0
- sglang/srt/models/llama.py +13 -4
- sglang/srt/models/llama4.py +487 -0
- sglang/srt/models/minicpm.py +1 -0
- sglang/srt/models/minicpm3.py +2 -0
- sglang/srt/models/mixtral.py +1 -0
- sglang/srt/models/mixtral_quant.py +1 -0
- sglang/srt/models/mllama.py +51 -8
- sglang/srt/models/mllama4.py +227 -0
- sglang/srt/models/olmo.py +1 -0
- sglang/srt/models/olmo2.py +1 -0
- sglang/srt/models/olmoe.py +1 -0
- sglang/srt/models/phi3_small.py +1 -0
- sglang/srt/models/qwen.py +1 -0
- sglang/srt/models/qwen2.py +1 -0
- sglang/srt/models/qwen2_5_vl.py +35 -70
- sglang/srt/models/qwen2_moe.py +1 -0
- sglang/srt/models/qwen2_vl.py +27 -25
- sglang/srt/models/stablelm.py +1 -0
- sglang/srt/models/xverse.py +1 -0
- sglang/srt/models/xverse_moe.py +1 -0
- sglang/srt/openai_api/adapter.py +4 -1
- sglang/srt/patch_torch.py +11 -0
- sglang/srt/server_args.py +34 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +4 -4
- sglang/srt/speculative/eagle_utils.py +1 -11
- sglang/srt/speculative/eagle_worker.py +6 -2
- sglang/srt/utils.py +120 -9
- sglang/test/attention/test_flashattn_backend.py +259 -221
- sglang/test/attention/test_flashattn_mla_backend.py +285 -0
- sglang/test/attention/test_prefix_chunk_info.py +224 -0
- sglang/test/test_block_fp8.py +57 -0
- sglang/test/test_utils.py +19 -8
- sglang/version.py +1 -1
- {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/METADATA +14 -4
- {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/RECORD +133 -109
- sglang/srt/disaggregation/conn.py +0 -81
- {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/top_level.txt +0 -0
@@ -24,12 +24,17 @@ import logging
|
|
24
24
|
from dataclasses import dataclass
|
25
25
|
from typing import TYPE_CHECKING, List, Optional, Tuple
|
26
26
|
|
27
|
+
import numpy as np
|
27
28
|
import torch
|
28
29
|
from torch.distributed import ProcessGroup
|
29
30
|
|
30
|
-
from sglang.srt.disaggregation.
|
31
|
+
from sglang.srt.disaggregation.base import BaseKVManager, BaseKVReceiver, KVArgs, KVPoll
|
31
32
|
from sglang.srt.disaggregation.utils import (
|
33
|
+
DisaggregationMode,
|
34
|
+
KVClassType,
|
32
35
|
ReqToMetadataIdxAllocator,
|
36
|
+
TransferBackend,
|
37
|
+
get_kv_class,
|
33
38
|
poll_and_all_reduce,
|
34
39
|
)
|
35
40
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
@@ -49,7 +54,7 @@ if TYPE_CHECKING:
|
|
49
54
|
@dataclass
|
50
55
|
class DecodeRequest:
|
51
56
|
req: Req
|
52
|
-
kv_receiver:
|
57
|
+
kv_receiver: BaseKVReceiver
|
53
58
|
waiting_for_input: bool = False
|
54
59
|
metadata_buffer_index: int = -1
|
55
60
|
|
@@ -73,6 +78,7 @@ class DecodePreallocQueue:
|
|
73
78
|
tp_rank: int,
|
74
79
|
tp_size: int,
|
75
80
|
bootstrap_port: int,
|
81
|
+
transfer_backend: TransferBackend,
|
76
82
|
):
|
77
83
|
self.req_to_token_pool = req_to_token_pool
|
78
84
|
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
|
@@ -92,9 +98,10 @@ class DecodePreallocQueue:
|
|
92
98
|
|
93
99
|
# Queue for requests pending pre-allocation
|
94
100
|
self.queue: List[DecodeRequest] = []
|
101
|
+
self.transfer_backend = transfer_backend
|
95
102
|
self.kv_manager = self._init_kv_manager()
|
96
103
|
|
97
|
-
def _init_kv_manager(self) ->
|
104
|
+
def _init_kv_manager(self) -> BaseKVManager:
|
98
105
|
kv_args = KVArgs()
|
99
106
|
kv_args.engine_rank = self.tp_rank
|
100
107
|
kv_data_ptrs, kv_data_lens, kv_item_lens = (
|
@@ -115,13 +122,18 @@ class DecodePreallocQueue:
|
|
115
122
|
metadata_buffer[0].nbytes for metadata_buffer in self.metadata_buffers
|
116
123
|
]
|
117
124
|
kv_args.ib_device = "mock-ib-device"
|
118
|
-
|
125
|
+
kv_args.gpu_id = self.scheduler.gpu_id
|
126
|
+
kv_manager_class = get_kv_class(self.transfer_backend, KVClassType.MANAGER)
|
127
|
+
kv_manager = kv_manager_class(
|
128
|
+
kv_args, DisaggregationMode.DECODE, self.scheduler.server_args
|
129
|
+
)
|
119
130
|
return kv_manager
|
120
131
|
|
121
132
|
def add(self, req: Req) -> None:
|
122
133
|
"""Add a request to the pending queue."""
|
123
134
|
|
124
|
-
|
135
|
+
kv_receiver_class = get_kv_class(self.transfer_backend, KVClassType.RECEIVER)
|
136
|
+
kv_receiver = kv_receiver_class(
|
125
137
|
mgr=self.kv_manager,
|
126
138
|
bootstrap_addr=f"{req.bootstrap_host}:{self.bootstrap_port}",
|
127
139
|
bootstrap_room=req.bootstrap_room,
|
@@ -186,6 +198,7 @@ class DecodePreallocQueue:
|
|
186
198
|
]
|
187
199
|
.cpu()
|
188
200
|
.numpy()
|
201
|
+
.astype(np.int64)
|
189
202
|
)
|
190
203
|
|
191
204
|
decode_req.metadata_buffer_index = (
|
@@ -1,5 +1,5 @@
|
|
1
1
|
"""
|
2
|
-
Minimal HTTP load balancer for prefill and decode servers for testing
|
2
|
+
Minimal HTTP load balancer for prefill and decode servers for testing.
|
3
3
|
"""
|
4
4
|
|
5
5
|
import asyncio
|
@@ -22,64 +22,59 @@ class MiniLoadBalancer:
|
|
22
22
|
def select_pair(self):
|
23
23
|
return random.choice(self.prefill_servers), random.choice(self.decode_servers)
|
24
24
|
|
25
|
-
async def
|
26
|
-
prefill_server, decode_server
|
27
|
-
|
28
|
-
# Parse and transform prefill_server
|
29
|
-
parsed_url = urllib.parse.urlparse(prefill_server)
|
30
|
-
hostname = parsed_url.hostname
|
31
|
-
bootstrap_host = f"{hostname}"
|
32
|
-
|
33
|
-
modified_request = request_data.copy()
|
34
|
-
modified_request.update(
|
35
|
-
{
|
36
|
-
"bootstrap_host": bootstrap_host,
|
37
|
-
"bootstrap_room": random.randint(0, 2**63 - 1),
|
38
|
-
}
|
39
|
-
)
|
25
|
+
async def generate(
|
26
|
+
self, modified_request, prefill_server, decode_server
|
27
|
+
) -> ORJSONResponse:
|
40
28
|
|
41
29
|
async with aiohttp.ClientSession() as session:
|
42
|
-
# Create the tasks
|
43
30
|
tasks = [
|
44
31
|
session.post(f"{prefill_server}/generate", json=modified_request),
|
45
32
|
session.post(f"{decode_server}/generate", json=modified_request),
|
46
33
|
]
|
34
|
+
# Wait for both responses to complete. Prefill should end first.
|
35
|
+
prefill_response, decode_response = await asyncio.gather(*tasks)
|
36
|
+
|
37
|
+
return ORJSONResponse(
|
38
|
+
content=await decode_response.json(),
|
39
|
+
status_code=decode_response.status,
|
40
|
+
)
|
41
|
+
|
42
|
+
async def generate_stream(self, modified_request, prefill_server, decode_server):
|
43
|
+
async def stream_results():
|
44
|
+
async with aiohttp.ClientSession(
|
45
|
+
timeout=aiohttp.ClientTimeout(
|
46
|
+
total=3600
|
47
|
+
) # Add timeout for request reliability
|
48
|
+
) as session:
|
49
|
+
try:
|
50
|
+
# Create the tasks for both prefill and decode requests
|
51
|
+
tasks = [
|
52
|
+
session.post(
|
53
|
+
f"{prefill_server}/generate", json=modified_request
|
54
|
+
),
|
55
|
+
session.post(
|
56
|
+
f"{decode_server}/generate", json=modified_request
|
57
|
+
),
|
58
|
+
]
|
59
|
+
# Wait for both responses to complete. Since this is streaming, they return immediately.
|
60
|
+
prefill_response, decode_response = await asyncio.gather(*tasks)
|
61
|
+
async for chunk in decode_response.content:
|
62
|
+
yield chunk
|
63
|
+
except Exception as e:
|
64
|
+
error_msg = {
|
65
|
+
"error": {"message": f"Stream processing error: {str(e)}"}
|
66
|
+
}
|
67
|
+
yield b"data: " + orjson.dumps(
|
68
|
+
error_msg, option=orjson.OPT_NON_STR_KEYS
|
69
|
+
) + b"\n\n"
|
70
|
+
finally:
|
71
|
+
if prefill_response is not None:
|
72
|
+
await prefill_response.release()
|
47
73
|
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
for i, response in enumerate(asyncio.as_completed(tasks)):
|
53
|
-
response = await response
|
54
|
-
# Check if this is the prefill or decode response based on order created
|
55
|
-
if i == 0: # First completed task
|
56
|
-
if str(response.url).startswith(prefill_server):
|
57
|
-
prefill_response = response
|
58
|
-
if response.status != 200:
|
59
|
-
raise HTTPException(
|
60
|
-
status_code=response.status,
|
61
|
-
detail=f"Prefill server error: Status {response.status} Details: {await response.text()}",
|
62
|
-
)
|
63
|
-
else:
|
64
|
-
decode_response = response
|
65
|
-
if response.status != 200:
|
66
|
-
raise HTTPException(
|
67
|
-
status_code=response.status,
|
68
|
-
detail=f"Decode server error: Status {response.status} Details: {await response.text()}",
|
69
|
-
)
|
70
|
-
else: # Second completed task
|
71
|
-
if str(response.url).startswith(prefill_server):
|
72
|
-
prefill_response = response
|
73
|
-
else:
|
74
|
-
decode_response = response
|
75
|
-
|
76
|
-
if response.status != 200:
|
77
|
-
raise HTTPException(
|
78
|
-
status_code=response.status,
|
79
|
-
detail=f"{'Prefill' if str(response.url).startswith(prefill_server) else 'Decode'} server error: Status {response.status} Details: {await response.text()}",
|
80
|
-
)
|
81
|
-
|
82
|
-
return await decode_response.json()
|
74
|
+
return StreamingResponse(
|
75
|
+
stream_results(),
|
76
|
+
media_type="text/event-stream",
|
77
|
+
)
|
83
78
|
|
84
79
|
|
85
80
|
app = FastAPI()
|
@@ -169,78 +164,14 @@ async def handle_generate_request(request_data: dict):
|
|
169
164
|
}
|
170
165
|
)
|
171
166
|
|
172
|
-
# Check if streaming is requested
|
173
167
|
if request_data.get("stream", False):
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
# Create the tasks
|
181
|
-
tasks = [
|
182
|
-
session.post(
|
183
|
-
f"{prefill_server}/generate", json=modified_request
|
184
|
-
),
|
185
|
-
session.post(
|
186
|
-
f"{decode_server}/generate", json=modified_request
|
187
|
-
),
|
188
|
-
]
|
189
|
-
|
190
|
-
prefill_response = None
|
191
|
-
decode_response = None
|
192
|
-
|
193
|
-
# Process responses as they arrive
|
194
|
-
for i, response_task in enumerate(asyncio.as_completed(tasks)):
|
195
|
-
response = await response_task
|
196
|
-
|
197
|
-
# Check the response immediately
|
198
|
-
if str(response.url).startswith(prefill_server):
|
199
|
-
prefill_response = response
|
200
|
-
if response.status != 200:
|
201
|
-
error_msg = {
|
202
|
-
"error": {
|
203
|
-
"message": f"Prefill server error: Status {response.status}, Details: {await response.text()}"
|
204
|
-
}
|
205
|
-
}
|
206
|
-
yield b"data: " + orjson.dumps(
|
207
|
-
error_msg, option=orjson.OPT_NON_STR_KEYS
|
208
|
-
) + b"\n\n"
|
209
|
-
return
|
210
|
-
else:
|
211
|
-
decode_response = response
|
212
|
-
if response.status != 200:
|
213
|
-
error_msg = {
|
214
|
-
"error": {
|
215
|
-
"message": f"Decode server error: Status {response.status}"
|
216
|
-
}
|
217
|
-
}
|
218
|
-
yield b"data: " + orjson.dumps(
|
219
|
-
error_msg, option=orjson.OPT_NON_STR_KEYS
|
220
|
-
) + b"\n\n"
|
221
|
-
return
|
222
|
-
|
223
|
-
# Stream successful decode server response
|
224
|
-
async for line in decode_response.content:
|
225
|
-
yield line
|
226
|
-
yield b"data: [DONE]\n\n"
|
227
|
-
|
228
|
-
except Exception as e:
|
229
|
-
error_msg = {
|
230
|
-
"error": {"message": f"Stream processing error: {str(e)}"}
|
231
|
-
}
|
232
|
-
yield b"data: " + orjson.dumps(
|
233
|
-
error_msg, option=orjson.OPT_NON_STR_KEYS
|
234
|
-
) + b"\n\n"
|
235
|
-
|
236
|
-
return StreamingResponse(
|
237
|
-
stream_results(),
|
238
|
-
media_type="text/event-stream",
|
168
|
+
return await load_balancer.generate_stream(
|
169
|
+
modified_request, prefill_server, decode_server
|
170
|
+
)
|
171
|
+
else:
|
172
|
+
return await load_balancer.generate(
|
173
|
+
modified_request, prefill_server, decode_server
|
239
174
|
)
|
240
|
-
|
241
|
-
# Non-streaming case
|
242
|
-
result = await load_balancer.generate_request(request_data)
|
243
|
-
return ORJSONResponse(content=result)
|
244
175
|
|
245
176
|
|
246
177
|
@app.get("/v1/models")
|