sglang 0.4.4__py3-none-any.whl → 0.4.4.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (176) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +6 -0
  3. sglang/bench_one_batch.py +1 -1
  4. sglang/bench_one_batch_server.py +1 -1
  5. sglang/bench_serving.py +3 -1
  6. sglang/check_env.py +3 -4
  7. sglang/lang/backend/openai.py +18 -5
  8. sglang/lang/chat_template.py +28 -7
  9. sglang/lang/interpreter.py +7 -3
  10. sglang/lang/ir.py +10 -0
  11. sglang/srt/_custom_ops.py +1 -1
  12. sglang/srt/code_completion_parser.py +174 -0
  13. sglang/srt/configs/__init__.py +2 -6
  14. sglang/srt/configs/deepseekvl2.py +667 -0
  15. sglang/srt/configs/janus_pro.py +3 -4
  16. sglang/srt/configs/load_config.py +1 -0
  17. sglang/srt/configs/model_config.py +63 -11
  18. sglang/srt/configs/utils.py +25 -0
  19. sglang/srt/connector/__init__.py +51 -0
  20. sglang/srt/connector/base_connector.py +112 -0
  21. sglang/srt/connector/redis.py +85 -0
  22. sglang/srt/connector/s3.py +122 -0
  23. sglang/srt/connector/serde/__init__.py +31 -0
  24. sglang/srt/connector/serde/safe_serde.py +29 -0
  25. sglang/srt/connector/serde/serde.py +43 -0
  26. sglang/srt/connector/utils.py +35 -0
  27. sglang/srt/conversation.py +88 -0
  28. sglang/srt/disaggregation/conn.py +81 -0
  29. sglang/srt/disaggregation/decode.py +495 -0
  30. sglang/srt/disaggregation/mini_lb.py +285 -0
  31. sglang/srt/disaggregation/prefill.py +249 -0
  32. sglang/srt/disaggregation/utils.py +44 -0
  33. sglang/srt/distributed/parallel_state.py +10 -3
  34. sglang/srt/entrypoints/engine.py +55 -5
  35. sglang/srt/entrypoints/http_server.py +71 -12
  36. sglang/srt/function_call_parser.py +164 -54
  37. sglang/srt/hf_transformers_utils.py +28 -3
  38. sglang/srt/layers/activation.py +4 -2
  39. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  40. sglang/srt/layers/attention/flashattention_backend.py +295 -0
  41. sglang/srt/layers/attention/flashinfer_backend.py +1 -1
  42. sglang/srt/layers/attention/flashmla_backend.py +284 -0
  43. sglang/srt/layers/attention/triton_backend.py +171 -38
  44. sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
  45. sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
  46. sglang/srt/layers/attention/utils.py +53 -0
  47. sglang/srt/layers/attention/vision.py +9 -28
  48. sglang/srt/layers/dp_attention.py +62 -23
  49. sglang/srt/layers/elementwise.py +411 -0
  50. sglang/srt/layers/layernorm.py +24 -2
  51. sglang/srt/layers/linear.py +17 -5
  52. sglang/srt/layers/logits_processor.py +26 -7
  53. sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
  54. sglang/srt/layers/moe/ep_moe/layer.py +273 -1
  55. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
  56. sglang/srt/layers/moe/fused_moe_native.py +2 -1
  57. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
  62. sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
  63. sglang/srt/layers/moe/router.py +342 -0
  64. sglang/srt/layers/moe/topk.py +31 -18
  65. sglang/srt/layers/parameter.py +1 -1
  66. sglang/srt/layers/quantization/__init__.py +184 -126
  67. sglang/srt/layers/quantization/base_config.py +5 -0
  68. sglang/srt/layers/quantization/blockwise_int8.py +1 -1
  69. sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
  70. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
  71. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
  72. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
  73. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
  74. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
  75. sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
  76. sglang/srt/layers/quantization/fp8.py +76 -34
  77. sglang/srt/layers/quantization/fp8_kernel.py +24 -8
  78. sglang/srt/layers/quantization/fp8_utils.py +284 -28
  79. sglang/srt/layers/quantization/gptq.py +36 -9
  80. sglang/srt/layers/quantization/kv_cache.py +98 -0
  81. sglang/srt/layers/quantization/modelopt_quant.py +9 -7
  82. sglang/srt/layers/quantization/utils.py +153 -0
  83. sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
  84. sglang/srt/layers/rotary_embedding.py +66 -87
  85. sglang/srt/layers/sampler.py +1 -1
  86. sglang/srt/lora/layers.py +68 -0
  87. sglang/srt/lora/lora.py +2 -22
  88. sglang/srt/lora/lora_manager.py +47 -23
  89. sglang/srt/lora/mem_pool.py +110 -51
  90. sglang/srt/lora/utils.py +12 -1
  91. sglang/srt/managers/cache_controller.py +4 -5
  92. sglang/srt/managers/data_parallel_controller.py +31 -9
  93. sglang/srt/managers/expert_distribution.py +81 -0
  94. sglang/srt/managers/io_struct.py +39 -3
  95. sglang/srt/managers/mm_utils.py +373 -0
  96. sglang/srt/managers/multimodal_processor.py +68 -0
  97. sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
  98. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
  99. sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
  100. sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
  101. sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
  102. sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
  103. sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
  104. sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
  105. sglang/srt/managers/schedule_batch.py +134 -31
  106. sglang/srt/managers/scheduler.py +325 -38
  107. sglang/srt/managers/scheduler_output_processor_mixin.py +4 -1
  108. sglang/srt/managers/session_controller.py +1 -1
  109. sglang/srt/managers/tokenizer_manager.py +59 -23
  110. sglang/srt/managers/tp_worker.py +1 -1
  111. sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
  112. sglang/srt/managers/utils.py +6 -1
  113. sglang/srt/mem_cache/hiradix_cache.py +27 -8
  114. sglang/srt/mem_cache/memory_pool.py +258 -98
  115. sglang/srt/mem_cache/paged_allocator.py +2 -2
  116. sglang/srt/mem_cache/radix_cache.py +4 -4
  117. sglang/srt/model_executor/cuda_graph_runner.py +85 -28
  118. sglang/srt/model_executor/forward_batch_info.py +81 -15
  119. sglang/srt/model_executor/model_runner.py +70 -6
  120. sglang/srt/model_loader/loader.py +160 -2
  121. sglang/srt/model_loader/weight_utils.py +45 -0
  122. sglang/srt/models/deepseek_janus_pro.py +29 -86
  123. sglang/srt/models/deepseek_nextn.py +22 -10
  124. sglang/srt/models/deepseek_v2.py +326 -192
  125. sglang/srt/models/deepseek_vl2.py +358 -0
  126. sglang/srt/models/gemma3_causal.py +684 -0
  127. sglang/srt/models/gemma3_mm.py +462 -0
  128. sglang/srt/models/grok.py +374 -119
  129. sglang/srt/models/llama.py +47 -7
  130. sglang/srt/models/llama_eagle.py +1 -0
  131. sglang/srt/models/llama_eagle3.py +196 -0
  132. sglang/srt/models/llava.py +3 -3
  133. sglang/srt/models/llavavid.py +3 -3
  134. sglang/srt/models/minicpmo.py +1995 -0
  135. sglang/srt/models/minicpmv.py +62 -137
  136. sglang/srt/models/mllama.py +4 -4
  137. sglang/srt/models/phi3_small.py +1 -1
  138. sglang/srt/models/qwen2.py +3 -0
  139. sglang/srt/models/qwen2_5_vl.py +68 -146
  140. sglang/srt/models/qwen2_classification.py +75 -0
  141. sglang/srt/models/qwen2_moe.py +9 -1
  142. sglang/srt/models/qwen2_vl.py +25 -63
  143. sglang/srt/openai_api/adapter.py +145 -47
  144. sglang/srt/openai_api/protocol.py +23 -2
  145. sglang/srt/sampling/sampling_batch_info.py +1 -1
  146. sglang/srt/sampling/sampling_params.py +6 -6
  147. sglang/srt/server_args.py +104 -14
  148. sglang/srt/speculative/build_eagle_tree.py +7 -347
  149. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
  150. sglang/srt/speculative/eagle_utils.py +208 -252
  151. sglang/srt/speculative/eagle_worker.py +139 -53
  152. sglang/srt/speculative/spec_info.py +6 -1
  153. sglang/srt/torch_memory_saver_adapter.py +22 -0
  154. sglang/srt/utils.py +182 -21
  155. sglang/test/__init__.py +0 -0
  156. sglang/test/attention/__init__.py +0 -0
  157. sglang/test/attention/test_flashattn_backend.py +312 -0
  158. sglang/test/runners.py +2 -0
  159. sglang/test/test_activation.py +2 -1
  160. sglang/test/test_block_fp8.py +5 -4
  161. sglang/test/test_block_fp8_ep.py +2 -1
  162. sglang/test/test_dynamic_grad_mode.py +58 -0
  163. sglang/test/test_layernorm.py +3 -2
  164. sglang/test/test_utils.py +55 -4
  165. sglang/utils.py +31 -0
  166. sglang/version.py +1 -1
  167. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/METADATA +12 -8
  168. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/RECORD +171 -125
  169. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/WHEEL +1 -1
  170. sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
  171. sglang/srt/managers/image_processor.py +0 -55
  172. sglang/srt/managers/image_processors/base_image_processor.py +0 -219
  173. sglang/srt/managers/image_processors/minicpmv.py +0 -86
  174. sglang/srt/managers/multi_modality_padding.py +0 -134
  175. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info/licenses}/LICENSE +0 -0
  176. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,285 @@
1
+ """
2
+ Minimal HTTP load balancer for prefill and decode servers for testing purpose.
3
+ """
4
+
5
+ import asyncio
6
+ import random
7
+ import urllib
8
+ from itertools import chain
9
+
10
+ import aiohttp
11
+ import orjson
12
+ import uvicorn
13
+ from fastapi import FastAPI, HTTPException
14
+ from fastapi.responses import ORJSONResponse, Response, StreamingResponse
15
+
16
+
17
+ class MiniLoadBalancer:
18
+ def __init__(self, prefill_servers, decode_servers):
19
+ self.prefill_servers = prefill_servers
20
+ self.decode_servers = decode_servers
21
+
22
+ def select_pair(self):
23
+ return random.choice(self.prefill_servers), random.choice(self.decode_servers)
24
+
25
+ async def generate_request(self, request_data):
26
+ prefill_server, decode_server = self.select_pair()
27
+
28
+ # Parse and transform prefill_server
29
+ parsed_url = urllib.parse.urlparse(prefill_server)
30
+ hostname = parsed_url.hostname
31
+ bootstrap_host = f"{hostname}"
32
+
33
+ modified_request = request_data.copy()
34
+ modified_request.update(
35
+ {
36
+ "bootstrap_host": bootstrap_host,
37
+ "bootstrap_room": random.randint(0, 2**63 - 1),
38
+ }
39
+ )
40
+
41
+ async with aiohttp.ClientSession() as session:
42
+ # Create the tasks
43
+ tasks = [
44
+ session.post(f"{prefill_server}/generate", json=modified_request),
45
+ session.post(f"{decode_server}/generate", json=modified_request),
46
+ ]
47
+
48
+ prefill_response = None
49
+ decode_response = None
50
+
51
+ # Process responses as they arrive
52
+ for i, response in enumerate(asyncio.as_completed(tasks)):
53
+ response = await response
54
+ # Check if this is the prefill or decode response based on order created
55
+ if i == 0: # First completed task
56
+ if str(response.url).startswith(prefill_server):
57
+ prefill_response = response
58
+ if response.status != 200:
59
+ raise HTTPException(
60
+ status_code=response.status,
61
+ detail=f"Prefill server error: Status {response.status} Details: {await response.text()}",
62
+ )
63
+ else:
64
+ decode_response = response
65
+ if response.status != 200:
66
+ raise HTTPException(
67
+ status_code=response.status,
68
+ detail=f"Decode server error: Status {response.status} Details: {await response.text()}",
69
+ )
70
+ else: # Second completed task
71
+ if str(response.url).startswith(prefill_server):
72
+ prefill_response = response
73
+ else:
74
+ decode_response = response
75
+
76
+ if response.status != 200:
77
+ raise HTTPException(
78
+ status_code=response.status,
79
+ detail=f"{'Prefill' if str(response.url).startswith(prefill_server) else 'Decode'} server error: Status {response.status} Details: {await response.text()}",
80
+ )
81
+
82
+ return await decode_response.json()
83
+
84
+
85
+ app = FastAPI()
86
+ load_balancer = None
87
+
88
+
89
+ @app.get("/health")
90
+ async def health_check():
91
+ return Response(status_code=200)
92
+
93
+
94
+ @app.get("/health_generate")
95
+ async def health_check():
96
+ prefill_servers, decode_servers = (
97
+ load_balancer.prefill_servers,
98
+ load_balancer.decode_servers,
99
+ )
100
+ async with aiohttp.ClientSession() as session:
101
+ # Create the tasks
102
+ tasks = []
103
+ for server in chain(prefill_servers, decode_servers):
104
+ tasks.append(session.post(f"{server}/health_generate"))
105
+ for i, response in enumerate(asyncio.as_completed(tasks)):
106
+ await response
107
+ return Response(status_code=200)
108
+
109
+
110
+ @app.post("/flush_cache")
111
+ async def flush_cache():
112
+ prefill_servers, decode_servers = (
113
+ load_balancer.prefill_servers,
114
+ load_balancer.decode_servers,
115
+ )
116
+ async with aiohttp.ClientSession() as session:
117
+ # Create the tasks
118
+ tasks = []
119
+ for server in chain(prefill_servers, decode_servers):
120
+ tasks.append(session.post(f"{server}/flush_cache"))
121
+ for i, response in enumerate(asyncio.as_completed(tasks)):
122
+ await response
123
+ return Response(status_code=200)
124
+
125
+
126
+ @app.get("/get_server_info")
127
+ async def get_server_info():
128
+ prefill_servers, decode_servers = (
129
+ load_balancer.prefill_servers,
130
+ load_balancer.decode_servers,
131
+ )
132
+ prefill_infos = []
133
+ decode_infos = []
134
+ async with aiohttp.ClientSession() as session:
135
+ for server in chain(prefill_servers):
136
+ server_info = await session.get(f"{server}/get_server_info")
137
+ prefill_infos.append(await server_info.json())
138
+ for server in chain(decode_servers):
139
+ server_info = await session.get(f"{server}/get_server_info")
140
+ decode_infos.append(await server_info.json())
141
+
142
+ return {"prefill": prefill_infos, "decode": decode_infos}
143
+
144
+
145
+ @app.get("/get_model_info")
146
+ async def get_model_info():
147
+ # Dummy model information
148
+ model_info = {
149
+ "model_path": "/path/to/dummy/model",
150
+ "tokenizer_path": "/path/to/dummy/tokenizer",
151
+ "is_generation": True,
152
+ "preferred_sampling_params": {"temperature": 0.7, "max_new_tokens": 128},
153
+ }
154
+ return ORJSONResponse(content=model_info)
155
+
156
+
157
+ @app.post("/generate")
158
+ async def handle_generate_request(request_data: dict):
159
+ prefill_server, decode_server = load_balancer.select_pair()
160
+
161
+ # Parse and transform prefill_server for bootstrap data
162
+ parsed_url = urllib.parse.urlparse(prefill_server)
163
+ hostname = parsed_url.hostname
164
+ modified_request = request_data.copy()
165
+ modified_request.update(
166
+ {
167
+ "bootstrap_host": hostname,
168
+ "bootstrap_room": random.randint(0, 2**63 - 1),
169
+ }
170
+ )
171
+
172
+ # Check if streaming is requested
173
+ if request_data.get("stream", False):
174
+
175
+ async def stream_results():
176
+ async with aiohttp.ClientSession(
177
+ timeout=aiohttp.ClientTimeout(total=3600)
178
+ ) as session:
179
+ try:
180
+ # Create the tasks
181
+ tasks = [
182
+ session.post(
183
+ f"{prefill_server}/generate", json=modified_request
184
+ ),
185
+ session.post(
186
+ f"{decode_server}/generate", json=modified_request
187
+ ),
188
+ ]
189
+
190
+ prefill_response = None
191
+ decode_response = None
192
+
193
+ # Process responses as they arrive
194
+ for i, response_task in enumerate(asyncio.as_completed(tasks)):
195
+ response = await response_task
196
+
197
+ # Check the response immediately
198
+ if str(response.url).startswith(prefill_server):
199
+ prefill_response = response
200
+ if response.status != 200:
201
+ error_msg = {
202
+ "error": {
203
+ "message": f"Prefill server error: Status {response.status}, Details: {await response.text()}"
204
+ }
205
+ }
206
+ yield b"data: " + orjson.dumps(
207
+ error_msg, option=orjson.OPT_NON_STR_KEYS
208
+ ) + b"\n\n"
209
+ return
210
+ else:
211
+ decode_response = response
212
+ if response.status != 200:
213
+ error_msg = {
214
+ "error": {
215
+ "message": f"Decode server error: Status {response.status}"
216
+ }
217
+ }
218
+ yield b"data: " + orjson.dumps(
219
+ error_msg, option=orjson.OPT_NON_STR_KEYS
220
+ ) + b"\n\n"
221
+ return
222
+
223
+ # Stream successful decode server response
224
+ async for line in decode_response.content:
225
+ yield line
226
+ yield b"data: [DONE]\n\n"
227
+
228
+ except Exception as e:
229
+ error_msg = {
230
+ "error": {"message": f"Stream processing error: {str(e)}"}
231
+ }
232
+ yield b"data: " + orjson.dumps(
233
+ error_msg, option=orjson.OPT_NON_STR_KEYS
234
+ ) + b"\n\n"
235
+
236
+ return StreamingResponse(
237
+ stream_results(),
238
+ media_type="text/event-stream",
239
+ )
240
+
241
+ # Non-streaming case
242
+ result = await load_balancer.generate_request(request_data)
243
+ return ORJSONResponse(content=result)
244
+
245
+
246
+ @app.get("/v1/models")
247
+ async def get_models():
248
+ prefill_server = load_balancer.prefill_servers[0] # Get the first prefill server
249
+ async with aiohttp.ClientSession() as session:
250
+ try:
251
+ response = await session.get(f"{prefill_server}/v1/models")
252
+ if response.status != 200:
253
+ raise HTTPException(
254
+ status_code=response.status,
255
+ detail=f"Prefill server error: Status {response.status}",
256
+ )
257
+ return ORJSONResponse(content=await response.json())
258
+ except Exception as e:
259
+ raise HTTPException(status_code=500, detail=str(e))
260
+
261
+
262
+ def run(prefill_addrs, decode_addrs, host, port):
263
+ global load_balancer
264
+ load_balancer = MiniLoadBalancer(prefill_addrs, decode_addrs)
265
+ uvicorn.run(app, host=host, port=port)
266
+
267
+
268
+ if __name__ == "__main__":
269
+ import argparse
270
+
271
+ parser = argparse.ArgumentParser(description="Mini Load Balancer Server")
272
+ parser.add_argument(
273
+ "--prefill", required=True, help="Comma-separated URLs for prefill servers"
274
+ )
275
+ parser.add_argument(
276
+ "--decode", required=True, help="Comma-separated URLs for decode servers"
277
+ )
278
+ parser.add_argument(
279
+ "--host", default="0.0.0.0", help="Host to bind the server (default: 0.0.0.0)"
280
+ )
281
+ parser.add_argument(
282
+ "--port", type=int, default=8000, help="Port to bind the server (default: 8000)"
283
+ )
284
+ args = parser.parse_args()
285
+ run(args.prefill.split(","), args.decode.split(","), args.host, args.port)
@@ -0,0 +1,249 @@
1
+ """
2
+ Life cycle of a request in the prefill server
3
+
4
+ 1. Bootstrap Queue
5
+ a. Initialize a sender for each request
6
+ b. Use the queue to store requests whose bootstrap (handshake and preallocation) has not finished
7
+ c. Poll senders to check bootstrap state
8
+ d. Once bootstrap is complete, move request to Waiting Queue
9
+
10
+ 2. Waiting Queue
11
+ a. Use PrefillAdder to pop requests
12
+ b. Run forward
13
+ c. Add the request to Infight Queue
14
+
15
+ 3. Infight Queue
16
+ a. Poll (non-blocking) the sender of the request
17
+ b. Once the transfer has finished, return the request
18
+ """
19
+
20
+ from __future__ import annotations
21
+
22
+ import logging
23
+ from typing import TYPE_CHECKING, List, Optional
24
+
25
+ import torch
26
+
27
+ from sglang.srt.disaggregation.conn import KVArgs, KVManager, KVPoll, KVSender
28
+ from sglang.srt.disaggregation.utils import (
29
+ ReqToMetadataIdxAllocator,
30
+ poll_and_all_reduce,
31
+ )
32
+ from sglang.srt.managers.schedule_batch import FINISH_LENGTH, Req, ScheduleBatch
33
+
34
+ if TYPE_CHECKING:
35
+ from torch.distributed import ProcessGroup
36
+
37
+ from sglang.srt.managers.scheduler import GenerationBatchResult, Scheduler
38
+ from sglang.srt.mem_cache.memory_pool import KVCache
39
+
40
+ logger = logging.getLogger(__name__)
41
+
42
+
43
+ class PrefillBootstrapQueue:
44
+ """
45
+ Store the requests in bootstrapping
46
+ """
47
+
48
+ def __init__(
49
+ self,
50
+ token_to_kv_pool: KVCache,
51
+ req_to_metadata_buffer_idx_allocator: ReqToMetadataIdxAllocator,
52
+ metadata_buffers: List[torch.Tensor],
53
+ aux_dtype: torch.dtype,
54
+ tp_rank: int,
55
+ tp_size: int,
56
+ bootstrap_port: int,
57
+ gloo_group: ProcessGroup,
58
+ ):
59
+ self.token_to_kv_pool = token_to_kv_pool
60
+ self.aux_dtype = aux_dtype
61
+
62
+ self.metadata_buffers = metadata_buffers
63
+ self.req_to_metadata_buffer_idx_allocator = req_to_metadata_buffer_idx_allocator
64
+ self.tp_rank = tp_rank
65
+ self.tp_size = tp_size
66
+ self.kv_manager = self._init_kv_manager()
67
+ self.queue: List[Req] = []
68
+ self.gloo_group = gloo_group
69
+ self.bootstrap_port = bootstrap_port
70
+
71
+ def allocate_token_id(self, idx: int, token_id: int):
72
+ assert token_id >= 0, f"token_id: {token_id} is negative"
73
+ output_id_buffer = self.metadata_buffers[0]
74
+ output_id_buffer[idx] = token_id
75
+
76
+ def _init_kv_manager(self) -> KVManager:
77
+ kv_args = KVArgs()
78
+ kv_args.engine_rank = self.tp_rank
79
+ kv_data_ptrs, kv_data_lens, kv_item_lens = (
80
+ self.token_to_kv_pool.get_contiguous_buf_infos()
81
+ )
82
+
83
+ kv_args.kv_data_ptrs = kv_data_ptrs
84
+ kv_args.kv_data_lens = kv_data_lens
85
+ kv_args.kv_item_lens = kv_item_lens
86
+
87
+ # Define req -> input ids buffer
88
+ kv_args.aux_data_ptrs = [
89
+ metadata_buffer.data_ptr() for metadata_buffer in self.metadata_buffers
90
+ ]
91
+ kv_args.aux_data_lens = [
92
+ metadata_buffer.nbytes for metadata_buffer in self.metadata_buffers
93
+ ]
94
+ kv_args.aux_item_lens = [
95
+ metadata_buffer[0].nbytes for metadata_buffer in self.metadata_buffers
96
+ ]
97
+ kv_args.ib_device = "mock-ib-device"
98
+ kv_manager = KVManager(kv_args)
99
+ return kv_manager
100
+
101
+ def add(self, req: Req) -> None:
102
+ req.disagg_kv_sender = KVSender(
103
+ mgr=self.kv_manager,
104
+ bootstrap_addr=f"{req.bootstrap_host}:{self.bootstrap_port}",
105
+ bootstrap_room=req.bootstrap_room,
106
+ )
107
+ self._process_req(req)
108
+ self.queue.append(req)
109
+
110
+ def _process_req(self, req: Req) -> None:
111
+ """
112
+ Set max_new_tokens = 1, so PrefillAdder memory estimation is accurate
113
+ """
114
+ req.sampling_params.max_new_tokens = 1
115
+
116
+ def pop_bootstrapped(self) -> List[Req]:
117
+ """pop the reqs which has finished bootstrapping"""
118
+ bootstrapped_reqs = []
119
+ indices_to_remove = set()
120
+
121
+ if len(self.queue) == 0:
122
+ return []
123
+
124
+ polls = poll_and_all_reduce(
125
+ [req.disagg_kv_sender for req in self.queue], self.gloo_group
126
+ )
127
+
128
+ for i, (req, poll) in enumerate(zip(self.queue, polls)):
129
+ if poll == KVPoll.Bootstrapping:
130
+ continue
131
+ elif poll == KVPoll.Failed:
132
+ raise Exception("Bootstrap failed")
133
+
134
+ # KV.WaitingForInput - init here
135
+ num_kv_indices = len(req.origin_input_ids)
136
+ if self.req_to_metadata_buffer_idx_allocator.available_size() == 0:
137
+ break
138
+
139
+ req.metadata_buffer_index = (
140
+ self.req_to_metadata_buffer_idx_allocator.alloc()
141
+ )
142
+ assert req.metadata_buffer_index is not None
143
+ req.disagg_kv_sender.init(num_kv_indices, req.metadata_buffer_index)
144
+
145
+ bootstrapped_reqs.append(req)
146
+ indices_to_remove.add(i)
147
+
148
+ self.queue = [
149
+ entry for i, entry in enumerate(self.queue) if i not in indices_to_remove
150
+ ]
151
+
152
+ return bootstrapped_reqs
153
+
154
+
155
+ class SchedulerDisaggregationPrefillMixin:
156
+ """
157
+ Mixin for Scheduler to handle disaggregation prefill
158
+ """
159
+
160
+ def process_batch_result_disagg_prefill(
161
+ self: Scheduler, batch: ScheduleBatch, result: GenerationBatchResult
162
+ ) -> None:
163
+ """
164
+ Transfer kv for prefill completed requests and add it into disagg_prefill_infight_queue
165
+ Adapted from process_batch_result_prefill
166
+ """
167
+
168
+ next_token_ids = result.next_token_ids.tolist()
169
+
170
+ for req, next_token_id in zip(batch.reqs, next_token_ids, strict=True):
171
+ req: Req
172
+ if req.is_chunked <= 0:
173
+ # There is no output_ids for prefill
174
+ req.output_ids.append(next_token_id)
175
+ self.tree_cache.cache_unfinished_req(req) # update the tree and lock
176
+ self.send_kv_chunk(req, token_id=next_token_id)
177
+ self.disagg_prefill_infight_queue.append(req)
178
+ else:
179
+ # being chunked reqs' prefill is not finished
180
+ req.is_chunked -= 1
181
+
182
+ # TODO: Not sure if this is necessary
183
+ if batch.next_batch_sampling_info:
184
+ batch.next_batch_sampling_info.update_regex_vocab_mask()
185
+ # We need to remove this for overlap schedule.
186
+ self.current_stream.synchronize()
187
+ batch.next_batch_sampling_info.sampling_info_done.set()
188
+
189
+ def process_disagg_prefill_infight_queue(self: Scheduler) -> None:
190
+ """
191
+ Poll the requests in the middle of transfer. If done, return the request.
192
+ """
193
+ assert len(self.disagg_prefill_infight_queue) > 0
194
+
195
+ done_reqs = []
196
+
197
+ polls = poll_and_all_reduce(
198
+ [req.disagg_kv_sender for req in self.disagg_prefill_infight_queue],
199
+ self.tp_worker.get_tp_cpu_group(),
200
+ )
201
+
202
+ undone_reqs: List[Req] = []
203
+ # Check .poll() for the reqs in disagg_prefill_infight_queue. If Success, respond to the client and remove it from the queue
204
+ for req, poll in zip(self.disagg_prefill_infight_queue, polls):
205
+ if poll in [KVPoll.WaitingForInput, KVPoll.Transferring]:
206
+ undone_reqs.append(req)
207
+ elif poll == KVPoll.Success: # transfer done
208
+ self.tree_cache.cache_finished_req(req) # unlock the tree
209
+ req.finished_reason = FINISH_LENGTH(length=0)
210
+ done_reqs.append(req)
211
+ elif poll == KVPoll.Failed:
212
+ raise Exception("Transferring failed")
213
+
214
+ # Stream requests which have finished transfer
215
+ self.stream_output(done_reqs, False, None)
216
+
217
+ self.disagg_prefill_infight_queue = undone_reqs
218
+
219
+ def process_prefill_chunk(self: Scheduler) -> None:
220
+ if self.last_batch and self.last_batch.forward_mode.is_extend():
221
+ if self.chunked_req:
222
+ # Move the chunked request out of the batch so that we can merge
223
+ # only finished requests to running_batch.
224
+ self.last_batch.filter_batch(chunked_req_to_exclude=self.chunked_req)
225
+ self.tree_cache.cache_unfinished_req(self.chunked_req)
226
+ self.send_kv_chunk(self.chunked_req)
227
+ # chunked request keeps its rid but will get a new req_pool_idx
228
+ self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
229
+ self.running_batch.batch_is_full = False
230
+
231
+ def send_kv_chunk(
232
+ self: Scheduler, req: Req, token_id: Optional[int] = None
233
+ ) -> None:
234
+ """
235
+ Send a prefilled chunk to the decode server
236
+ """
237
+ start_idx = req.start_send_idx
238
+ end_idx = min(len(req.fill_ids), len(req.origin_input_ids))
239
+ kv_indices = (
240
+ self.req_to_token_pool.req_to_token[req.req_pool_idx][start_idx:end_idx]
241
+ .cpu()
242
+ .numpy()
243
+ )
244
+ req.start_send_idx = end_idx
245
+ if token_id is not None:
246
+ self.disagg_prefill_pending_queue.allocate_token_id(
247
+ req.metadata_buffer_index, token_id
248
+ )
249
+ req.disagg_kv_sender.send(kv_indices)
@@ -0,0 +1,44 @@
1
+ from __future__ import annotations
2
+
3
+ from collections import deque
4
+ from enum import Enum
5
+ from typing import List
6
+
7
+ import torch
8
+ import torch.distributed as dist
9
+
10
+
11
+ class DisaggregationMode(Enum):
12
+ NULL = "null"
13
+ PREFILL = "prefill"
14
+ DECODE = "decode"
15
+
16
+
17
+ def poll_and_all_reduce(pollers, gloo_group):
18
+ polls = [int(poller.poll()) for poller in pollers]
19
+ tensor_to_reduce = torch.tensor(polls, dtype=torch.uint8, device="cpu")
20
+ dist.all_reduce(tensor_to_reduce, op=dist.ReduceOp.MIN, group=gloo_group)
21
+ return tensor_to_reduce.tolist()
22
+
23
+
24
+ class ReqToMetadataIdxAllocator:
25
+ """A memory pool that maps a request to its first output token location."""
26
+
27
+ def __init__(
28
+ self,
29
+ size: int,
30
+ ):
31
+ self.size = size
32
+ self.free_slots = deque(list(range(size)))
33
+
34
+ def available_size(self):
35
+ return len(self.free_slots)
36
+
37
+ def alloc(self) -> List[int]:
38
+ if len(self.free_slots) == 0:
39
+ return None
40
+
41
+ return self.free_slots.popleft()
42
+
43
+ def free(self, free_index: int):
44
+ self.free_slots.append(free_index)
@@ -189,6 +189,9 @@ class GroupCoordinator:
189
189
  device_group: ProcessGroup # group for device communication
190
190
  use_pynccl: bool # a hint of whether to use PyNccl
191
191
  use_custom_allreduce: bool # a hint of whether to use CustomAllreduce
192
+ use_message_queue_broadcaster: (
193
+ bool # a hint of whether to use message queue broadcaster
194
+ )
192
195
  # communicators are only created for world size > 1
193
196
  pynccl_comm: Optional[Any] # PyNccl communicator
194
197
  ca_comm: Optional[Any] # Custom allreduce communicator
@@ -241,6 +244,7 @@ class GroupCoordinator:
241
244
  self.use_custom_allreduce = use_custom_allreduce
242
245
  self.use_hpu_communicator = use_hpu_communicator
243
246
  self.use_xpu_communicator = use_xpu_communicator
247
+ self.use_message_queue_broadcaster = use_message_queue_broadcaster
244
248
 
245
249
  # lazy import to avoid documentation build error
246
250
  from sglang.srt.distributed.device_communicators.custom_all_reduce import (
@@ -269,7 +273,7 @@ class GroupCoordinator:
269
273
  HpuCommunicator,
270
274
  )
271
275
 
272
- self.hpu_communicator: Optional[HpuCommunicator]
276
+ self.hpu_communicator: Optional[HpuCommunicator] = None
273
277
  if use_hpu_communicator and self.world_size > 1:
274
278
  self.hpu_communicator = HpuCommunicator(group=self.device_group)
275
279
 
@@ -277,7 +281,7 @@ class GroupCoordinator:
277
281
  XpuCommunicator,
278
282
  )
279
283
 
280
- self.xpu_communicator: Optional[XpuCommunicator]
284
+ self.xpu_communicator: Optional[XpuCommunicator] = None
281
285
  if use_xpu_communicator and self.world_size > 1:
282
286
  self.xpu_communicator = XpuCommunicator(group=self.device_group)
283
287
 
@@ -1312,7 +1316,10 @@ vllm_get_world_group = None
1312
1316
 
1313
1317
 
1314
1318
  def monkey_patch_vllm_parallel_state(reverse: bool = False):
1315
- import vllm.distributed.parallel_state as vllm_parrlel_state
1319
+ try:
1320
+ import vllm.distributed.parallel_state as vllm_parrlel_state
1321
+ except ImportError:
1322
+ return
1316
1323
 
1317
1324
  global vllm_get_pp_group, vllm_get_tp_group, vllm_get_world_group
1318
1325
  if vllm_get_pp_group is None: