sglang 0.4.7.post1__py3-none-any.whl → 0.4.8.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (123) hide show
  1. sglang/bench_one_batch.py +8 -6
  2. sglang/srt/_custom_ops.py +2 -2
  3. sglang/srt/code_completion_parser.py +2 -44
  4. sglang/srt/configs/model_config.py +1 -0
  5. sglang/srt/constants.py +3 -0
  6. sglang/srt/conversation.py +14 -3
  7. sglang/srt/custom_op.py +11 -1
  8. sglang/srt/disaggregation/base/conn.py +2 -0
  9. sglang/srt/disaggregation/decode.py +22 -28
  10. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
  11. sglang/srt/disaggregation/mini_lb.py +34 -4
  12. sglang/srt/disaggregation/mooncake/conn.py +301 -64
  13. sglang/srt/disaggregation/mooncake/transfer_engine.py +31 -1
  14. sglang/srt/disaggregation/nixl/conn.py +94 -46
  15. sglang/srt/disaggregation/prefill.py +20 -15
  16. sglang/srt/disaggregation/utils.py +47 -18
  17. sglang/srt/distributed/parallel_state.py +12 -4
  18. sglang/srt/entrypoints/engine.py +27 -31
  19. sglang/srt/entrypoints/http_server.py +149 -79
  20. sglang/srt/entrypoints/http_server_engine.py +0 -3
  21. sglang/srt/entrypoints/openai/__init__.py +0 -0
  22. sglang/srt/{openai_api → entrypoints/openai}/protocol.py +115 -34
  23. sglang/srt/entrypoints/openai/serving_base.py +149 -0
  24. sglang/srt/entrypoints/openai/serving_chat.py +897 -0
  25. sglang/srt/entrypoints/openai/serving_completions.py +425 -0
  26. sglang/srt/entrypoints/openai/serving_embedding.py +170 -0
  27. sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
  28. sglang/srt/entrypoints/openai/serving_score.py +61 -0
  29. sglang/srt/entrypoints/openai/usage_processor.py +81 -0
  30. sglang/srt/entrypoints/openai/utils.py +72 -0
  31. sglang/srt/function_call/base_format_detector.py +7 -4
  32. sglang/srt/function_call/deepseekv3_detector.py +1 -1
  33. sglang/srt/function_call/ebnf_composer.py +64 -10
  34. sglang/srt/function_call/function_call_parser.py +6 -6
  35. sglang/srt/function_call/llama32_detector.py +1 -1
  36. sglang/srt/function_call/mistral_detector.py +1 -1
  37. sglang/srt/function_call/pythonic_detector.py +1 -1
  38. sglang/srt/function_call/qwen25_detector.py +1 -1
  39. sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
  40. sglang/srt/layers/activation.py +28 -3
  41. sglang/srt/layers/attention/aiter_backend.py +5 -2
  42. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  43. sglang/srt/layers/attention/cutlass_mla_backend.py +1 -0
  44. sglang/srt/layers/attention/flashattention_backend.py +43 -23
  45. sglang/srt/layers/attention/flashinfer_backend.py +9 -6
  46. sglang/srt/layers/attention/flashinfer_mla_backend.py +7 -4
  47. sglang/srt/layers/attention/flashmla_backend.py +5 -2
  48. sglang/srt/layers/attention/tbo_backend.py +3 -3
  49. sglang/srt/layers/attention/triton_backend.py +19 -11
  50. sglang/srt/layers/communicator.py +5 -5
  51. sglang/srt/layers/dp_attention.py +11 -2
  52. sglang/srt/layers/layernorm.py +44 -2
  53. sglang/srt/layers/linear.py +18 -1
  54. sglang/srt/layers/logits_processor.py +14 -5
  55. sglang/srt/layers/moe/ep_moe/kernels.py +159 -2
  56. sglang/srt/layers/moe/ep_moe/layer.py +286 -13
  57. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +19 -2
  58. sglang/srt/layers/moe/fused_moe_native.py +7 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +13 -2
  61. sglang/srt/layers/moe/fused_moe_triton/layer.py +148 -26
  62. sglang/srt/layers/moe/topk.py +117 -4
  63. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
  64. sglang/srt/layers/quantization/fp8.py +25 -17
  65. sglang/srt/layers/quantization/fp8_utils.py +5 -4
  66. sglang/srt/layers/quantization/modelopt_quant.py +62 -8
  67. sglang/srt/layers/quantization/utils.py +5 -2
  68. sglang/srt/layers/rotary_embedding.py +144 -12
  69. sglang/srt/layers/sampler.py +1 -1
  70. sglang/srt/layers/vocab_parallel_embedding.py +14 -1
  71. sglang/srt/lora/lora_manager.py +173 -74
  72. sglang/srt/lora/mem_pool.py +49 -45
  73. sglang/srt/lora/utils.py +1 -1
  74. sglang/srt/managers/cache_controller.py +33 -15
  75. sglang/srt/managers/expert_distribution.py +21 -0
  76. sglang/srt/managers/io_struct.py +19 -14
  77. sglang/srt/managers/multimodal_processors/base_processor.py +44 -9
  78. sglang/srt/managers/multimodal_processors/gemma3n.py +97 -0
  79. sglang/srt/managers/schedule_batch.py +49 -32
  80. sglang/srt/managers/schedule_policy.py +70 -56
  81. sglang/srt/managers/scheduler.py +189 -68
  82. sglang/srt/managers/template_manager.py +226 -0
  83. sglang/srt/managers/tokenizer_manager.py +11 -8
  84. sglang/srt/managers/tp_worker.py +12 -2
  85. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  86. sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
  87. sglang/srt/mem_cache/base_prefix_cache.py +52 -8
  88. sglang/srt/mem_cache/chunk_cache.py +11 -16
  89. sglang/srt/mem_cache/hiradix_cache.py +34 -23
  90. sglang/srt/mem_cache/memory_pool.py +118 -114
  91. sglang/srt/mem_cache/radix_cache.py +20 -16
  92. sglang/srt/model_executor/cuda_graph_runner.py +77 -46
  93. sglang/srt/model_executor/forward_batch_info.py +18 -5
  94. sglang/srt/model_executor/model_runner.py +27 -8
  95. sglang/srt/model_loader/loader.py +50 -8
  96. sglang/srt/model_loader/weight_utils.py +100 -2
  97. sglang/srt/models/deepseek_nextn.py +35 -30
  98. sglang/srt/models/deepseek_v2.py +255 -30
  99. sglang/srt/models/gemma3n_audio.py +949 -0
  100. sglang/srt/models/gemma3n_causal.py +1009 -0
  101. sglang/srt/models/gemma3n_mm.py +511 -0
  102. sglang/srt/models/glm4.py +312 -0
  103. sglang/srt/models/hunyuan.py +771 -0
  104. sglang/srt/models/mimo_mtp.py +2 -18
  105. sglang/srt/reasoning_parser.py +21 -11
  106. sglang/srt/server_args.py +51 -9
  107. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -10
  108. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +125 -12
  109. sglang/srt/speculative/eagle_utils.py +80 -8
  110. sglang/srt/speculative/eagle_worker.py +124 -41
  111. sglang/srt/torch_memory_saver_adapter.py +19 -15
  112. sglang/srt/two_batch_overlap.py +4 -1
  113. sglang/srt/utils.py +248 -11
  114. sglang/test/test_block_fp8_ep.py +1 -0
  115. sglang/test/test_utils.py +1 -0
  116. sglang/version.py +1 -1
  117. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/METADATA +4 -10
  118. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/RECORD +121 -105
  119. sglang/srt/entrypoints/verl_engine.py +0 -179
  120. sglang/srt/openai_api/adapter.py +0 -2148
  121. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/WHEEL +0 -0
  122. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/licenses/LICENSE +0 -0
  123. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/top_level.txt +0 -0
@@ -31,23 +31,19 @@ from sglang.srt.utils import get_local_ip_by_remote
31
31
 
32
32
  logger = logging.getLogger(__name__)
33
33
 
34
- NixlEngineInfo: TypeAlias = Dict[str, Union[str, int]]
35
-
36
34
  GUARD = "NixlMsgGuard".encode("ascii")
37
35
 
38
36
 
39
37
  @dataclasses.dataclass
40
38
  class TransferInfo:
39
+ """Contains indices for a transfer, sent by KVReceiver. Received by prefill bootstrap thread."""
40
+
41
41
  room: int
42
42
  endpoint: str
43
43
  dst_port: int
44
- agent_metadata: bytes
45
44
  agent_name: str
46
- dst_kv_ptrs: list[int]
47
45
  dst_kv_indices: npt.NDArray[np.int32]
48
- dst_aux_ptrs: list[int]
49
46
  dst_aux_index: int
50
- dst_gpu_id: int
51
47
  required_dst_info_num: int
52
48
 
53
49
  def is_dummy(self):
@@ -59,14 +55,37 @@ class TransferInfo:
59
55
  room=int(msg[0].decode("ascii")),
60
56
  endpoint=msg[1].decode("ascii"),
61
57
  dst_port=int(msg[2].decode("ascii")),
62
- agent_metadata=msg[3],
63
- agent_name=msg[4].decode("ascii"),
58
+ agent_name=msg[3].decode("ascii"),
59
+ dst_kv_indices=np.frombuffer(msg[4], dtype=np.int32),
60
+ dst_aux_index=int(msg[5].decode("ascii")),
61
+ required_dst_info_num=int(msg[6].decode("ascii")),
62
+ )
63
+
64
+
65
+ @dataclasses.dataclass
66
+ class KVArgsRegisterInfo:
67
+ """Contains base pointers and other info which only needs to be sent once by KVReceiver. Received by prefill bootstrap thread."""
68
+
69
+ room: str
70
+ endpoint: str
71
+ dst_port: int
72
+ agent_name: str
73
+ agent_metadata: bytes
74
+ dst_kv_ptrs: list[int]
75
+ dst_aux_ptrs: list[int]
76
+ gpu_id: int
77
+
78
+ @classmethod
79
+ def from_zmq(cls, msg: List[bytes]):
80
+ return cls(
81
+ room=str(msg[0].decode("ascii")),
82
+ endpoint=msg[1].decode("ascii"),
83
+ dst_port=int(msg[2].decode("ascii")),
84
+ agent_name=msg[3].decode("ascii"),
85
+ agent_metadata=msg[4],
64
86
  dst_kv_ptrs=list(struct.unpack(f"{len(msg[5])//8}Q", msg[5])),
65
- dst_kv_indices=np.frombuffer(msg[6], dtype=np.int32),
66
- dst_aux_ptrs=list(struct.unpack(f"{len(msg[7])//8}Q", msg[7])),
67
- dst_aux_index=int(msg[8].decode("ascii")),
68
- dst_gpu_id=int(msg[9].decode("ascii")),
69
- required_dst_info_num=int(msg[10].decode("ascii")),
87
+ dst_aux_ptrs=list(struct.unpack(f"{len(msg[6])//8}Q", msg[6])),
88
+ gpu_id=int(msg[7].decode("ascii")),
70
89
  )
71
90
 
72
91
 
@@ -109,9 +128,9 @@ class NixlKVManager(CommonKVManager):
109
128
  self.register_buffer_to_engine()
110
129
 
111
130
  if self.disaggregation_mode == DisaggregationMode.PREFILL:
112
- self.request_status = {}
113
- self.transfer_infos: Dict[int, TransferInfo] = {}
114
- self.peer_names: Dict[str, str] = {}
131
+ self.request_status: Dict[int, KVPoll] = {}
132
+ self.transfer_infos: Dict[int, Dict[str, TransferInfo]] = {}
133
+ self.decode_kv_args_table: Dict[str, KVArgsRegisterInfo] = {}
115
134
  self._start_bootstrap_thread()
116
135
  elif self.disaggregation_mode == DisaggregationMode.DECODE:
117
136
  self.transfer_statuses: Dict[int, TransferStatus] = defaultdict(
@@ -154,10 +173,13 @@ class NixlKVManager(CommonKVManager):
154
173
  if not self.aux_descs:
155
174
  raise Exception("NIXL memory registration failed for aux tensors")
156
175
 
157
- def _add_remote(self, agent_name: str, agent_metadata: bytes):
158
- if agent_name not in self.peer_names:
159
- self.peer_names[agent_name] = self.agent.add_remote_agent(agent_metadata)
160
- return self.peer_names[agent_name]
176
+ def _add_remote_peer(self, decode_kv_args: KVArgsRegisterInfo):
177
+ agent_name = decode_kv_args.agent_name
178
+ if agent_name in self.decode_kv_args_table:
179
+ logger.info(f"Peer {agent_name} was already registered, ignoring.")
180
+ return
181
+ self.decode_kv_args_table[agent_name] = decode_kv_args
182
+ self.agent.add_remote_agent(decode_kv_args.agent_metadata)
161
183
 
162
184
  def send_kvcache(
163
185
  self,
@@ -262,17 +284,17 @@ class NixlKVManager(CommonKVManager):
262
284
  if req.is_dummy():
263
285
  continue
264
286
 
265
- peer_name = self._add_remote(req.agent_name, req.agent_metadata)
266
287
  chunked_dst_kv_indice = req.dst_kv_indices[index_slice]
267
288
  assert len(chunked_dst_kv_indice) == len(kv_indices)
289
+ assert req.agent_name in self.decode_kv_args_table
268
290
 
269
291
  notif = "_".join([str(req.room), "kv", str(chunk_id), str(int(is_last))])
270
292
  kv_xfer_handle = self.send_kvcache(
271
- peer_name,
293
+ req.agent_name,
272
294
  kv_indices,
273
- req.dst_kv_ptrs,
295
+ self.decode_kv_args_table[req.agent_name].dst_kv_ptrs,
274
296
  chunked_dst_kv_indice,
275
- req.dst_gpu_id,
297
+ self.decode_kv_args_table[req.agent_name].gpu_id,
276
298
  notif,
277
299
  )
278
300
  handles.append(kv_xfer_handle)
@@ -280,13 +302,15 @@ class NixlKVManager(CommonKVManager):
280
302
  if is_last:
281
303
  assert aux_index is not None
282
304
  aux_xfer_handle = self.send_aux(
283
- peer_name,
305
+ req.agent_name,
284
306
  aux_index,
285
- req.dst_aux_ptrs,
307
+ self.decode_kv_args_table[req.agent_name].dst_aux_ptrs,
286
308
  req.dst_aux_index,
287
309
  str(req.room) + "_aux",
288
310
  )
289
311
  handles.append(aux_xfer_handle)
312
+ if is_last:
313
+ del self.transfer_infos[bootstrap_room]
290
314
  return handles
291
315
 
292
316
  def update_transfer_status(self):
@@ -328,16 +352,23 @@ class NixlKVManager(CommonKVManager):
328
352
  ), f"First message should be {GUARD}. Foreign traffic?"
329
353
  waiting_req_bytes = waiting_req_bytes[1:]
330
354
  room = waiting_req_bytes[0].decode("ascii")
331
-
332
- required_dst_info_num = int(waiting_req_bytes[10].decode("ascii"))
355
+ agent_name = waiting_req_bytes[3].decode("ascii")
356
+ if room == "None":
357
+ # Register new peer and save KV base pointers.
358
+ self._add_remote_peer(
359
+ KVArgsRegisterInfo.from_zmq(waiting_req_bytes)
360
+ )
361
+ logger.debug(f"Register KVArgs from {agent_name} successfully")
362
+ continue
333
363
  room = int(room)
334
- agent_name = waiting_req_bytes[4].decode("ascii")
335
364
  if room not in self.transfer_infos:
336
365
  self.transfer_infos[room] = {}
337
366
  self.transfer_infos[room][agent_name] = TransferInfo.from_zmq(
338
367
  waiting_req_bytes
339
368
  )
340
-
369
+ required_dst_info_num = self.transfer_infos[room][
370
+ agent_name
371
+ ].required_dst_info_num
341
372
  logger.debug(f"got info {room=} {agent_name=} {required_dst_info_num=}")
342
373
  if len(self.transfer_infos[room]) == required_dst_info_num:
343
374
  logger.debug(f"{room=} is bootstrapped")
@@ -391,6 +422,7 @@ class NixlKVSender(BaseKVSender):
391
422
  self.chunk_id += 1
392
423
  if is_last:
393
424
  self.has_sent = True
425
+ del self.kv_mgr.request_status[self.bootstrap_room]
394
426
 
395
427
  def poll(self) -> KVPoll:
396
428
  if not self.has_sent:
@@ -415,6 +447,7 @@ class NixlKVReceiver(CommonKVReceiver):
415
447
  data_parallel_rank: Optional[int] = None,
416
448
  ):
417
449
  self.started_transfer = False
450
+ self.conclude_state = None
418
451
  super().__init__(mgr, bootstrap_addr, bootstrap_room, data_parallel_rank)
419
452
 
420
453
  def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None):
@@ -426,17 +459,8 @@ class NixlKVReceiver(CommonKVReceiver):
426
459
  f"Fetched bootstrap info: {bootstrap_info} for engine rank: {self.kv_mgr.kv_args.engine_rank}"
427
460
  )
428
461
  is_dummy = bootstrap_info["is_dummy"]
429
-
430
- # TODO: send_kv_args earlier
431
- packed_kv_data_ptrs = b"".join(
432
- struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs
433
- )
434
- packed_aux_data_ptrs = b"".join(
435
- struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs
436
- )
437
-
438
462
  logger.debug(
439
- f"Sending to {self.prefill_server_url} with bootstrap room {self.bootstrap_room}"
463
+ f"Sending to {self.prefill_server_url} with bootstrap room {self.bootstrap_room} {is_dummy=}"
440
464
  )
441
465
  sock, lock = self._connect("tcp://" + self.prefill_server_url)
442
466
  with lock:
@@ -446,13 +470,9 @@ class NixlKVReceiver(CommonKVReceiver):
446
470
  str(self.bootstrap_room).encode("ascii"),
447
471
  get_local_ip_by_remote().encode("ascii"),
448
472
  str(self.kv_mgr.rank_port).encode("ascii"),
449
- self.kv_mgr.agent.get_agent_metadata(),
450
473
  self.kv_mgr.agent.name.encode("ascii"),
451
- packed_kv_data_ptrs,
452
474
  kv_indices.tobytes() if not is_dummy else b"",
453
- packed_aux_data_ptrs,
454
475
  str(aux_index).encode("ascii"),
455
- str(self.kv_mgr.kv_args.gpu_id).encode("ascii"),
456
476
  str(self.required_dst_info_num).encode("ascii"),
457
477
  ]
458
478
  )
@@ -460,17 +480,45 @@ class NixlKVReceiver(CommonKVReceiver):
460
480
  self.started_transfer = True
461
481
 
462
482
  def poll(self) -> KVPoll:
483
+ if self.conclude_state is not None:
484
+ return self.conclude_state
463
485
  if not self.started_transfer:
464
486
  return KVPoll.WaitingForInput # type: ignore
465
487
 
466
488
  self.kv_mgr.update_transfer_status()
467
-
468
489
  if self.kv_mgr.check_transfer_done(self.bootstrap_room): # type: ignore
490
+ self.conclude_state = KVPoll.Success
491
+ del self.kv_mgr.transfer_statuses[self.bootstrap_room]
469
492
  return KVPoll.Success # type: ignore
470
493
  return KVPoll.WaitingForInput # type: ignore
471
494
 
472
495
  def _register_kv_args(self):
473
- pass
496
+ for bootstrap_info in self.bootstrap_infos:
497
+ self.prefill_server_url = (
498
+ f"{bootstrap_info['rank_ip']}:{bootstrap_info['rank_port']}"
499
+ )
500
+ packed_kv_data_ptrs = b"".join(
501
+ struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs
502
+ )
503
+ packed_aux_data_ptrs = b"".join(
504
+ struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs
505
+ )
506
+
507
+ sock, lock = self._connect("tcp://" + self.prefill_server_url)
508
+ with lock:
509
+ sock.send_multipart(
510
+ [
511
+ GUARD,
512
+ "None".encode("ascii"),
513
+ get_local_ip_by_remote().encode("ascii"),
514
+ str(self.kv_mgr.rank_port).encode("ascii"),
515
+ self.kv_mgr.agent.name.encode("ascii"),
516
+ self.kv_mgr.agent.get_agent_metadata(),
517
+ packed_kv_data_ptrs,
518
+ packed_aux_data_ptrs,
519
+ str(self.kv_mgr.kv_args.gpu_id).encode("ascii"),
520
+ ]
521
+ )
474
522
 
475
523
  def failure_exception(self):
476
524
  raise Exception("Fake KVReceiver Exception")
@@ -25,7 +25,6 @@ from collections import deque
25
25
  from http import HTTPStatus
26
26
  from typing import TYPE_CHECKING, List, Optional
27
27
 
28
- import numpy as np
29
28
  import torch
30
29
 
31
30
  from sglang.srt.disaggregation.base import BaseKVManager, KVPoll
@@ -45,6 +44,7 @@ from sglang.srt.disaggregation.utils import (
45
44
  )
46
45
  from sglang.srt.managers.schedule_batch import FINISH_LENGTH, Req, ScheduleBatch
47
46
  from sglang.srt.model_executor.forward_batch_info import ForwardMode
47
+ from sglang.srt.utils import require_mlp_sync
48
48
 
49
49
  if TYPE_CHECKING:
50
50
  from torch.distributed import ProcessGroup
@@ -93,8 +93,6 @@ class PrefillBootstrapQueue:
93
93
  self.gpu_id = gpu_id
94
94
  self.bootstrap_port = bootstrap_port
95
95
  self.queue: List[Req] = []
96
- self.pp_rank = pp_rank
97
- self.pp_size = pp_size
98
96
  self.gloo_group = gloo_group
99
97
  self.max_total_num_tokens = max_total_num_tokens
100
98
  self.scheduler = scheduler
@@ -124,6 +122,9 @@ class PrefillBootstrapQueue:
124
122
  kv_args.kv_data_ptrs = kv_data_ptrs
125
123
  kv_args.kv_data_lens = kv_data_lens
126
124
  kv_args.kv_item_lens = kv_item_lens
125
+ if not self.is_mla_backend:
126
+ kv_args.kv_head_num = self.token_to_kv_pool.head_num
127
+ kv_args.page_size = self.token_to_kv_pool.page_size
127
128
 
128
129
  kv_args.aux_data_ptrs, kv_args.aux_data_lens, kv_args.aux_item_lens = (
129
130
  self.metadata_buffers.get_buf_infos()
@@ -274,12 +275,8 @@ class SchedulerDisaggregationPrefillMixin:
274
275
  self.process_prefill_chunk()
275
276
  batch = self.get_new_batch_prefill()
276
277
 
277
- # Handle DP attention
278
- if (
279
- self.server_args.enable_dp_attention
280
- or self.server_args.enable_sp_layernorm
281
- ):
282
- batch, _ = self.prepare_dp_attn_batch(batch)
278
+ if require_mlp_sync(self.server_args):
279
+ batch, _ = self.prepare_mlp_sync_batch(batch)
283
280
  self.cur_batch = batch
284
281
 
285
282
  if batch:
@@ -312,12 +309,8 @@ class SchedulerDisaggregationPrefillMixin:
312
309
  self.process_prefill_chunk()
313
310
  batch = self.get_new_batch_prefill()
314
311
 
315
- # Handle DP attention
316
- if (
317
- self.server_args.enable_dp_attention
318
- or self.server_args.enable_sp_layernorm
319
- ):
320
- batch, _ = self.prepare_dp_attn_batch(batch)
312
+ if require_mlp_sync(self.server_args):
313
+ batch, _ = self.prepare_mlp_sync_batch(batch)
321
314
  self.cur_batch = batch
322
315
  if batch:
323
316
  result = self.run_batch(batch)
@@ -393,6 +386,8 @@ class SchedulerDisaggregationPrefillMixin:
393
386
  logits_output.input_token_logprobs = tuple(
394
387
  logits_output.input_token_logprobs.tolist()
395
388
  )
389
+
390
+ hidden_state_offset = 0
396
391
  for i, (req, next_token_id) in enumerate(
397
392
  zip(batch.reqs, next_token_ids, strict=True)
398
393
  ):
@@ -402,6 +397,16 @@ class SchedulerDisaggregationPrefillMixin:
402
397
  req.output_ids.append(next_token_id)
403
398
  self.tree_cache.cache_unfinished_req(req) # update the tree and lock
404
399
  self.disagg_prefill_inflight_queue.append(req)
400
+ if logits_output.hidden_states is not None:
401
+ last_hidden_index = (
402
+ hidden_state_offset + extend_input_len_per_req[i] - 1
403
+ )
404
+ req.hidden_states_tensor = (
405
+ logits_output.hidden_states[last_hidden_index].cpu().clone()
406
+ )
407
+ hidden_state_offset += extend_input_len_per_req[i]
408
+ else:
409
+ req.hidden_states_tensor = None
405
410
  if req.return_logprob:
406
411
  assert extend_logprob_start_len_per_req is not None
407
412
  assert extend_input_len_per_req is not None
@@ -6,6 +6,7 @@ import random
6
6
  import threading
7
7
  import warnings
8
8
  from collections import deque
9
+ from contextlib import nullcontext
9
10
  from enum import Enum
10
11
  from typing import TYPE_CHECKING, List, Optional
11
12
 
@@ -84,24 +85,43 @@ class ReqToMetadataIdxAllocator:
84
85
 
85
86
 
86
87
  class MetadataBuffers:
87
- def __init__(self, size: int, max_top_logprobs_num: int = 128):
88
- # TODO: abort top_logprobs_num > 128 in PD
89
-
90
- # We transfer the metadata of first output token to decode
91
- # The minimal size for RDMA is 64Bytes, so we pad it to > 64Bytes
92
- self.output_ids = torch.zeros((size, 16), dtype=torch.int32, device="cpu")
93
- self.output_token_logprobs_val = torch.zeros(
94
- (size, 16), dtype=torch.float32, device="cpu"
95
- )
96
- self.output_token_logprobs_idx = torch.zeros(
97
- (size, 16), dtype=torch.int32, device="cpu"
98
- )
99
- self.output_top_logprobs_val = torch.zeros(
100
- (size, max_top_logprobs_num), dtype=torch.float32, device="cpu"
101
- )
102
- self.output_top_logprobs_idx = torch.zeros(
103
- (size, max_top_logprobs_num), dtype=torch.int32, device="cpu"
104
- )
88
+ def __init__(
89
+ self,
90
+ size: int,
91
+ hidden_size: int,
92
+ dtype: torch.dtype,
93
+ max_top_logprobs_num: int = 128,
94
+ custom_mem_pool: torch.cuda.MemPool = None,
95
+ ):
96
+ self.custom_mem_pool = custom_mem_pool
97
+ device = "cuda" if self.custom_mem_pool else "cpu"
98
+
99
+ with (
100
+ torch.cuda.use_mem_pool(self.custom_mem_pool)
101
+ if self.custom_mem_pool
102
+ else nullcontext()
103
+ ):
104
+ # TODO: abort top_logprobs_num > 128 in PD
105
+
106
+ # We transfer the metadata of first output token to decode
107
+ # The minimal size for RDMA is 64Bytes, so we pad it to > 64Bytes
108
+ self.output_ids = torch.zeros((size, 16), dtype=torch.int32, device=device)
109
+
110
+ self.output_token_logprobs_val = torch.zeros(
111
+ (size, 16), dtype=torch.float32, device=device
112
+ )
113
+ self.output_token_logprobs_idx = torch.zeros(
114
+ (size, 16), dtype=torch.int32, device=device
115
+ )
116
+ self.output_top_logprobs_val = torch.zeros(
117
+ (size, max_top_logprobs_num), dtype=torch.float32, device=device
118
+ )
119
+ self.output_top_logprobs_idx = torch.zeros(
120
+ (size, max_top_logprobs_num), dtype=torch.int32, device=device
121
+ )
122
+ self.output_hidden_states = torch.zeros(
123
+ (size, hidden_size), dtype=dtype, device=device
124
+ )
105
125
 
106
126
  def get_buf_infos(self):
107
127
  ptrs = [
@@ -110,6 +130,7 @@ class MetadataBuffers:
110
130
  self.output_token_logprobs_idx.data_ptr(),
111
131
  self.output_top_logprobs_val.data_ptr(),
112
132
  self.output_top_logprobs_idx.data_ptr(),
133
+ self.output_hidden_states.data_ptr(),
113
134
  ]
114
135
  data_lens = [
115
136
  self.output_ids.nbytes,
@@ -117,6 +138,7 @@ class MetadataBuffers:
117
138
  self.output_token_logprobs_idx.nbytes,
118
139
  self.output_top_logprobs_val.nbytes,
119
140
  self.output_top_logprobs_idx.nbytes,
141
+ self.output_hidden_states.nbytes,
120
142
  ]
121
143
  item_lens = [
122
144
  self.output_ids[0].nbytes,
@@ -124,6 +146,7 @@ class MetadataBuffers:
124
146
  self.output_token_logprobs_idx[0].nbytes,
125
147
  self.output_top_logprobs_val[0].nbytes,
126
148
  self.output_top_logprobs_idx[0].nbytes,
149
+ self.output_hidden_states[0].nbytes,
127
150
  ]
128
151
  return ptrs, data_lens, item_lens
129
152
 
@@ -134,6 +157,7 @@ class MetadataBuffers:
134
157
  self.output_token_logprobs_idx[idx],
135
158
  self.output_top_logprobs_val[idx],
136
159
  self.output_top_logprobs_idx[idx],
160
+ self.output_hidden_states[idx],
137
161
  )
138
162
 
139
163
  def set_buf(self, req: Req):
@@ -161,6 +185,11 @@ class MetadataBuffers:
161
185
  ] = torch.tensor(
162
186
  req.output_top_logprobs_idx[0], dtype=torch.int32, device="cpu"
163
187
  )
188
+ # for PD + spec decode
189
+ if req.hidden_states_tensor is not None:
190
+ self.output_hidden_states[req.metadata_buffer_index].copy_(
191
+ req.hidden_states_tensor
192
+ )
164
193
 
165
194
 
166
195
  #########################
@@ -523,17 +523,25 @@ class GroupCoordinator:
523
523
  self,
524
524
  input_: torch.Tensor,
525
525
  dim: int = -1,
526
- tensor_list: List[torch.Tensor] = None,
526
+ output_tensor_list: Optional[List[torch.Tensor]] = None,
527
527
  ) -> torch.Tensor:
528
528
  world_size = self.world_size
529
529
  # Bypass the function if we are using only 1 GPU.
530
530
  if world_size == 1:
531
- return input_
531
+ if output_tensor_list is not None:
532
+ logger.warning(
533
+ "Performing in-place all-gather with a group size of 1. "
534
+ "This may be unnecessary; consider bypassing it for better efficiency."
535
+ )
536
+ output_tensor_list[0].copy_(input_)
537
+ return None
538
+ else:
539
+ return input_
532
540
 
533
- if tensor_list is not None:
541
+ if output_tensor_list is not None:
534
542
  # TODO(ch-wan): support other backends
535
543
  return torch.distributed.all_gather(
536
- tensor_list, input_, group=self.device_group
544
+ output_tensor_list, input_, group=self.device_group
537
545
  )
538
546
 
539
547
  assert (
@@ -37,7 +37,6 @@ setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
37
37
  import torch
38
38
  import uvloop
39
39
 
40
- from sglang.srt.code_completion_parser import load_completion_template_for_openai_api
41
40
  from sglang.srt.entrypoints.EngineBase import EngineBase
42
41
  from sglang.srt.managers.data_parallel_controller import (
43
42
  run_data_parallel_controller_process,
@@ -58,11 +57,8 @@ from sglang.srt.managers.io_struct import (
58
57
  UpdateWeightsFromTensorReqInput,
59
58
  )
60
59
  from sglang.srt.managers.scheduler import run_scheduler_process
60
+ from sglang.srt.managers.template_manager import TemplateManager
61
61
  from sglang.srt.managers.tokenizer_manager import TokenizerManager
62
- from sglang.srt.openai_api.adapter import (
63
- guess_chat_template_name_from_model_path,
64
- load_chat_template_for_openai_api,
65
- )
66
62
  from sglang.srt.server_args import PortArgs, ServerArgs
67
63
  from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
68
64
  from sglang.srt.utils import (
@@ -119,21 +115,22 @@ class Engine(EngineBase):
119
115
  atexit.register(self.shutdown)
120
116
 
121
117
  # Allocate ports for inter-process communications
122
- port_args = PortArgs.init_new(server_args)
118
+ self.port_args = PortArgs.init_new(server_args)
123
119
  logger.info(f"{server_args=}")
124
120
 
125
121
  # Launch subprocesses
126
- tokenizer_manager, scheduler_info = _launch_subprocesses(
122
+ tokenizer_manager, template_manager, scheduler_info = _launch_subprocesses(
127
123
  server_args=server_args,
128
- port_args=port_args,
124
+ port_args=self.port_args,
129
125
  )
130
126
  self.server_args = server_args
131
127
  self.tokenizer_manager = tokenizer_manager
128
+ self.template_manager = template_manager
132
129
  self.scheduler_info = scheduler_info
133
130
 
134
131
  context = zmq.Context(2)
135
132
  self.send_to_rpc = get_zmq_socket(
136
- context, zmq.DEALER, port_args.rpc_ipc_name, True
133
+ context, zmq.DEALER, self.port_args.rpc_ipc_name, True
137
134
  )
138
135
 
139
136
  def generate(
@@ -175,7 +172,7 @@ class Engine(EngineBase):
175
172
  """
176
173
  if self.server_args.enable_dp_attention:
177
174
  if data_parallel_rank is None:
178
- logger.info("data_parallel_rank not provided, using default dispatch")
175
+ logger.debug("data_parallel_rank not provided, using default dispatch")
179
176
  elif data_parallel_rank < 0:
180
177
  raise ValueError("data_parallel_rank must be non-negative")
181
178
  elif data_parallel_rank >= self.server_args.dp_size:
@@ -245,6 +242,7 @@ class Engine(EngineBase):
245
242
  token_ids_logprob: Optional[Union[List[List[int]], List[int]]] = None,
246
243
  lora_path: Optional[List[Optional[str]]] = None,
247
244
  custom_logit_processor: Optional[Union[List[str], str]] = None,
245
+ return_hidden_states: bool = False,
248
246
  stream: bool = False,
249
247
  bootstrap_host: Optional[Union[List[str], str]] = None,
250
248
  bootstrap_port: Optional[Union[List[int], int]] = None,
@@ -258,7 +256,7 @@ class Engine(EngineBase):
258
256
 
259
257
  if self.server_args.enable_dp_attention:
260
258
  if data_parallel_rank is None:
261
- logger.info("data_parallel_rank not provided, using default dispatch")
259
+ logger.debug("data_parallel_rank not provided, using default dispatch")
262
260
  elif data_parallel_rank < 0:
263
261
  raise ValueError("data_parallel_rank must be non-negative")
264
262
  elif data_parallel_rank >= self.server_args.dp_size:
@@ -277,6 +275,7 @@ class Engine(EngineBase):
277
275
  top_logprobs_num=top_logprobs_num,
278
276
  token_ids_logprob=token_ids_logprob,
279
277
  lora_path=lora_path,
278
+ return_hidden_states=return_hidden_states,
280
279
  stream=stream,
281
280
  custom_logit_processor=custom_logit_processor,
282
281
  bootstrap_host=bootstrap_host,
@@ -479,17 +478,15 @@ class Engine(EngineBase):
479
478
  self.tokenizer_manager.get_weights_by_name(obj, None)
480
479
  )
481
480
 
482
- def release_memory_occupation(self):
483
- """Release GPU occupation temporarily."""
484
- obj = ReleaseMemoryOccupationReqInput()
481
+ def release_memory_occupation(self, tags: Optional[List[str]] = None):
482
+ obj = ReleaseMemoryOccupationReqInput(tags=tags)
485
483
  loop = asyncio.get_event_loop()
486
484
  return loop.run_until_complete(
487
485
  self.tokenizer_manager.release_memory_occupation(obj, None)
488
486
  )
489
487
 
490
- def resume_memory_occupation(self):
491
- """Resume GPU occupation."""
492
- obj = ResumeMemoryOccupationReqInput()
488
+ def resume_memory_occupation(self, tags: Optional[List[str]] = None):
489
+ obj = ResumeMemoryOccupationReqInput(tags=tags)
493
490
  loop = asyncio.get_event_loop()
494
491
  return loop.run_until_complete(
495
492
  self.tokenizer_manager.resume_memory_occupation(obj, None)
@@ -649,7 +646,7 @@ def _set_envs_and_config(server_args: ServerArgs):
649
646
 
650
647
  def _launch_subprocesses(
651
648
  server_args: ServerArgs, port_args: Optional[PortArgs] = None
652
- ) -> Tuple[TokenizerManager, Dict]:
649
+ ) -> Tuple[TokenizerManager, TemplateManager, Dict]:
653
650
  """
654
651
  Launch the TokenizerManager in the main process, the Scheduler in a subprocess, and the DetokenizerManager in another subprocess.
655
652
  """
@@ -670,11 +667,9 @@ def _launch_subprocesses(
670
667
 
671
668
  scheduler_procs = []
672
669
  if server_args.dp_size == 1:
673
- # Launch tensor parallel scheduler processes
674
670
  memory_saver_adapter = TorchMemorySaverAdapter.create(
675
671
  enable=server_args.enable_memory_saver
676
672
  )
677
-
678
673
  scheduler_pipe_readers = []
679
674
 
680
675
  nnodes_per_tp_group = max(server_args.nnodes // server_args.pp_size, 1)
@@ -710,6 +705,7 @@ def _launch_subprocesses(
710
705
  writer,
711
706
  ),
712
707
  )
708
+
713
709
  with memory_saver_adapter.configure_subprocess():
714
710
  proc.start()
715
711
  scheduler_procs.append(proc)
@@ -735,7 +731,7 @@ def _launch_subprocesses(
735
731
 
736
732
  if os.getenv("SGLANG_BLOCK_NONZERO_RANK_CHILDREN") == "0":
737
733
  # When using `Engine` as a Python API, we don't want to block here.
738
- return None, None
734
+ return None, None, None
739
735
 
740
736
  launch_dummy_health_check_server(server_args.host, server_args.port)
741
737
 
@@ -744,7 +740,7 @@ def _launch_subprocesses(
744
740
  logger.error(
745
741
  f"Scheduler or DataParallelController {proc.pid} terminated with {proc.exitcode}"
746
742
  )
747
- return None, None
743
+ return None, None, None
748
744
 
749
745
  # Launch detokenizer process
750
746
  detoken_proc = mp.Process(
@@ -758,15 +754,15 @@ def _launch_subprocesses(
758
754
 
759
755
  # Launch tokenizer process
760
756
  tokenizer_manager = TokenizerManager(server_args, port_args)
761
- if server_args.chat_template:
762
- load_chat_template_for_openai_api(
763
- tokenizer_manager, server_args.chat_template, server_args.model_path
764
- )
765
- else:
766
- guess_chat_template_name_from_model_path(server_args.model_path)
767
757
 
768
- if server_args.completion_template:
769
- load_completion_template_for_openai_api(server_args.completion_template)
758
+ # Initialize templates
759
+ template_manager = TemplateManager()
760
+ template_manager.initialize_templates(
761
+ tokenizer_manager=tokenizer_manager,
762
+ model_path=server_args.model_path,
763
+ chat_template=server_args.chat_template,
764
+ completion_template=server_args.completion_template,
765
+ )
770
766
 
771
767
  # Wait for the model to finish loading
772
768
  scheduler_infos = []
@@ -790,4 +786,4 @@ def _launch_subprocesses(
790
786
  # Assume all schedulers have the same scheduler_info
791
787
  scheduler_info = scheduler_infos[0]
792
788
  tokenizer_manager.max_req_input_len = scheduler_info["max_req_input_len"]
793
- return tokenizer_manager, scheduler_info
789
+ return tokenizer_manager, template_manager, scheduler_info