sglang 0.4.5.post2__py3-none-any.whl → 0.4.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (99) hide show
  1. sglang/bench_one_batch.py +19 -3
  2. sglang/bench_serving.py +8 -8
  3. sglang/compile_deep_gemm.py +177 -0
  4. sglang/lang/backend/openai.py +5 -1
  5. sglang/lang/backend/runtime_endpoint.py +5 -1
  6. sglang/srt/code_completion_parser.py +1 -1
  7. sglang/srt/configs/deepseekvl2.py +1 -1
  8. sglang/srt/configs/model_config.py +11 -2
  9. sglang/srt/constrained/llguidance_backend.py +78 -61
  10. sglang/srt/constrained/xgrammar_backend.py +1 -0
  11. sglang/srt/conversation.py +34 -1
  12. sglang/srt/disaggregation/decode.py +96 -5
  13. sglang/srt/disaggregation/mini_lb.py +113 -15
  14. sglang/srt/disaggregation/mooncake/conn.py +199 -32
  15. sglang/srt/disaggregation/nixl/__init__.py +1 -0
  16. sglang/srt/disaggregation/nixl/conn.py +622 -0
  17. sglang/srt/disaggregation/prefill.py +119 -20
  18. sglang/srt/disaggregation/utils.py +17 -0
  19. sglang/srt/entrypoints/engine.py +4 -0
  20. sglang/srt/entrypoints/http_server.py +11 -9
  21. sglang/srt/function_call_parser.py +132 -0
  22. sglang/srt/layers/activation.py +2 -2
  23. sglang/srt/layers/attention/base_attn_backend.py +3 -0
  24. sglang/srt/layers/attention/flashattention_backend.py +809 -160
  25. sglang/srt/layers/attention/flashmla_backend.py +8 -11
  26. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +5 -5
  27. sglang/srt/layers/attention/triton_ops/extend_attention.py +5 -5
  28. sglang/srt/layers/attention/triton_ops/prefill_attention.py +7 -3
  29. sglang/srt/layers/attention/vision.py +2 -0
  30. sglang/srt/layers/dp_attention.py +1 -1
  31. sglang/srt/layers/layernorm.py +42 -5
  32. sglang/srt/layers/logits_processor.py +2 -2
  33. sglang/srt/layers/moe/ep_moe/layer.py +2 -0
  34. sglang/srt/layers/moe/fused_moe_native.py +2 -4
  35. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +41 -41
  36. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  37. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +18 -15
  38. sglang/srt/layers/pooler.py +6 -0
  39. sglang/srt/layers/quantization/awq.py +5 -1
  40. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  41. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  42. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +153 -0
  43. sglang/srt/layers/quantization/deep_gemm.py +385 -0
  44. sglang/srt/layers/quantization/fp8_kernel.py +7 -38
  45. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  46. sglang/srt/layers/quantization/gptq.py +13 -7
  47. sglang/srt/layers/quantization/int8_kernel.py +32 -1
  48. sglang/srt/layers/quantization/modelopt_quant.py +2 -2
  49. sglang/srt/layers/quantization/w8a8_int8.py +3 -3
  50. sglang/srt/layers/radix_attention.py +13 -3
  51. sglang/srt/layers/rotary_embedding.py +176 -132
  52. sglang/srt/layers/sampler.py +2 -2
  53. sglang/srt/managers/data_parallel_controller.py +17 -4
  54. sglang/srt/managers/io_struct.py +21 -3
  55. sglang/srt/managers/mm_utils.py +85 -28
  56. sglang/srt/managers/multimodal_processors/base_processor.py +14 -1
  57. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +9 -2
  58. sglang/srt/managers/multimodal_processors/gemma3.py +2 -5
  59. sglang/srt/managers/multimodal_processors/janus_pro.py +2 -2
  60. sglang/srt/managers/multimodal_processors/minicpm.py +4 -3
  61. sglang/srt/managers/multimodal_processors/qwen_vl.py +38 -13
  62. sglang/srt/managers/schedule_batch.py +42 -12
  63. sglang/srt/managers/scheduler.py +47 -26
  64. sglang/srt/managers/tokenizer_manager.py +120 -30
  65. sglang/srt/managers/tp_worker.py +1 -0
  66. sglang/srt/mem_cache/hiradix_cache.py +40 -32
  67. sglang/srt/mem_cache/memory_pool.py +118 -13
  68. sglang/srt/model_executor/cuda_graph_runner.py +16 -10
  69. sglang/srt/model_executor/forward_batch_info.py +51 -95
  70. sglang/srt/model_executor/model_runner.py +29 -27
  71. sglang/srt/models/deepseek.py +12 -2
  72. sglang/srt/models/deepseek_nextn.py +101 -6
  73. sglang/srt/models/deepseek_v2.py +153 -76
  74. sglang/srt/models/deepseek_vl2.py +9 -4
  75. sglang/srt/models/gemma3_causal.py +1 -1
  76. sglang/srt/models/llama4.py +0 -1
  77. sglang/srt/models/minicpm3.py +2 -2
  78. sglang/srt/models/minicpmo.py +22 -7
  79. sglang/srt/models/mllama4.py +2 -2
  80. sglang/srt/models/qwen2_5_vl.py +3 -6
  81. sglang/srt/models/qwen2_vl.py +3 -7
  82. sglang/srt/models/roberta.py +178 -0
  83. sglang/srt/openai_api/adapter.py +87 -10
  84. sglang/srt/openai_api/protocol.py +6 -1
  85. sglang/srt/server_args.py +65 -60
  86. sglang/srt/speculative/build_eagle_tree.py +2 -2
  87. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
  88. sglang/srt/speculative/eagle_utils.py +2 -2
  89. sglang/srt/speculative/eagle_worker.py +2 -7
  90. sglang/srt/torch_memory_saver_adapter.py +10 -1
  91. sglang/srt/utils.py +48 -6
  92. sglang/test/runners.py +6 -13
  93. sglang/test/test_utils.py +39 -19
  94. sglang/version.py +1 -1
  95. {sglang-0.4.5.post2.dist-info → sglang-0.4.6.dist-info}/METADATA +6 -7
  96. {sglang-0.4.5.post2.dist-info → sglang-0.4.6.dist-info}/RECORD +99 -92
  97. {sglang-0.4.5.post2.dist-info → sglang-0.4.6.dist-info}/WHEEL +1 -1
  98. {sglang-0.4.5.post2.dist-info → sglang-0.4.6.dist-info}/licenses/LICENSE +0 -0
  99. {sglang-0.4.5.post2.dist-info → sglang-0.4.6.dist-info}/top_level.txt +0 -0
@@ -21,6 +21,7 @@ Life cycle of a request in the decode server
21
21
  from __future__ import annotations
22
22
 
23
23
  import logging
24
+ from collections import deque
24
25
  from dataclasses import dataclass
25
26
  from typing import TYPE_CHECKING, List, Optional, Tuple
26
27
 
@@ -136,7 +137,7 @@ class DecodePreallocQueue:
136
137
  kv_receiver_class = get_kv_class(self.transfer_backend, KVClassType.RECEIVER)
137
138
  kv_receiver = kv_receiver_class(
138
139
  mgr=self.kv_manager,
139
- bootstrap_addr=f"{req.bootstrap_host}:{self.bootstrap_port}",
140
+ bootstrap_addr=f"{req.bootstrap_host}:{req.bootstrap_port}",
140
141
  bootstrap_room=req.bootstrap_room,
141
142
  )
142
143
  self.queue.append(DecodeRequest(req=req, kv_receiver=kv_receiver))
@@ -306,7 +307,7 @@ class DecodeTransferQueue:
306
307
  def extend(self, req_conns) -> None:
307
308
  self.queue.extend(req_conns)
308
309
 
309
- def pop_transferred(self) -> List[Req]:
310
+ def pop_transferred(self) -> List[DecodeRequest]:
310
311
  if not self.queue:
311
312
  return []
312
313
 
@@ -329,7 +330,7 @@ class DecodeTransferQueue:
329
330
  assert len(decode_req.req.output_ids) == 0
330
331
  assert decode_req.req.transferred_output_id is None
331
332
  decode_req.req.transferred_output_id = output_id
332
- transferred_reqs.append(decode_req.req)
333
+ transferred_reqs.append(decode_req)
333
334
  indices_to_remove.add(i)
334
335
  elif poll in [
335
336
  KVPoll.Bootstrapping,
@@ -443,8 +444,17 @@ class ScheduleBatchDisaggregationDecodeMixin:
443
444
 
444
445
  class SchedulerDisaggregationDecodeMixin:
445
446
 
447
+ def _prepare_idle_batch_and_run(self, batch, delay_process=False):
448
+ batch, _ = self.prepare_dp_attn_batch(batch)
449
+ result = None
450
+ if batch:
451
+ result = self.run_batch(batch)
452
+ if not delay_process:
453
+ self.process_batch_result(batch, result)
454
+ return batch, result
455
+
446
456
  @torch.no_grad()
447
- def event_loop_normal_disagg_decode(self):
457
+ def event_loop_normal_disagg_decode(self: Scheduler):
448
458
  """A normal scheduler loop for decode worker in disaggregation mode."""
449
459
 
450
460
  while True:
@@ -455,14 +465,25 @@ class SchedulerDisaggregationDecodeMixin:
455
465
  batch = self.get_next_disagg_decode_batch_to_run()
456
466
  self.cur_batch = batch
457
467
 
468
+ prepare_dp_attn_flag = (
469
+ self.server_args.enable_dp_attention
470
+ or self.server_args.enable_sp_layernorm
471
+ )
472
+
458
473
  if batch:
459
474
  # Generate fake extend output.
460
475
  if batch.forward_mode.is_extend():
461
476
  # Note: Logprobs should be handled on the prefill engine.
462
477
  self.stream_output(batch.reqs, False)
478
+ if prepare_dp_attn_flag:
479
+ self._prepare_idle_batch_and_run(None)
463
480
  else:
481
+ if prepare_dp_attn_flag:
482
+ self.prepare_dp_attn_batch(batch)
464
483
  result = self.run_batch(batch)
465
484
  self.process_batch_result(batch, result)
485
+ elif prepare_dp_attn_flag:
486
+ batch, _ = self._prepare_idle_batch_and_run(None)
466
487
 
467
488
  if batch is None and (
468
489
  len(self.disagg_decode_transfer_queue.queue)
@@ -475,6 +496,69 @@ class SchedulerDisaggregationDecodeMixin:
475
496
 
476
497
  self.last_batch = batch
477
498
 
499
+ @torch.no_grad()
500
+ def event_loop_overlap_disagg_decode(self: Scheduler):
501
+ result_queue = deque()
502
+ self.last_batch: Optional[ScheduleBatch] = None
503
+ self.last_batch_in_queue = False # last batch is modifed in-place, so we need another variable to track if it's extend
504
+
505
+ while True:
506
+ recv_reqs = self.recv_requests()
507
+ self.process_input_requests(recv_reqs)
508
+ # polling and allocating kv cache
509
+ self.process_decode_queue()
510
+ batch = self.get_next_disagg_decode_batch_to_run()
511
+ self.cur_batch = batch
512
+ last_batch_in_queue = False
513
+
514
+ prepare_dp_attn_flag = (
515
+ self.server_args.enable_dp_attention
516
+ or self.server_args.enable_sp_layernorm
517
+ )
518
+
519
+ if batch:
520
+ # Generate fake extend output.
521
+ if batch.forward_mode.is_extend():
522
+ # Note: Logprobs should be handled on the prefill engine.
523
+ self.stream_output(batch.reqs, False)
524
+ if prepare_dp_attn_flag:
525
+ batch_, result = self._prepare_idle_batch_and_run(
526
+ None, delay_process=True
527
+ )
528
+ if batch_:
529
+ result_queue.append((batch_.copy(), result))
530
+ last_batch_in_queue = True
531
+ else:
532
+ if prepare_dp_attn_flag:
533
+ self.prepare_dp_attn_batch(batch)
534
+ result = self.run_batch(batch)
535
+ result_queue.append((batch.copy(), result))
536
+ last_batch_in_queue = True
537
+ elif prepare_dp_attn_flag:
538
+ batch, result = self._prepare_idle_batch_and_run(
539
+ None, delay_process=True
540
+ )
541
+ if batch:
542
+ result_queue.append((batch.copy(), result))
543
+ last_batch_in_queue = True
544
+
545
+ # Process the results of the previous batch but skip if the last batch is extend
546
+ if self.last_batch and self.last_batch_in_queue:
547
+ tmp_batch, tmp_result = result_queue.popleft()
548
+ self.process_batch_result(tmp_batch, tmp_result)
549
+
550
+ if batch is None and (
551
+ len(self.disagg_decode_transfer_queue.queue)
552
+ + len(self.disagg_decode_prealloc_queue.queue)
553
+ == 0
554
+ ):
555
+ # When the server is idle, do self-check and re-init some states
556
+ self.check_memory()
557
+ self.new_token_ratio = self.init_new_token_ratio
558
+
559
+ self.last_batch = batch
560
+ self.last_batch_in_queue = last_batch_in_queue
561
+
478
562
  def get_next_disagg_decode_batch_to_run(
479
563
  self: Scheduler,
480
564
  ) -> Optional[Tuple[ScheduleBatch, bool]]:
@@ -557,8 +641,15 @@ class SchedulerDisaggregationDecodeMixin:
557
641
 
558
642
  def process_decode_queue(self: Scheduler):
559
643
  req_conns = self.disagg_decode_prealloc_queue.pop_preallocated()
644
+
645
+ def _num_pre_alloc(req):
646
+ return len(req.req.origin_input_ids) + max(len(req.req.output_ids) - 1, 0)
647
+
648
+ self.num_tokens_pre_allocated += sum(_num_pre_alloc(req) for req in req_conns)
560
649
  self.disagg_decode_transfer_queue.extend(req_conns)
561
650
  alloc_reqs = (
562
651
  self.disagg_decode_transfer_queue.pop_transferred()
563
652
  ) # the requests which kv has arrived
564
- self.waiting_queue.extend(alloc_reqs)
653
+ self.num_tokens_pre_allocated -= sum(_num_pre_alloc(req) for req in alloc_reqs)
654
+
655
+ self.waiting_queue.extend([req.req for req in alloc_reqs])
@@ -6,6 +6,7 @@ import asyncio
6
6
  import random
7
7
  import urllib
8
8
  from itertools import chain
9
+ from typing import List
9
10
 
10
11
  import aiohttp
11
12
  import orjson
@@ -14,17 +15,27 @@ from fastapi import FastAPI, HTTPException
14
15
  from fastapi.responses import ORJSONResponse, Response, StreamingResponse
15
16
 
16
17
 
18
+ class PrefillConfig:
19
+ def __init__(self, url: str, bootstrap_port: int):
20
+ self.url = url
21
+ self.bootstrap_port = bootstrap_port
22
+
23
+
17
24
  class MiniLoadBalancer:
18
- def __init__(self, prefill_servers, decode_servers):
19
- self.prefill_servers = prefill_servers
25
+ def __init__(self, prefill_configs: List[PrefillConfig], decode_servers: List[str]):
26
+ self.prefill_configs = prefill_configs
27
+ self.prefill_servers = [p.url for p in prefill_configs]
20
28
  self.decode_servers = decode_servers
21
29
 
22
30
  def select_pair(self):
23
- return random.choice(self.prefill_servers), random.choice(self.decode_servers)
31
+ prefill_config = random.choice(self.prefill_configs)
32
+ decode_server = random.choice(self.decode_servers)
33
+ return prefill_config.url, prefill_config.bootstrap_port, decode_server
24
34
 
25
35
  async def generate(
26
- self, modified_request, prefill_server, decode_server
36
+ self, modified_request, prefill_server, decode_server, endpoint
27
37
  ) -> ORJSONResponse:
38
+ assert endpoint[0] != "/", f"Endpoint should not start with '/': {endpoint}"
28
39
 
29
40
  async with aiohttp.ClientSession(
30
41
  timeout=aiohttp.ClientTimeout(
@@ -32,8 +43,8 @@ class MiniLoadBalancer:
32
43
  ) # Add timeout for request reliability
33
44
  ) as session:
34
45
  tasks = [
35
- session.post(f"{prefill_server}/generate", json=modified_request),
36
- session.post(f"{decode_server}/generate", json=modified_request),
46
+ session.post(f"{prefill_server}/{endpoint}", json=modified_request),
47
+ session.post(f"{decode_server}/{endpoint}", json=modified_request),
37
48
  ]
38
49
  # Wait for both responses to complete. Prefill should end first.
39
50
  prefill_response, decode_response = await asyncio.gather(*tasks)
@@ -43,7 +54,11 @@ class MiniLoadBalancer:
43
54
  status_code=decode_response.status,
44
55
  )
45
56
 
46
- async def generate_stream(self, modified_request, prefill_server, decode_server):
57
+ async def generate_stream(
58
+ self, modified_request, prefill_server, decode_server, endpoint="generate"
59
+ ):
60
+ assert endpoint[0] != "/", f"Endpoint should not start with '/': {endpoint}"
61
+
47
62
  async def stream_results():
48
63
  async with aiohttp.ClientSession(
49
64
  timeout=aiohttp.ClientTimeout(
@@ -54,10 +69,10 @@ class MiniLoadBalancer:
54
69
  # Create the tasks for both prefill and decode requests
55
70
  tasks = [
56
71
  session.post(
57
- f"{prefill_server}/generate", json=modified_request
72
+ f"{prefill_server}/{endpoint}", json=modified_request
58
73
  ),
59
74
  session.post(
60
- f"{decode_server}/generate", json=modified_request
75
+ f"{decode_server}/{endpoint}", json=modified_request
61
76
  ),
62
77
  ]
63
78
  # Wait for both responses to complete. Since this is streaming, they return immediately.
@@ -155,7 +170,46 @@ async def get_model_info():
155
170
 
156
171
  @app.post("/generate")
157
172
  async def handle_generate_request(request_data: dict):
158
- prefill_server, decode_server = load_balancer.select_pair()
173
+ prefill_server, bootstrap_port, decode_server = load_balancer.select_pair()
174
+
175
+ # Parse and transform prefill_server for bootstrap data
176
+ parsed_url = urllib.parse.urlparse(prefill_server)
177
+ hostname = parsed_url.hostname
178
+ modified_request = request_data.copy()
179
+
180
+ batch_size = _get_request_batch_size(modified_request)
181
+ if batch_size is not None:
182
+ modified_request.update(
183
+ {
184
+ "bootstrap_host": [hostname] * batch_size,
185
+ "bootstrap_port": [bootstrap_port] * batch_size,
186
+ "bootstrap_room": [
187
+ _generate_bootstrap_room() for _ in range(batch_size)
188
+ ],
189
+ }
190
+ )
191
+ else:
192
+ modified_request.update(
193
+ {
194
+ "bootstrap_host": hostname,
195
+ "bootstrap_port": bootstrap_port,
196
+ "bootstrap_room": _generate_bootstrap_room(),
197
+ }
198
+ )
199
+
200
+ if request_data.get("stream", False):
201
+ return await load_balancer.generate_stream(
202
+ modified_request, prefill_server, decode_server, "generate"
203
+ )
204
+ else:
205
+ return await load_balancer.generate(
206
+ modified_request, prefill_server, decode_server, "generate"
207
+ )
208
+
209
+
210
+ @app.post("/v1/chat/completions")
211
+ async def handle_completion_request(request_data: dict):
212
+ prefill_server, bootstrap_port, decode_server = load_balancer.select_pair()
159
213
 
160
214
  # Parse and transform prefill_server for bootstrap data
161
215
  parsed_url = urllib.parse.urlparse(prefill_server)
@@ -164,20 +218,40 @@ async def handle_generate_request(request_data: dict):
164
218
  modified_request.update(
165
219
  {
166
220
  "bootstrap_host": hostname,
221
+ "bootstrap_port": bootstrap_port,
167
222
  "bootstrap_room": random.randint(0, 2**63 - 1),
168
223
  }
169
224
  )
170
225
 
171
226
  if request_data.get("stream", False):
172
227
  return await load_balancer.generate_stream(
173
- modified_request, prefill_server, decode_server
228
+ modified_request,
229
+ prefill_server,
230
+ decode_server,
231
+ endpoint="v1/chat/completions",
174
232
  )
175
233
  else:
176
234
  return await load_balancer.generate(
177
- modified_request, prefill_server, decode_server
235
+ modified_request,
236
+ prefill_server,
237
+ decode_server,
238
+ endpoint="v1/chat/completions",
178
239
  )
179
240
 
180
241
 
242
+ def _generate_bootstrap_room():
243
+ return random.randint(0, 2**63 - 1)
244
+
245
+
246
+ # We may utilize `GenerateReqInput`'s logic later
247
+ def _get_request_batch_size(request):
248
+ if (text := request.get("text")) is not None:
249
+ return None if isinstance(text, str) else len(text)
250
+ if (input_ids := request.get("input_ids")) is not None:
251
+ return None if isinstance(input_ids[0], int) else len(input_ids)
252
+ return None
253
+
254
+
181
255
  @app.get("/v1/models")
182
256
  async def get_models():
183
257
  prefill_server = load_balancer.prefill_servers[0] # Get the first prefill server
@@ -194,9 +268,9 @@ async def get_models():
194
268
  raise HTTPException(status_code=500, detail=str(e))
195
269
 
196
270
 
197
- def run(prefill_addrs, decode_addrs, host, port):
271
+ def run(prefill_configs, decode_addrs, host, port):
198
272
  global load_balancer
199
- load_balancer = MiniLoadBalancer(prefill_addrs, decode_addrs)
273
+ load_balancer = MiniLoadBalancer(prefill_configs, decode_addrs)
200
274
  uvicorn.run(app, host=host, port=port)
201
275
 
202
276
 
@@ -207,6 +281,11 @@ if __name__ == "__main__":
207
281
  parser.add_argument(
208
282
  "--prefill", required=True, help="Comma-separated URLs for prefill servers"
209
283
  )
284
+ parser.add_argument(
285
+ "--prefill-bootstrap-ports",
286
+ help="Comma-separated bootstrap ports for prefill servers",
287
+ default="8998",
288
+ )
210
289
  parser.add_argument(
211
290
  "--decode", required=True, help="Comma-separated URLs for decode servers"
212
291
  )
@@ -217,4 +296,23 @@ if __name__ == "__main__":
217
296
  "--port", type=int, default=8000, help="Port to bind the server (default: 8000)"
218
297
  )
219
298
  args = parser.parse_args()
220
- run(args.prefill.split(","), args.decode.split(","), args.host, args.port)
299
+
300
+ prefill_urls = args.prefill.split(",")
301
+ bootstrap_ports = [int(p) for p in args.prefill_bootstrap_ports.split(",")]
302
+
303
+ if len(bootstrap_ports) == 1:
304
+ bootstrap_ports = bootstrap_ports * len(prefill_urls)
305
+ else:
306
+ if len(bootstrap_ports) != len(prefill_urls):
307
+ raise ValueError(
308
+ "Number of prefill URLs must match number of bootstrap ports"
309
+ )
310
+ exit(1)
311
+
312
+ prefill_configs = []
313
+ for url, port in zip(prefill_urls, bootstrap_ports):
314
+ prefill_configs.append(PrefillConfig(url, port))
315
+
316
+ decode_addrs = args.decode.split(",")
317
+
318
+ run(prefill_configs, decode_addrs, args.host, args.port)