sglang 0.4.4__py3-none-any.whl → 0.4.4.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -0
- sglang/api.py +6 -0
- sglang/bench_one_batch.py +1 -1
- sglang/bench_one_batch_server.py +1 -1
- sglang/bench_serving.py +3 -1
- sglang/check_env.py +3 -4
- sglang/lang/backend/openai.py +18 -5
- sglang/lang/chat_template.py +28 -7
- sglang/lang/interpreter.py +7 -3
- sglang/lang/ir.py +10 -0
- sglang/srt/_custom_ops.py +1 -1
- sglang/srt/code_completion_parser.py +174 -0
- sglang/srt/configs/__init__.py +2 -6
- sglang/srt/configs/deepseekvl2.py +667 -0
- sglang/srt/configs/janus_pro.py +3 -4
- sglang/srt/configs/load_config.py +1 -0
- sglang/srt/configs/model_config.py +63 -11
- sglang/srt/configs/utils.py +25 -0
- sglang/srt/connector/__init__.py +51 -0
- sglang/srt/connector/base_connector.py +112 -0
- sglang/srt/connector/redis.py +85 -0
- sglang/srt/connector/s3.py +122 -0
- sglang/srt/connector/serde/__init__.py +31 -0
- sglang/srt/connector/serde/safe_serde.py +29 -0
- sglang/srt/connector/serde/serde.py +43 -0
- sglang/srt/connector/utils.py +35 -0
- sglang/srt/conversation.py +88 -0
- sglang/srt/disaggregation/conn.py +81 -0
- sglang/srt/disaggregation/decode.py +495 -0
- sglang/srt/disaggregation/mini_lb.py +285 -0
- sglang/srt/disaggregation/prefill.py +249 -0
- sglang/srt/disaggregation/utils.py +44 -0
- sglang/srt/distributed/parallel_state.py +10 -3
- sglang/srt/entrypoints/engine.py +55 -5
- sglang/srt/entrypoints/http_server.py +71 -12
- sglang/srt/function_call_parser.py +164 -54
- sglang/srt/hf_transformers_utils.py +28 -3
- sglang/srt/layers/activation.py +4 -2
- sglang/srt/layers/attention/base_attn_backend.py +1 -1
- sglang/srt/layers/attention/flashattention_backend.py +295 -0
- sglang/srt/layers/attention/flashinfer_backend.py +1 -1
- sglang/srt/layers/attention/flashmla_backend.py +284 -0
- sglang/srt/layers/attention/triton_backend.py +171 -38
- sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
- sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
- sglang/srt/layers/attention/utils.py +53 -0
- sglang/srt/layers/attention/vision.py +9 -28
- sglang/srt/layers/dp_attention.py +62 -23
- sglang/srt/layers/elementwise.py +411 -0
- sglang/srt/layers/layernorm.py +24 -2
- sglang/srt/layers/linear.py +17 -5
- sglang/srt/layers/logits_processor.py +26 -7
- sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
- sglang/srt/layers/moe/ep_moe/layer.py +273 -1
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
- sglang/srt/layers/moe/fused_moe_native.py +2 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
- sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
- sglang/srt/layers/moe/router.py +342 -0
- sglang/srt/layers/moe/topk.py +31 -18
- sglang/srt/layers/parameter.py +1 -1
- sglang/srt/layers/quantization/__init__.py +184 -126
- sglang/srt/layers/quantization/base_config.py +5 -0
- sglang/srt/layers/quantization/blockwise_int8.py +1 -1
- sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
- sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
- sglang/srt/layers/quantization/fp8.py +76 -34
- sglang/srt/layers/quantization/fp8_kernel.py +24 -8
- sglang/srt/layers/quantization/fp8_utils.py +284 -28
- sglang/srt/layers/quantization/gptq.py +36 -9
- sglang/srt/layers/quantization/kv_cache.py +98 -0
- sglang/srt/layers/quantization/modelopt_quant.py +9 -7
- sglang/srt/layers/quantization/utils.py +153 -0
- sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
- sglang/srt/layers/rotary_embedding.py +66 -87
- sglang/srt/layers/sampler.py +1 -1
- sglang/srt/lora/layers.py +68 -0
- sglang/srt/lora/lora.py +2 -22
- sglang/srt/lora/lora_manager.py +47 -23
- sglang/srt/lora/mem_pool.py +110 -51
- sglang/srt/lora/utils.py +12 -1
- sglang/srt/managers/cache_controller.py +4 -5
- sglang/srt/managers/data_parallel_controller.py +31 -9
- sglang/srt/managers/expert_distribution.py +81 -0
- sglang/srt/managers/io_struct.py +39 -3
- sglang/srt/managers/mm_utils.py +373 -0
- sglang/srt/managers/multimodal_processor.py +68 -0
- sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
- sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
- sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
- sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
- sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
- sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
- sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
- sglang/srt/managers/schedule_batch.py +134 -31
- sglang/srt/managers/scheduler.py +325 -38
- sglang/srt/managers/scheduler_output_processor_mixin.py +4 -1
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +59 -23
- sglang/srt/managers/tp_worker.py +1 -1
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
- sglang/srt/managers/utils.py +6 -1
- sglang/srt/mem_cache/hiradix_cache.py +27 -8
- sglang/srt/mem_cache/memory_pool.py +258 -98
- sglang/srt/mem_cache/paged_allocator.py +2 -2
- sglang/srt/mem_cache/radix_cache.py +4 -4
- sglang/srt/model_executor/cuda_graph_runner.py +85 -28
- sglang/srt/model_executor/forward_batch_info.py +81 -15
- sglang/srt/model_executor/model_runner.py +70 -6
- sglang/srt/model_loader/loader.py +160 -2
- sglang/srt/model_loader/weight_utils.py +45 -0
- sglang/srt/models/deepseek_janus_pro.py +29 -86
- sglang/srt/models/deepseek_nextn.py +22 -10
- sglang/srt/models/deepseek_v2.py +326 -192
- sglang/srt/models/deepseek_vl2.py +358 -0
- sglang/srt/models/gemma3_causal.py +684 -0
- sglang/srt/models/gemma3_mm.py +462 -0
- sglang/srt/models/grok.py +374 -119
- sglang/srt/models/llama.py +47 -7
- sglang/srt/models/llama_eagle.py +1 -0
- sglang/srt/models/llama_eagle3.py +196 -0
- sglang/srt/models/llava.py +3 -3
- sglang/srt/models/llavavid.py +3 -3
- sglang/srt/models/minicpmo.py +1995 -0
- sglang/srt/models/minicpmv.py +62 -137
- sglang/srt/models/mllama.py +4 -4
- sglang/srt/models/phi3_small.py +1 -1
- sglang/srt/models/qwen2.py +3 -0
- sglang/srt/models/qwen2_5_vl.py +68 -146
- sglang/srt/models/qwen2_classification.py +75 -0
- sglang/srt/models/qwen2_moe.py +9 -1
- sglang/srt/models/qwen2_vl.py +25 -63
- sglang/srt/openai_api/adapter.py +145 -47
- sglang/srt/openai_api/protocol.py +23 -2
- sglang/srt/sampling/sampling_batch_info.py +1 -1
- sglang/srt/sampling/sampling_params.py +6 -6
- sglang/srt/server_args.py +104 -14
- sglang/srt/speculative/build_eagle_tree.py +7 -347
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
- sglang/srt/speculative/eagle_utils.py +208 -252
- sglang/srt/speculative/eagle_worker.py +139 -53
- sglang/srt/speculative/spec_info.py +6 -1
- sglang/srt/torch_memory_saver_adapter.py +22 -0
- sglang/srt/utils.py +182 -21
- sglang/test/__init__.py +0 -0
- sglang/test/attention/__init__.py +0 -0
- sglang/test/attention/test_flashattn_backend.py +312 -0
- sglang/test/runners.py +2 -0
- sglang/test/test_activation.py +2 -1
- sglang/test/test_block_fp8.py +5 -4
- sglang/test/test_block_fp8_ep.py +2 -1
- sglang/test/test_dynamic_grad_mode.py +58 -0
- sglang/test/test_layernorm.py +3 -2
- sglang/test/test_utils.py +55 -4
- sglang/utils.py +31 -0
- sglang/version.py +1 -1
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/METADATA +12 -8
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/RECORD +171 -125
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/WHEEL +1 -1
- sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
- sglang/srt/managers/image_processor.py +0 -55
- sglang/srt/managers/image_processors/base_image_processor.py +0 -219
- sglang/srt/managers/image_processors/minicpmv.py +0 -86
- sglang/srt/managers/multi_modality_padding.py +0 -134
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info/licenses}/LICENSE +0 -0
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,285 @@
|
|
1
|
+
"""
|
2
|
+
Minimal HTTP load balancer for prefill and decode servers for testing purpose.
|
3
|
+
"""
|
4
|
+
|
5
|
+
import asyncio
|
6
|
+
import random
|
7
|
+
import urllib
|
8
|
+
from itertools import chain
|
9
|
+
|
10
|
+
import aiohttp
|
11
|
+
import orjson
|
12
|
+
import uvicorn
|
13
|
+
from fastapi import FastAPI, HTTPException
|
14
|
+
from fastapi.responses import ORJSONResponse, Response, StreamingResponse
|
15
|
+
|
16
|
+
|
17
|
+
class MiniLoadBalancer:
|
18
|
+
def __init__(self, prefill_servers, decode_servers):
|
19
|
+
self.prefill_servers = prefill_servers
|
20
|
+
self.decode_servers = decode_servers
|
21
|
+
|
22
|
+
def select_pair(self):
|
23
|
+
return random.choice(self.prefill_servers), random.choice(self.decode_servers)
|
24
|
+
|
25
|
+
async def generate_request(self, request_data):
|
26
|
+
prefill_server, decode_server = self.select_pair()
|
27
|
+
|
28
|
+
# Parse and transform prefill_server
|
29
|
+
parsed_url = urllib.parse.urlparse(prefill_server)
|
30
|
+
hostname = parsed_url.hostname
|
31
|
+
bootstrap_host = f"{hostname}"
|
32
|
+
|
33
|
+
modified_request = request_data.copy()
|
34
|
+
modified_request.update(
|
35
|
+
{
|
36
|
+
"bootstrap_host": bootstrap_host,
|
37
|
+
"bootstrap_room": random.randint(0, 2**63 - 1),
|
38
|
+
}
|
39
|
+
)
|
40
|
+
|
41
|
+
async with aiohttp.ClientSession() as session:
|
42
|
+
# Create the tasks
|
43
|
+
tasks = [
|
44
|
+
session.post(f"{prefill_server}/generate", json=modified_request),
|
45
|
+
session.post(f"{decode_server}/generate", json=modified_request),
|
46
|
+
]
|
47
|
+
|
48
|
+
prefill_response = None
|
49
|
+
decode_response = None
|
50
|
+
|
51
|
+
# Process responses as they arrive
|
52
|
+
for i, response in enumerate(asyncio.as_completed(tasks)):
|
53
|
+
response = await response
|
54
|
+
# Check if this is the prefill or decode response based on order created
|
55
|
+
if i == 0: # First completed task
|
56
|
+
if str(response.url).startswith(prefill_server):
|
57
|
+
prefill_response = response
|
58
|
+
if response.status != 200:
|
59
|
+
raise HTTPException(
|
60
|
+
status_code=response.status,
|
61
|
+
detail=f"Prefill server error: Status {response.status} Details: {await response.text()}",
|
62
|
+
)
|
63
|
+
else:
|
64
|
+
decode_response = response
|
65
|
+
if response.status != 200:
|
66
|
+
raise HTTPException(
|
67
|
+
status_code=response.status,
|
68
|
+
detail=f"Decode server error: Status {response.status} Details: {await response.text()}",
|
69
|
+
)
|
70
|
+
else: # Second completed task
|
71
|
+
if str(response.url).startswith(prefill_server):
|
72
|
+
prefill_response = response
|
73
|
+
else:
|
74
|
+
decode_response = response
|
75
|
+
|
76
|
+
if response.status != 200:
|
77
|
+
raise HTTPException(
|
78
|
+
status_code=response.status,
|
79
|
+
detail=f"{'Prefill' if str(response.url).startswith(prefill_server) else 'Decode'} server error: Status {response.status} Details: {await response.text()}",
|
80
|
+
)
|
81
|
+
|
82
|
+
return await decode_response.json()
|
83
|
+
|
84
|
+
|
85
|
+
app = FastAPI()
|
86
|
+
load_balancer = None
|
87
|
+
|
88
|
+
|
89
|
+
@app.get("/health")
|
90
|
+
async def health_check():
|
91
|
+
return Response(status_code=200)
|
92
|
+
|
93
|
+
|
94
|
+
@app.get("/health_generate")
|
95
|
+
async def health_check():
|
96
|
+
prefill_servers, decode_servers = (
|
97
|
+
load_balancer.prefill_servers,
|
98
|
+
load_balancer.decode_servers,
|
99
|
+
)
|
100
|
+
async with aiohttp.ClientSession() as session:
|
101
|
+
# Create the tasks
|
102
|
+
tasks = []
|
103
|
+
for server in chain(prefill_servers, decode_servers):
|
104
|
+
tasks.append(session.post(f"{server}/health_generate"))
|
105
|
+
for i, response in enumerate(asyncio.as_completed(tasks)):
|
106
|
+
await response
|
107
|
+
return Response(status_code=200)
|
108
|
+
|
109
|
+
|
110
|
+
@app.post("/flush_cache")
|
111
|
+
async def flush_cache():
|
112
|
+
prefill_servers, decode_servers = (
|
113
|
+
load_balancer.prefill_servers,
|
114
|
+
load_balancer.decode_servers,
|
115
|
+
)
|
116
|
+
async with aiohttp.ClientSession() as session:
|
117
|
+
# Create the tasks
|
118
|
+
tasks = []
|
119
|
+
for server in chain(prefill_servers, decode_servers):
|
120
|
+
tasks.append(session.post(f"{server}/flush_cache"))
|
121
|
+
for i, response in enumerate(asyncio.as_completed(tasks)):
|
122
|
+
await response
|
123
|
+
return Response(status_code=200)
|
124
|
+
|
125
|
+
|
126
|
+
@app.get("/get_server_info")
|
127
|
+
async def get_server_info():
|
128
|
+
prefill_servers, decode_servers = (
|
129
|
+
load_balancer.prefill_servers,
|
130
|
+
load_balancer.decode_servers,
|
131
|
+
)
|
132
|
+
prefill_infos = []
|
133
|
+
decode_infos = []
|
134
|
+
async with aiohttp.ClientSession() as session:
|
135
|
+
for server in chain(prefill_servers):
|
136
|
+
server_info = await session.get(f"{server}/get_server_info")
|
137
|
+
prefill_infos.append(await server_info.json())
|
138
|
+
for server in chain(decode_servers):
|
139
|
+
server_info = await session.get(f"{server}/get_server_info")
|
140
|
+
decode_infos.append(await server_info.json())
|
141
|
+
|
142
|
+
return {"prefill": prefill_infos, "decode": decode_infos}
|
143
|
+
|
144
|
+
|
145
|
+
@app.get("/get_model_info")
|
146
|
+
async def get_model_info():
|
147
|
+
# Dummy model information
|
148
|
+
model_info = {
|
149
|
+
"model_path": "/path/to/dummy/model",
|
150
|
+
"tokenizer_path": "/path/to/dummy/tokenizer",
|
151
|
+
"is_generation": True,
|
152
|
+
"preferred_sampling_params": {"temperature": 0.7, "max_new_tokens": 128},
|
153
|
+
}
|
154
|
+
return ORJSONResponse(content=model_info)
|
155
|
+
|
156
|
+
|
157
|
+
@app.post("/generate")
|
158
|
+
async def handle_generate_request(request_data: dict):
|
159
|
+
prefill_server, decode_server = load_balancer.select_pair()
|
160
|
+
|
161
|
+
# Parse and transform prefill_server for bootstrap data
|
162
|
+
parsed_url = urllib.parse.urlparse(prefill_server)
|
163
|
+
hostname = parsed_url.hostname
|
164
|
+
modified_request = request_data.copy()
|
165
|
+
modified_request.update(
|
166
|
+
{
|
167
|
+
"bootstrap_host": hostname,
|
168
|
+
"bootstrap_room": random.randint(0, 2**63 - 1),
|
169
|
+
}
|
170
|
+
)
|
171
|
+
|
172
|
+
# Check if streaming is requested
|
173
|
+
if request_data.get("stream", False):
|
174
|
+
|
175
|
+
async def stream_results():
|
176
|
+
async with aiohttp.ClientSession(
|
177
|
+
timeout=aiohttp.ClientTimeout(total=3600)
|
178
|
+
) as session:
|
179
|
+
try:
|
180
|
+
# Create the tasks
|
181
|
+
tasks = [
|
182
|
+
session.post(
|
183
|
+
f"{prefill_server}/generate", json=modified_request
|
184
|
+
),
|
185
|
+
session.post(
|
186
|
+
f"{decode_server}/generate", json=modified_request
|
187
|
+
),
|
188
|
+
]
|
189
|
+
|
190
|
+
prefill_response = None
|
191
|
+
decode_response = None
|
192
|
+
|
193
|
+
# Process responses as they arrive
|
194
|
+
for i, response_task in enumerate(asyncio.as_completed(tasks)):
|
195
|
+
response = await response_task
|
196
|
+
|
197
|
+
# Check the response immediately
|
198
|
+
if str(response.url).startswith(prefill_server):
|
199
|
+
prefill_response = response
|
200
|
+
if response.status != 200:
|
201
|
+
error_msg = {
|
202
|
+
"error": {
|
203
|
+
"message": f"Prefill server error: Status {response.status}, Details: {await response.text()}"
|
204
|
+
}
|
205
|
+
}
|
206
|
+
yield b"data: " + orjson.dumps(
|
207
|
+
error_msg, option=orjson.OPT_NON_STR_KEYS
|
208
|
+
) + b"\n\n"
|
209
|
+
return
|
210
|
+
else:
|
211
|
+
decode_response = response
|
212
|
+
if response.status != 200:
|
213
|
+
error_msg = {
|
214
|
+
"error": {
|
215
|
+
"message": f"Decode server error: Status {response.status}"
|
216
|
+
}
|
217
|
+
}
|
218
|
+
yield b"data: " + orjson.dumps(
|
219
|
+
error_msg, option=orjson.OPT_NON_STR_KEYS
|
220
|
+
) + b"\n\n"
|
221
|
+
return
|
222
|
+
|
223
|
+
# Stream successful decode server response
|
224
|
+
async for line in decode_response.content:
|
225
|
+
yield line
|
226
|
+
yield b"data: [DONE]\n\n"
|
227
|
+
|
228
|
+
except Exception as e:
|
229
|
+
error_msg = {
|
230
|
+
"error": {"message": f"Stream processing error: {str(e)}"}
|
231
|
+
}
|
232
|
+
yield b"data: " + orjson.dumps(
|
233
|
+
error_msg, option=orjson.OPT_NON_STR_KEYS
|
234
|
+
) + b"\n\n"
|
235
|
+
|
236
|
+
return StreamingResponse(
|
237
|
+
stream_results(),
|
238
|
+
media_type="text/event-stream",
|
239
|
+
)
|
240
|
+
|
241
|
+
# Non-streaming case
|
242
|
+
result = await load_balancer.generate_request(request_data)
|
243
|
+
return ORJSONResponse(content=result)
|
244
|
+
|
245
|
+
|
246
|
+
@app.get("/v1/models")
|
247
|
+
async def get_models():
|
248
|
+
prefill_server = load_balancer.prefill_servers[0] # Get the first prefill server
|
249
|
+
async with aiohttp.ClientSession() as session:
|
250
|
+
try:
|
251
|
+
response = await session.get(f"{prefill_server}/v1/models")
|
252
|
+
if response.status != 200:
|
253
|
+
raise HTTPException(
|
254
|
+
status_code=response.status,
|
255
|
+
detail=f"Prefill server error: Status {response.status}",
|
256
|
+
)
|
257
|
+
return ORJSONResponse(content=await response.json())
|
258
|
+
except Exception as e:
|
259
|
+
raise HTTPException(status_code=500, detail=str(e))
|
260
|
+
|
261
|
+
|
262
|
+
def run(prefill_addrs, decode_addrs, host, port):
|
263
|
+
global load_balancer
|
264
|
+
load_balancer = MiniLoadBalancer(prefill_addrs, decode_addrs)
|
265
|
+
uvicorn.run(app, host=host, port=port)
|
266
|
+
|
267
|
+
|
268
|
+
if __name__ == "__main__":
|
269
|
+
import argparse
|
270
|
+
|
271
|
+
parser = argparse.ArgumentParser(description="Mini Load Balancer Server")
|
272
|
+
parser.add_argument(
|
273
|
+
"--prefill", required=True, help="Comma-separated URLs for prefill servers"
|
274
|
+
)
|
275
|
+
parser.add_argument(
|
276
|
+
"--decode", required=True, help="Comma-separated URLs for decode servers"
|
277
|
+
)
|
278
|
+
parser.add_argument(
|
279
|
+
"--host", default="0.0.0.0", help="Host to bind the server (default: 0.0.0.0)"
|
280
|
+
)
|
281
|
+
parser.add_argument(
|
282
|
+
"--port", type=int, default=8000, help="Port to bind the server (default: 8000)"
|
283
|
+
)
|
284
|
+
args = parser.parse_args()
|
285
|
+
run(args.prefill.split(","), args.decode.split(","), args.host, args.port)
|
@@ -0,0 +1,249 @@
|
|
1
|
+
"""
|
2
|
+
Life cycle of a request in the prefill server
|
3
|
+
|
4
|
+
1. Bootstrap Queue
|
5
|
+
a. Initialize a sender for each request
|
6
|
+
b. Use the queue to store requests whose bootstrap (handshake and preallocation) has not finished
|
7
|
+
c. Poll senders to check bootstrap state
|
8
|
+
d. Once bootstrap is complete, move request to Waiting Queue
|
9
|
+
|
10
|
+
2. Waiting Queue
|
11
|
+
a. Use PrefillAdder to pop requests
|
12
|
+
b. Run forward
|
13
|
+
c. Add the request to Infight Queue
|
14
|
+
|
15
|
+
3. Infight Queue
|
16
|
+
a. Poll (non-blocking) the sender of the request
|
17
|
+
b. Once the transfer has finished, return the request
|
18
|
+
"""
|
19
|
+
|
20
|
+
from __future__ import annotations
|
21
|
+
|
22
|
+
import logging
|
23
|
+
from typing import TYPE_CHECKING, List, Optional
|
24
|
+
|
25
|
+
import torch
|
26
|
+
|
27
|
+
from sglang.srt.disaggregation.conn import KVArgs, KVManager, KVPoll, KVSender
|
28
|
+
from sglang.srt.disaggregation.utils import (
|
29
|
+
ReqToMetadataIdxAllocator,
|
30
|
+
poll_and_all_reduce,
|
31
|
+
)
|
32
|
+
from sglang.srt.managers.schedule_batch import FINISH_LENGTH, Req, ScheduleBatch
|
33
|
+
|
34
|
+
if TYPE_CHECKING:
|
35
|
+
from torch.distributed import ProcessGroup
|
36
|
+
|
37
|
+
from sglang.srt.managers.scheduler import GenerationBatchResult, Scheduler
|
38
|
+
from sglang.srt.mem_cache.memory_pool import KVCache
|
39
|
+
|
40
|
+
logger = logging.getLogger(__name__)
|
41
|
+
|
42
|
+
|
43
|
+
class PrefillBootstrapQueue:
|
44
|
+
"""
|
45
|
+
Store the requests in bootstrapping
|
46
|
+
"""
|
47
|
+
|
48
|
+
def __init__(
|
49
|
+
self,
|
50
|
+
token_to_kv_pool: KVCache,
|
51
|
+
req_to_metadata_buffer_idx_allocator: ReqToMetadataIdxAllocator,
|
52
|
+
metadata_buffers: List[torch.Tensor],
|
53
|
+
aux_dtype: torch.dtype,
|
54
|
+
tp_rank: int,
|
55
|
+
tp_size: int,
|
56
|
+
bootstrap_port: int,
|
57
|
+
gloo_group: ProcessGroup,
|
58
|
+
):
|
59
|
+
self.token_to_kv_pool = token_to_kv_pool
|
60
|
+
self.aux_dtype = aux_dtype
|
61
|
+
|
62
|
+
self.metadata_buffers = metadata_buffers
|
63
|
+
self.req_to_metadata_buffer_idx_allocator = req_to_metadata_buffer_idx_allocator
|
64
|
+
self.tp_rank = tp_rank
|
65
|
+
self.tp_size = tp_size
|
66
|
+
self.kv_manager = self._init_kv_manager()
|
67
|
+
self.queue: List[Req] = []
|
68
|
+
self.gloo_group = gloo_group
|
69
|
+
self.bootstrap_port = bootstrap_port
|
70
|
+
|
71
|
+
def allocate_token_id(self, idx: int, token_id: int):
|
72
|
+
assert token_id >= 0, f"token_id: {token_id} is negative"
|
73
|
+
output_id_buffer = self.metadata_buffers[0]
|
74
|
+
output_id_buffer[idx] = token_id
|
75
|
+
|
76
|
+
def _init_kv_manager(self) -> KVManager:
|
77
|
+
kv_args = KVArgs()
|
78
|
+
kv_args.engine_rank = self.tp_rank
|
79
|
+
kv_data_ptrs, kv_data_lens, kv_item_lens = (
|
80
|
+
self.token_to_kv_pool.get_contiguous_buf_infos()
|
81
|
+
)
|
82
|
+
|
83
|
+
kv_args.kv_data_ptrs = kv_data_ptrs
|
84
|
+
kv_args.kv_data_lens = kv_data_lens
|
85
|
+
kv_args.kv_item_lens = kv_item_lens
|
86
|
+
|
87
|
+
# Define req -> input ids buffer
|
88
|
+
kv_args.aux_data_ptrs = [
|
89
|
+
metadata_buffer.data_ptr() for metadata_buffer in self.metadata_buffers
|
90
|
+
]
|
91
|
+
kv_args.aux_data_lens = [
|
92
|
+
metadata_buffer.nbytes for metadata_buffer in self.metadata_buffers
|
93
|
+
]
|
94
|
+
kv_args.aux_item_lens = [
|
95
|
+
metadata_buffer[0].nbytes for metadata_buffer in self.metadata_buffers
|
96
|
+
]
|
97
|
+
kv_args.ib_device = "mock-ib-device"
|
98
|
+
kv_manager = KVManager(kv_args)
|
99
|
+
return kv_manager
|
100
|
+
|
101
|
+
def add(self, req: Req) -> None:
|
102
|
+
req.disagg_kv_sender = KVSender(
|
103
|
+
mgr=self.kv_manager,
|
104
|
+
bootstrap_addr=f"{req.bootstrap_host}:{self.bootstrap_port}",
|
105
|
+
bootstrap_room=req.bootstrap_room,
|
106
|
+
)
|
107
|
+
self._process_req(req)
|
108
|
+
self.queue.append(req)
|
109
|
+
|
110
|
+
def _process_req(self, req: Req) -> None:
|
111
|
+
"""
|
112
|
+
Set max_new_tokens = 1, so PrefillAdder memory estimation is accurate
|
113
|
+
"""
|
114
|
+
req.sampling_params.max_new_tokens = 1
|
115
|
+
|
116
|
+
def pop_bootstrapped(self) -> List[Req]:
|
117
|
+
"""pop the reqs which has finished bootstrapping"""
|
118
|
+
bootstrapped_reqs = []
|
119
|
+
indices_to_remove = set()
|
120
|
+
|
121
|
+
if len(self.queue) == 0:
|
122
|
+
return []
|
123
|
+
|
124
|
+
polls = poll_and_all_reduce(
|
125
|
+
[req.disagg_kv_sender for req in self.queue], self.gloo_group
|
126
|
+
)
|
127
|
+
|
128
|
+
for i, (req, poll) in enumerate(zip(self.queue, polls)):
|
129
|
+
if poll == KVPoll.Bootstrapping:
|
130
|
+
continue
|
131
|
+
elif poll == KVPoll.Failed:
|
132
|
+
raise Exception("Bootstrap failed")
|
133
|
+
|
134
|
+
# KV.WaitingForInput - init here
|
135
|
+
num_kv_indices = len(req.origin_input_ids)
|
136
|
+
if self.req_to_metadata_buffer_idx_allocator.available_size() == 0:
|
137
|
+
break
|
138
|
+
|
139
|
+
req.metadata_buffer_index = (
|
140
|
+
self.req_to_metadata_buffer_idx_allocator.alloc()
|
141
|
+
)
|
142
|
+
assert req.metadata_buffer_index is not None
|
143
|
+
req.disagg_kv_sender.init(num_kv_indices, req.metadata_buffer_index)
|
144
|
+
|
145
|
+
bootstrapped_reqs.append(req)
|
146
|
+
indices_to_remove.add(i)
|
147
|
+
|
148
|
+
self.queue = [
|
149
|
+
entry for i, entry in enumerate(self.queue) if i not in indices_to_remove
|
150
|
+
]
|
151
|
+
|
152
|
+
return bootstrapped_reqs
|
153
|
+
|
154
|
+
|
155
|
+
class SchedulerDisaggregationPrefillMixin:
|
156
|
+
"""
|
157
|
+
Mixin for Scheduler to handle disaggregation prefill
|
158
|
+
"""
|
159
|
+
|
160
|
+
def process_batch_result_disagg_prefill(
|
161
|
+
self: Scheduler, batch: ScheduleBatch, result: GenerationBatchResult
|
162
|
+
) -> None:
|
163
|
+
"""
|
164
|
+
Transfer kv for prefill completed requests and add it into disagg_prefill_infight_queue
|
165
|
+
Adapted from process_batch_result_prefill
|
166
|
+
"""
|
167
|
+
|
168
|
+
next_token_ids = result.next_token_ids.tolist()
|
169
|
+
|
170
|
+
for req, next_token_id in zip(batch.reqs, next_token_ids, strict=True):
|
171
|
+
req: Req
|
172
|
+
if req.is_chunked <= 0:
|
173
|
+
# There is no output_ids for prefill
|
174
|
+
req.output_ids.append(next_token_id)
|
175
|
+
self.tree_cache.cache_unfinished_req(req) # update the tree and lock
|
176
|
+
self.send_kv_chunk(req, token_id=next_token_id)
|
177
|
+
self.disagg_prefill_infight_queue.append(req)
|
178
|
+
else:
|
179
|
+
# being chunked reqs' prefill is not finished
|
180
|
+
req.is_chunked -= 1
|
181
|
+
|
182
|
+
# TODO: Not sure if this is necessary
|
183
|
+
if batch.next_batch_sampling_info:
|
184
|
+
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
185
|
+
# We need to remove this for overlap schedule.
|
186
|
+
self.current_stream.synchronize()
|
187
|
+
batch.next_batch_sampling_info.sampling_info_done.set()
|
188
|
+
|
189
|
+
def process_disagg_prefill_infight_queue(self: Scheduler) -> None:
|
190
|
+
"""
|
191
|
+
Poll the requests in the middle of transfer. If done, return the request.
|
192
|
+
"""
|
193
|
+
assert len(self.disagg_prefill_infight_queue) > 0
|
194
|
+
|
195
|
+
done_reqs = []
|
196
|
+
|
197
|
+
polls = poll_and_all_reduce(
|
198
|
+
[req.disagg_kv_sender for req in self.disagg_prefill_infight_queue],
|
199
|
+
self.tp_worker.get_tp_cpu_group(),
|
200
|
+
)
|
201
|
+
|
202
|
+
undone_reqs: List[Req] = []
|
203
|
+
# Check .poll() for the reqs in disagg_prefill_infight_queue. If Success, respond to the client and remove it from the queue
|
204
|
+
for req, poll in zip(self.disagg_prefill_infight_queue, polls):
|
205
|
+
if poll in [KVPoll.WaitingForInput, KVPoll.Transferring]:
|
206
|
+
undone_reqs.append(req)
|
207
|
+
elif poll == KVPoll.Success: # transfer done
|
208
|
+
self.tree_cache.cache_finished_req(req) # unlock the tree
|
209
|
+
req.finished_reason = FINISH_LENGTH(length=0)
|
210
|
+
done_reqs.append(req)
|
211
|
+
elif poll == KVPoll.Failed:
|
212
|
+
raise Exception("Transferring failed")
|
213
|
+
|
214
|
+
# Stream requests which have finished transfer
|
215
|
+
self.stream_output(done_reqs, False, None)
|
216
|
+
|
217
|
+
self.disagg_prefill_infight_queue = undone_reqs
|
218
|
+
|
219
|
+
def process_prefill_chunk(self: Scheduler) -> None:
|
220
|
+
if self.last_batch and self.last_batch.forward_mode.is_extend():
|
221
|
+
if self.chunked_req:
|
222
|
+
# Move the chunked request out of the batch so that we can merge
|
223
|
+
# only finished requests to running_batch.
|
224
|
+
self.last_batch.filter_batch(chunked_req_to_exclude=self.chunked_req)
|
225
|
+
self.tree_cache.cache_unfinished_req(self.chunked_req)
|
226
|
+
self.send_kv_chunk(self.chunked_req)
|
227
|
+
# chunked request keeps its rid but will get a new req_pool_idx
|
228
|
+
self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
|
229
|
+
self.running_batch.batch_is_full = False
|
230
|
+
|
231
|
+
def send_kv_chunk(
|
232
|
+
self: Scheduler, req: Req, token_id: Optional[int] = None
|
233
|
+
) -> None:
|
234
|
+
"""
|
235
|
+
Send a prefilled chunk to the decode server
|
236
|
+
"""
|
237
|
+
start_idx = req.start_send_idx
|
238
|
+
end_idx = min(len(req.fill_ids), len(req.origin_input_ids))
|
239
|
+
kv_indices = (
|
240
|
+
self.req_to_token_pool.req_to_token[req.req_pool_idx][start_idx:end_idx]
|
241
|
+
.cpu()
|
242
|
+
.numpy()
|
243
|
+
)
|
244
|
+
req.start_send_idx = end_idx
|
245
|
+
if token_id is not None:
|
246
|
+
self.disagg_prefill_pending_queue.allocate_token_id(
|
247
|
+
req.metadata_buffer_index, token_id
|
248
|
+
)
|
249
|
+
req.disagg_kv_sender.send(kv_indices)
|
@@ -0,0 +1,44 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from collections import deque
|
4
|
+
from enum import Enum
|
5
|
+
from typing import List
|
6
|
+
|
7
|
+
import torch
|
8
|
+
import torch.distributed as dist
|
9
|
+
|
10
|
+
|
11
|
+
class DisaggregationMode(Enum):
|
12
|
+
NULL = "null"
|
13
|
+
PREFILL = "prefill"
|
14
|
+
DECODE = "decode"
|
15
|
+
|
16
|
+
|
17
|
+
def poll_and_all_reduce(pollers, gloo_group):
|
18
|
+
polls = [int(poller.poll()) for poller in pollers]
|
19
|
+
tensor_to_reduce = torch.tensor(polls, dtype=torch.uint8, device="cpu")
|
20
|
+
dist.all_reduce(tensor_to_reduce, op=dist.ReduceOp.MIN, group=gloo_group)
|
21
|
+
return tensor_to_reduce.tolist()
|
22
|
+
|
23
|
+
|
24
|
+
class ReqToMetadataIdxAllocator:
|
25
|
+
"""A memory pool that maps a request to its first output token location."""
|
26
|
+
|
27
|
+
def __init__(
|
28
|
+
self,
|
29
|
+
size: int,
|
30
|
+
):
|
31
|
+
self.size = size
|
32
|
+
self.free_slots = deque(list(range(size)))
|
33
|
+
|
34
|
+
def available_size(self):
|
35
|
+
return len(self.free_slots)
|
36
|
+
|
37
|
+
def alloc(self) -> List[int]:
|
38
|
+
if len(self.free_slots) == 0:
|
39
|
+
return None
|
40
|
+
|
41
|
+
return self.free_slots.popleft()
|
42
|
+
|
43
|
+
def free(self, free_index: int):
|
44
|
+
self.free_slots.append(free_index)
|
@@ -189,6 +189,9 @@ class GroupCoordinator:
|
|
189
189
|
device_group: ProcessGroup # group for device communication
|
190
190
|
use_pynccl: bool # a hint of whether to use PyNccl
|
191
191
|
use_custom_allreduce: bool # a hint of whether to use CustomAllreduce
|
192
|
+
use_message_queue_broadcaster: (
|
193
|
+
bool # a hint of whether to use message queue broadcaster
|
194
|
+
)
|
192
195
|
# communicators are only created for world size > 1
|
193
196
|
pynccl_comm: Optional[Any] # PyNccl communicator
|
194
197
|
ca_comm: Optional[Any] # Custom allreduce communicator
|
@@ -241,6 +244,7 @@ class GroupCoordinator:
|
|
241
244
|
self.use_custom_allreduce = use_custom_allreduce
|
242
245
|
self.use_hpu_communicator = use_hpu_communicator
|
243
246
|
self.use_xpu_communicator = use_xpu_communicator
|
247
|
+
self.use_message_queue_broadcaster = use_message_queue_broadcaster
|
244
248
|
|
245
249
|
# lazy import to avoid documentation build error
|
246
250
|
from sglang.srt.distributed.device_communicators.custom_all_reduce import (
|
@@ -269,7 +273,7 @@ class GroupCoordinator:
|
|
269
273
|
HpuCommunicator,
|
270
274
|
)
|
271
275
|
|
272
|
-
self.hpu_communicator: Optional[HpuCommunicator]
|
276
|
+
self.hpu_communicator: Optional[HpuCommunicator] = None
|
273
277
|
if use_hpu_communicator and self.world_size > 1:
|
274
278
|
self.hpu_communicator = HpuCommunicator(group=self.device_group)
|
275
279
|
|
@@ -277,7 +281,7 @@ class GroupCoordinator:
|
|
277
281
|
XpuCommunicator,
|
278
282
|
)
|
279
283
|
|
280
|
-
self.xpu_communicator: Optional[XpuCommunicator]
|
284
|
+
self.xpu_communicator: Optional[XpuCommunicator] = None
|
281
285
|
if use_xpu_communicator and self.world_size > 1:
|
282
286
|
self.xpu_communicator = XpuCommunicator(group=self.device_group)
|
283
287
|
|
@@ -1312,7 +1316,10 @@ vllm_get_world_group = None
|
|
1312
1316
|
|
1313
1317
|
|
1314
1318
|
def monkey_patch_vllm_parallel_state(reverse: bool = False):
|
1315
|
-
|
1319
|
+
try:
|
1320
|
+
import vllm.distributed.parallel_state as vllm_parrlel_state
|
1321
|
+
except ImportError:
|
1322
|
+
return
|
1316
1323
|
|
1317
1324
|
global vllm_get_pp_group, vllm_get_tp_group, vllm_get_world_group
|
1318
1325
|
if vllm_get_pp_group is None:
|