sglang 0.4.6.post4__py3-none-any.whl → 0.4.6.post5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (130) hide show
  1. sglang/bench_offline_throughput.py +6 -6
  2. sglang/bench_one_batch.py +5 -4
  3. sglang/bench_one_batch_server.py +23 -15
  4. sglang/bench_serving.py +133 -57
  5. sglang/compile_deep_gemm.py +4 -4
  6. sglang/srt/configs/model_config.py +39 -28
  7. sglang/srt/conversation.py +1 -1
  8. sglang/srt/disaggregation/decode.py +122 -133
  9. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +142 -0
  10. sglang/srt/disaggregation/fake/conn.py +3 -13
  11. sglang/srt/disaggregation/kv_events.py +357 -0
  12. sglang/srt/disaggregation/mini_lb.py +57 -24
  13. sglang/srt/disaggregation/mooncake/conn.py +11 -2
  14. sglang/srt/disaggregation/mooncake/transfer_engine.py +2 -1
  15. sglang/srt/disaggregation/nixl/conn.py +9 -19
  16. sglang/srt/disaggregation/prefill.py +126 -44
  17. sglang/srt/disaggregation/utils.py +116 -5
  18. sglang/srt/distributed/utils.py +3 -3
  19. sglang/srt/entrypoints/EngineBase.py +5 -0
  20. sglang/srt/entrypoints/engine.py +28 -8
  21. sglang/srt/entrypoints/http_server.py +6 -4
  22. sglang/srt/entrypoints/http_server_engine.py +5 -2
  23. sglang/srt/function_call/base_format_detector.py +250 -0
  24. sglang/srt/function_call/core_types.py +34 -0
  25. sglang/srt/function_call/deepseekv3_detector.py +157 -0
  26. sglang/srt/function_call/ebnf_composer.py +234 -0
  27. sglang/srt/function_call/function_call_parser.py +175 -0
  28. sglang/srt/function_call/llama32_detector.py +74 -0
  29. sglang/srt/function_call/mistral_detector.py +84 -0
  30. sglang/srt/function_call/pythonic_detector.py +163 -0
  31. sglang/srt/function_call/qwen25_detector.py +67 -0
  32. sglang/srt/function_call/utils.py +35 -0
  33. sglang/srt/hf_transformers_utils.py +46 -7
  34. sglang/srt/layers/attention/aiter_backend.py +513 -0
  35. sglang/srt/layers/attention/flashattention_backend.py +63 -17
  36. sglang/srt/layers/attention/flashinfer_mla_backend.py +8 -4
  37. sglang/srt/layers/attention/flashmla_backend.py +340 -78
  38. sglang/srt/layers/attention/triton_backend.py +3 -0
  39. sglang/srt/layers/attention/utils.py +2 -2
  40. sglang/srt/layers/attention/vision.py +1 -1
  41. sglang/srt/layers/communicator.py +451 -0
  42. sglang/srt/layers/dp_attention.py +0 -10
  43. sglang/srt/layers/moe/cutlass_moe.py +207 -0
  44. sglang/srt/layers/moe/ep_moe/kernels.py +33 -11
  45. sglang/srt/layers/moe/ep_moe/layer.py +104 -50
  46. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +82 -7
  47. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -0
  48. sglang/srt/layers/moe/topk.py +66 -9
  49. sglang/srt/layers/multimodal.py +70 -0
  50. sglang/srt/layers/quantization/__init__.py +7 -2
  51. sglang/srt/layers/quantization/deep_gemm.py +5 -3
  52. sglang/srt/layers/quantization/fp8.py +90 -0
  53. sglang/srt/layers/quantization/fp8_utils.py +6 -0
  54. sglang/srt/layers/quantization/gptq.py +298 -6
  55. sglang/srt/layers/quantization/int8_kernel.py +18 -5
  56. sglang/srt/layers/quantization/qoq.py +244 -0
  57. sglang/srt/lora/lora_manager.py +1 -3
  58. sglang/srt/managers/deepseek_eplb.py +278 -0
  59. sglang/srt/managers/eplb_manager.py +55 -0
  60. sglang/srt/managers/expert_distribution.py +704 -56
  61. sglang/srt/managers/expert_location.py +394 -0
  62. sglang/srt/managers/expert_location_dispatch.py +91 -0
  63. sglang/srt/managers/io_struct.py +16 -3
  64. sglang/srt/managers/mm_utils.py +293 -139
  65. sglang/srt/managers/multimodal_processors/base_processor.py +127 -42
  66. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +6 -1
  67. sglang/srt/managers/multimodal_processors/gemma3.py +31 -6
  68. sglang/srt/managers/multimodal_processors/internvl.py +14 -5
  69. sglang/srt/managers/multimodal_processors/janus_pro.py +7 -1
  70. sglang/srt/managers/multimodal_processors/kimi_vl.py +7 -6
  71. sglang/srt/managers/multimodal_processors/llava.py +3 -3
  72. sglang/srt/managers/multimodal_processors/minicpm.py +25 -31
  73. sglang/srt/managers/multimodal_processors/mllama4.py +6 -0
  74. sglang/srt/managers/multimodal_processors/pixtral.py +9 -9
  75. sglang/srt/managers/multimodal_processors/qwen_vl.py +58 -16
  76. sglang/srt/managers/schedule_batch.py +49 -21
  77. sglang/srt/managers/schedule_policy.py +4 -5
  78. sglang/srt/managers/scheduler.py +92 -50
  79. sglang/srt/managers/session_controller.py +1 -1
  80. sglang/srt/managers/tokenizer_manager.py +99 -24
  81. sglang/srt/mem_cache/base_prefix_cache.py +3 -0
  82. sglang/srt/mem_cache/chunk_cache.py +3 -1
  83. sglang/srt/mem_cache/hiradix_cache.py +4 -4
  84. sglang/srt/mem_cache/memory_pool.py +74 -52
  85. sglang/srt/mem_cache/multimodal_cache.py +45 -0
  86. sglang/srt/mem_cache/radix_cache.py +58 -5
  87. sglang/srt/metrics/collector.py +2 -2
  88. sglang/srt/mm_utils.py +10 -0
  89. sglang/srt/model_executor/cuda_graph_runner.py +20 -9
  90. sglang/srt/model_executor/expert_location_updater.py +422 -0
  91. sglang/srt/model_executor/forward_batch_info.py +4 -0
  92. sglang/srt/model_executor/model_runner.py +144 -54
  93. sglang/srt/model_loader/loader.py +10 -6
  94. sglang/srt/models/clip.py +5 -1
  95. sglang/srt/models/deepseek_v2.py +297 -343
  96. sglang/srt/models/exaone.py +8 -3
  97. sglang/srt/models/gemma3_mm.py +70 -33
  98. sglang/srt/models/llama4.py +10 -2
  99. sglang/srt/models/llava.py +26 -18
  100. sglang/srt/models/mimo_mtp.py +220 -0
  101. sglang/srt/models/minicpmo.py +5 -12
  102. sglang/srt/models/mistral.py +71 -1
  103. sglang/srt/models/mllama.py +3 -3
  104. sglang/srt/models/qwen2.py +95 -26
  105. sglang/srt/models/qwen2_5_vl.py +8 -0
  106. sglang/srt/models/qwen2_moe.py +330 -60
  107. sglang/srt/models/qwen2_vl.py +6 -0
  108. sglang/srt/models/qwen3.py +52 -10
  109. sglang/srt/models/qwen3_moe.py +411 -48
  110. sglang/srt/models/siglip.py +294 -0
  111. sglang/srt/openai_api/adapter.py +28 -16
  112. sglang/srt/openai_api/protocol.py +6 -0
  113. sglang/srt/operations.py +154 -0
  114. sglang/srt/operations_strategy.py +31 -0
  115. sglang/srt/server_args.py +134 -24
  116. sglang/srt/speculative/eagle_utils.py +131 -0
  117. sglang/srt/speculative/eagle_worker.py +47 -2
  118. sglang/srt/utils.py +68 -12
  119. sglang/test/test_cutlass_moe.py +278 -0
  120. sglang/test/test_utils.py +2 -36
  121. sglang/utils.py +2 -2
  122. sglang/version.py +1 -1
  123. {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/METADATA +20 -11
  124. {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/RECORD +128 -102
  125. {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/WHEEL +1 -1
  126. sglang/srt/function_call_parser.py +0 -858
  127. sglang/srt/platforms/interface.py +0 -371
  128. /sglang/srt/models/{xiaomi_mimo.py → mimo.py} +0 -0
  129. {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/licenses/LICENSE +0 -0
  130. {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/top_level.txt +0 -0
@@ -61,7 +61,8 @@ class MooncakeTransferEngine:
61
61
  self, session_id: str, buffer: int, peer_buffer_address: int, length: int
62
62
  ) -> int:
63
63
  """Synchronously transfer data to the specified address."""
64
-
64
+ # the first time: based on session_id (which contains remote_ip) to construct a queue pair, and cache the queue pair
65
+ # later: based on the cached queue pair to send data
65
66
  ret = self.engine.transfer_sync_write(
66
67
  session_id, buffer, peer_buffer_address, length
67
68
  )
@@ -35,29 +35,19 @@ logger = logging.getLogger(__name__)
35
35
  NixlEngineInfo: TypeAlias = Dict[str, Union[str, int]]
36
36
 
37
37
 
38
- # From Mooncake backend.
39
38
  def group_concurrent_contiguous(
40
39
  src_indices: npt.NDArray[np.int64], dst_indices: npt.NDArray[np.int64]
41
40
  ) -> Tuple[List[npt.NDArray[np.int64]], List[npt.NDArray[np.int64]]]:
42
- src_groups = []
43
- dst_groups = []
44
- current_src = [src_indices[0]]
45
- current_dst = [dst_indices[0]]
46
-
47
- for i in range(1, len(src_indices)):
48
- src_contiguous = src_indices[i] == src_indices[i - 1] + 1
49
- dst_contiguous = dst_indices[i] == dst_indices[i - 1] + 1
50
- if src_contiguous and dst_contiguous:
51
- current_src.append(src_indices[i])
52
- current_dst.append(dst_indices[i])
53
- else:
54
- src_groups.append(current_src)
55
- dst_groups.append(current_dst)
56
- current_src = [src_indices[i]]
57
- current_dst = [dst_indices[i]]
41
+ """Vectorised NumPy implementation."""
42
+ if src_indices.size == 0:
43
+ return [], []
44
+
45
+ brk = np.where((np.diff(src_indices) != 1) | (np.diff(dst_indices) != 1))[0] + 1
46
+ src_groups = np.split(src_indices, brk)
47
+ dst_groups = np.split(dst_indices, brk)
58
48
 
59
- src_groups.append(current_src)
60
- dst_groups.append(current_dst)
49
+ src_groups = [g.tolist() for g in src_groups]
50
+ dst_groups = [g.tolist() for g in dst_groups]
61
51
 
62
52
  return src_groups, dst_groups
63
53
 
@@ -22,6 +22,7 @@ from __future__ import annotations
22
22
  import logging
23
23
  import threading
24
24
  from collections import deque
25
+ from http import HTTPStatus
25
26
  from typing import TYPE_CHECKING, List, Optional
26
27
 
27
28
  import torch
@@ -31,6 +32,7 @@ from sglang.srt.disaggregation.utils import (
31
32
  DisaggregationMode,
32
33
  FakeBootstrapHost,
33
34
  KVClassType,
35
+ MetadataBuffers,
34
36
  ReqToMetadataIdxAllocator,
35
37
  TransferBackend,
36
38
  get_kv_class,
@@ -38,8 +40,10 @@ from sglang.srt.disaggregation.utils import (
38
40
  kv_to_page_indices,
39
41
  kv_to_page_num,
40
42
  poll_and_all_reduce,
43
+ prepare_abort,
41
44
  )
42
45
  from sglang.srt.managers.schedule_batch import FINISH_LENGTH, Req, ScheduleBatch
46
+ from sglang.srt.model_executor.forward_batch_info import ForwardMode
43
47
 
44
48
  if TYPE_CHECKING:
45
49
  from torch.distributed import ProcessGroup
@@ -59,9 +63,9 @@ class PrefillBootstrapQueue:
59
63
  def __init__(
60
64
  self,
61
65
  token_to_kv_pool: KVCache,
66
+ draft_token_to_kv_pool: Optional[KVCache],
62
67
  req_to_metadata_buffer_idx_allocator: ReqToMetadataIdxAllocator,
63
- metadata_buffers: List[torch.Tensor],
64
- aux_dtype: torch.dtype,
68
+ metadata_buffers: MetadataBuffers,
65
69
  tp_rank: int,
66
70
  tp_size: int,
67
71
  bootstrap_port: int,
@@ -70,8 +74,9 @@ class PrefillBootstrapQueue:
70
74
  scheduler: Scheduler,
71
75
  ):
72
76
  self.token_to_kv_pool = token_to_kv_pool
77
+ self.draft_token_to_kv_pool = draft_token_to_kv_pool
78
+
73
79
  self.is_mla_backend = is_mla_backend(token_to_kv_pool)
74
- self.aux_dtype = aux_dtype
75
80
 
76
81
  self.metadata_buffers = metadata_buffers
77
82
  self.req_to_metadata_buffer_idx_allocator = req_to_metadata_buffer_idx_allocator
@@ -96,20 +101,24 @@ class PrefillBootstrapQueue:
96
101
  self.token_to_kv_pool.get_contiguous_buf_infos()
97
102
  )
98
103
 
104
+ if self.draft_token_to_kv_pool is not None:
105
+ # We should also transfer draft model kv cache. The indices are
106
+ # always shared with a target model.
107
+ draft_kv_data_ptrs, draft_kv_data_lens, draft_kv_item_lens = (
108
+ self.draft_token_to_kv_pool.get_contiguous_buf_infos()
109
+ )
110
+ kv_data_ptrs += draft_kv_data_ptrs
111
+ kv_data_lens += draft_kv_data_lens
112
+ kv_item_lens += draft_kv_item_lens
113
+
99
114
  kv_args.kv_data_ptrs = kv_data_ptrs
100
115
  kv_args.kv_data_lens = kv_data_lens
101
116
  kv_args.kv_item_lens = kv_item_lens
102
117
 
103
118
  # Define req -> input ids buffer
104
- kv_args.aux_data_ptrs = [
105
- metadata_buffer.data_ptr() for metadata_buffer in self.metadata_buffers
106
- ]
107
- kv_args.aux_data_lens = [
108
- metadata_buffer.nbytes for metadata_buffer in self.metadata_buffers
109
- ]
110
- kv_args.aux_item_lens = [
111
- metadata_buffer[0].nbytes for metadata_buffer in self.metadata_buffers
112
- ]
119
+ kv_args.aux_data_ptrs, kv_args.aux_data_lens, kv_args.aux_item_lens = (
120
+ self.metadata_buffers.get_buf_infos()
121
+ )
113
122
  kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device
114
123
  kv_args.gpu_id = self.scheduler.gpu_id
115
124
  kv_manager_class = get_kv_class(self.transfer_backend, KVClassType.MANAGER)
@@ -135,6 +144,10 @@ class PrefillBootstrapQueue:
135
144
  self._process_req(req)
136
145
  self.queue.append(req)
137
146
 
147
+ def extend(self, reqs: List[Req]) -> None:
148
+ for req in reqs:
149
+ self.add(req)
150
+
138
151
  def _process_req(self, req: Req) -> None:
139
152
  """
140
153
  Set max_new_tokens = 1, so PrefillAdder memory estimation is accurate
@@ -157,7 +170,18 @@ class PrefillBootstrapQueue:
157
170
  if poll == KVPoll.Bootstrapping:
158
171
  continue
159
172
  elif poll == KVPoll.Failed:
160
- raise Exception("Bootstrap failed")
173
+ error_message = f"Prefill bootstrap failed for request rank={self.tp_rank} {req.rid=} {req.bootstrap_room=}"
174
+ try:
175
+ req.disagg_kv_sender.failure_exception()
176
+ except Exception as e:
177
+ error_message += f" with exception {e}"
178
+ logger.error(error_message)
179
+ prepare_abort(
180
+ req, error_message, status_code=HTTPStatus.INTERNAL_SERVER_ERROR
181
+ )
182
+ self.scheduler.stream_output([req], req.return_logprob)
183
+ indices_to_remove.add(i)
184
+ continue
161
185
 
162
186
  # KV.WaitingForInput
163
187
  num_kv_indices = len(req.origin_input_ids)
@@ -250,6 +274,16 @@ class SchedulerDisaggregationPrefillMixin:
250
274
  result = self.run_batch(batch)
251
275
  self.result_queue.append((batch.copy(), result))
252
276
 
277
+ if self.last_batch is None:
278
+ # Create a dummy first batch to start the pipeline for overlap schedule.
279
+ # It is now used for triggering the sampling_info_done event.
280
+ tmp_batch = ScheduleBatch(
281
+ reqs=None,
282
+ forward_mode=ForwardMode.DUMMY_FIRST,
283
+ next_batch_sampling_info=self.tp_worker.cur_sampling_info,
284
+ )
285
+ self.set_next_batch_sampling_info_done(tmp_batch)
286
+
253
287
  if self.last_batch:
254
288
  tmp_batch, tmp_result = self.result_queue.popleft()
255
289
  self.process_batch_result_disagg_prefill(tmp_batch, tmp_result)
@@ -273,10 +307,9 @@ class SchedulerDisaggregationPrefillMixin:
273
307
  launch_done: Optional[threading.Event] = None,
274
308
  ) -> None:
275
309
  """
276
- Transfer kv for prefill completed requests and add it into disagg_prefill_inflight_queue
310
+ Transfer kv for prefill completed requests and add it into disagg_prefill_infight_queue
277
311
  Adapted from process_batch_result_prefill
278
312
  """
279
-
280
313
  (
281
314
  logits_output,
282
315
  next_token_ids,
@@ -289,27 +322,78 @@ class SchedulerDisaggregationPrefillMixin:
289
322
  result.extend_logprob_start_len_per_req,
290
323
  )
291
324
 
325
+ logprob_pt = 0
292
326
  # Transfer kv for prefill completed requests and add it into disagg_prefill_infight_queue
293
327
  if self.enable_overlap:
294
328
  # wait
295
- _, next_token_ids, _ = self.tp_worker.resolve_last_batch_result(launch_done)
329
+ logits_output, next_token_ids, _ = self.tp_worker.resolve_last_batch_result(
330
+ launch_done
331
+ )
296
332
  else:
297
333
  next_token_ids = result.next_token_ids.tolist()
298
-
299
- for req, next_token_id in zip(batch.reqs, next_token_ids, strict=True):
334
+ if batch.return_logprob:
335
+ if logits_output.next_token_logprobs is not None:
336
+ logits_output.next_token_logprobs = (
337
+ logits_output.next_token_logprobs.tolist()
338
+ )
339
+ if logits_output.input_token_logprobs is not None:
340
+ logits_output.input_token_logprobs = tuple(
341
+ logits_output.input_token_logprobs.tolist()
342
+ )
343
+ for i, (req, next_token_id) in enumerate(
344
+ zip(batch.reqs, next_token_ids, strict=True)
345
+ ):
300
346
  req: Req
301
347
  if req.is_chunked <= 0:
302
348
  # There is no output_ids for prefill
303
349
  req.output_ids.append(next_token_id)
304
350
  self.tree_cache.cache_unfinished_req(req) # update the tree and lock
305
- self.send_kv_chunk(req, token_id=next_token_id)
306
351
  self.disagg_prefill_inflight_queue.append(req)
352
+ if req.return_logprob:
353
+ assert extend_logprob_start_len_per_req is not None
354
+ assert extend_input_len_per_req is not None
355
+ extend_logprob_start_len = extend_logprob_start_len_per_req[i]
356
+ extend_input_len = extend_input_len_per_req[i]
357
+ num_input_logprobs = extend_input_len - extend_logprob_start_len
358
+ self.add_logprob_return_values(
359
+ i,
360
+ req,
361
+ logprob_pt,
362
+ next_token_ids,
363
+ num_input_logprobs,
364
+ logits_output,
365
+ )
366
+ logprob_pt += num_input_logprobs
367
+ self.send_kv_chunk(req, last_chunk=True)
368
+
369
+ if req.grammar is not None:
370
+ req.grammar.accept_token(next_token_id)
371
+ req.grammar.finished = req.finished()
307
372
  else:
308
373
  # being chunked reqs' prefill is not finished
309
374
  req.is_chunked -= 1
310
375
 
376
+ if req.return_logprob:
377
+ extend_logprob_start_len = extend_logprob_start_len_per_req[i]
378
+ extend_input_len = extend_input_len_per_req[i]
379
+ if extend_logprob_start_len < extend_input_len:
380
+ # Update input logprobs.
381
+ num_input_logprobs = extend_input_len - extend_logprob_start_len
382
+ self.add_input_logprob_return_values(
383
+ i,
384
+ req,
385
+ logits_output,
386
+ logprob_pt,
387
+ num_input_logprobs,
388
+ last_prefill_chunk=False,
389
+ )
390
+ logprob_pt += num_input_logprobs
391
+
311
392
  if self.enable_overlap:
312
- self.send_kv_chunk(req, end_idx=req.tmp_end_idx)
393
+ self.send_kv_chunk(req, last_chunk=False, end_idx=req.tmp_end_idx)
394
+
395
+ # We need to remove the sync in the following function for overlap schedule.
396
+ self.set_next_batch_sampling_info_done(batch)
313
397
 
314
398
  def process_disagg_prefill_inflight_queue(self: Scheduler) -> None:
315
399
  """
@@ -335,7 +419,17 @@ class SchedulerDisaggregationPrefillMixin:
335
419
  # FIXME: clean up req's data in transfer engine
336
420
  done_reqs.append(req)
337
421
  elif poll == KVPoll.Failed:
338
- raise Exception("Transferring failed")
422
+ error_message = f"Prefill transfer failed for request rank={self.tp_rank} {req.rid=} {req.bootstrap_room=}"
423
+ try:
424
+ req.disagg_kv_sender.failure_exception()
425
+ except Exception as e:
426
+ error_message += f" with exception {e}"
427
+ logger.warning(error_message)
428
+ self.tree_cache.cache_finished_req(req) # unlock the tree
429
+ prepare_abort(
430
+ req, error_message, status_code=HTTPStatus.INTERNAL_SERVER_ERROR
431
+ )
432
+ done_reqs.append(req)
339
433
 
340
434
  for req in done_reqs:
341
435
  self.disagg_prefill_bootstrap_queue.req_to_metadata_buffer_idx_allocator.free(
@@ -343,7 +437,11 @@ class SchedulerDisaggregationPrefillMixin:
343
437
  )
344
438
 
345
439
  # Stream requests which have finished transfer
346
- self.stream_output(done_reqs, False, None)
440
+ self.stream_output(
441
+ done_reqs,
442
+ any(req.return_logprob for req in done_reqs),
443
+ None,
444
+ )
347
445
 
348
446
  self.disagg_prefill_inflight_queue = undone_reqs
349
447
 
@@ -369,7 +467,7 @@ class SchedulerDisaggregationPrefillMixin:
369
467
  def send_kv_chunk(
370
468
  self: Scheduler,
371
469
  req: Req,
372
- token_id: Optional[int] = None,
470
+ last_chunk: bool = False,
373
471
  end_idx: Optional[int] = None,
374
472
  ) -> None:
375
473
  """
@@ -377,44 +475,28 @@ class SchedulerDisaggregationPrefillMixin:
377
475
  """
378
476
  page_size = self.token_to_kv_pool_allocator.page_size
379
477
  start_idx = req.start_send_idx
380
- # if end_idx is specified, use it as the end index of the kv chunk because in overlap schedule,
381
- # the resolved length is not the same as fill_ids's length
382
478
  end_idx = (
383
479
  end_idx
384
480
  if end_idx is not None
385
481
  else min(len(req.fill_ids), len(req.origin_input_ids))
386
482
  )
387
- last_chunk = token_id is not None
388
483
 
389
- if (not last_chunk) and (
390
- end_idx % page_size != 0
391
- ): # todo: remove the second condition
484
+ if not last_chunk:
392
485
  # if not the last chunk and the last page is partial, delay the last partial page to the next send
393
486
  end_idx = end_idx - end_idx % page_size
394
487
 
395
- # Update next start_send_idx
396
- req.start_send_idx = end_idx
397
-
398
488
  kv_indices = (
399
489
  self.req_to_token_pool.req_to_token[req.req_pool_idx, start_idx:end_idx]
400
490
  .cpu()
401
491
  .numpy()
402
492
  )
403
- if last_chunk is True:
404
- self.disagg_prefill_bootstrap_queue.store_prefill_results(
405
- req.metadata_buffer_index, token_id
406
- )
493
+ req.start_send_idx = end_idx
494
+ if last_chunk:
495
+ self.disagg_metadata_buffers.set_buf(req)
407
496
  page_indices = kv_to_page_indices(kv_indices, page_size)
408
-
409
- page_start_idx = start_idx // page_size
410
- page_end_idx = page_start_idx + len(page_indices)
411
-
412
497
  if len(page_indices) == 0:
413
498
  logger.info(
414
499
  f"Skip sending kv chunk for request {req.rid=} {req.bootstrap_room=} because page_indices is empty"
415
500
  )
416
501
  return
417
-
418
- req.disagg_kv_sender.send(
419
- page_indices, slice(page_start_idx, page_end_idx), last_chunk
420
- )
502
+ req.disagg_kv_sender.send(page_indices)
@@ -1,10 +1,12 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import dataclasses
4
+ import os
5
+ import random
4
6
  import warnings
5
7
  from collections import deque
6
8
  from enum import Enum
7
- from typing import List, Optional
9
+ from typing import TYPE_CHECKING, List, Optional
8
10
 
9
11
  import numpy as np
10
12
  import requests
@@ -13,6 +15,14 @@ import torch.distributed as dist
13
15
 
14
16
  from sglang.srt.utils import get_ip
15
17
 
18
+ if TYPE_CHECKING:
19
+ from sglang.srt.managers.schedule_batch import Req
20
+
21
+ FakeBootstrapHost = "2.2.2.2"
22
+
23
+ # env var for testing failure, convert to float explicitly
24
+ FAILURE_PROB = float(os.getenv("DISAGGREGATION_TEST_FAILURE_PROB", 0))
25
+
16
26
 
17
27
  class DisaggregationMode(Enum):
18
28
  NULL = "null"
@@ -20,11 +30,17 @@ class DisaggregationMode(Enum):
20
30
  DECODE = "decode"
21
31
 
22
32
 
23
- FakeBootstrapHost = "2.2.2.2"
24
-
25
-
26
33
  def poll_and_all_reduce(pollers, gloo_group):
27
- polls = [int(poller.poll()) for poller in pollers]
34
+ # at a certain prob, the poll is failed to simulate failure
35
+ if FAILURE_PROB > 0:
36
+ from sglang.srt.disaggregation.base import KVPoll
37
+
38
+ polls = [
39
+ int(KVPoll.Failed) if random.random() < FAILURE_PROB else int(poller.poll())
40
+ for poller in pollers
41
+ ]
42
+ else:
43
+ polls = [int(poller.poll()) for poller in pollers]
28
44
  tensor_to_reduce = torch.tensor(polls, dtype=torch.uint8, device="cpu")
29
45
  dist.all_reduce(tensor_to_reduce, op=dist.ReduceOp.MIN, group=gloo_group)
30
46
  return tensor_to_reduce.tolist()
@@ -168,3 +184,98 @@ def is_mla_backend(target_kv_pool) -> bool:
168
184
  from sglang.srt.mem_cache.memory_pool import MLATokenToKVPool
169
185
 
170
186
  return isinstance(target_kv_pool, MLATokenToKVPool)
187
+
188
+
189
+ def prepare_abort(req: Req, error_message: str, status_code=None):
190
+ from sglang.srt.managers.schedule_batch import FINISH_ABORT
191
+
192
+ # populate finish metadata and stream output
193
+ req.finished_reason = FINISH_ABORT(error_message, status_code)
194
+
195
+ if req.return_logprob:
196
+ req.input_token_logprobs_val = []
197
+ req.input_token_logprobs_idx = []
198
+ req.input_top_logprobs_val = []
199
+ req.input_top_logprobs_idx = []
200
+ req.input_token_ids_logprobs_val = []
201
+ req.input_token_ids_logprobs_idx = []
202
+
203
+
204
+ class MetadataBuffers:
205
+ def __init__(self, size: int, max_top_logprobs_num: int = 128):
206
+ # TODO: abort top_logprobs_num > 128 in PD
207
+
208
+ # We transfer the metadata of first output token to decode
209
+ # The minimal size for RDMA is 64Bytes, so we pad it to > 64Bytes
210
+ self.output_ids = torch.zeros((size, 16), dtype=torch.int32, device="cpu")
211
+ self.output_token_logprobs_val = torch.zeros(
212
+ (size, 16), dtype=torch.float32, device="cpu"
213
+ )
214
+ self.output_token_logprobs_idx = torch.zeros(
215
+ (size, 16), dtype=torch.int32, device="cpu"
216
+ )
217
+ self.output_top_logprobs_val = torch.zeros(
218
+ (size, max_top_logprobs_num), dtype=torch.float32, device="cpu"
219
+ )
220
+ self.output_top_logprobs_idx = torch.zeros(
221
+ (size, max_top_logprobs_num), dtype=torch.int32, device="cpu"
222
+ )
223
+
224
+ def get_buf_infos(self):
225
+ ptrs = [
226
+ self.output_ids.data_ptr(),
227
+ self.output_token_logprobs_val.data_ptr(),
228
+ self.output_token_logprobs_idx.data_ptr(),
229
+ self.output_top_logprobs_val.data_ptr(),
230
+ self.output_top_logprobs_idx.data_ptr(),
231
+ ]
232
+ data_lens = [
233
+ self.output_ids.nbytes,
234
+ self.output_token_logprobs_val.nbytes,
235
+ self.output_token_logprobs_idx.nbytes,
236
+ self.output_top_logprobs_val.nbytes,
237
+ self.output_top_logprobs_idx.nbytes,
238
+ ]
239
+ item_lens = [
240
+ self.output_ids[0].nbytes,
241
+ self.output_token_logprobs_val[0].nbytes,
242
+ self.output_token_logprobs_idx[0].nbytes,
243
+ self.output_top_logprobs_val[0].nbytes,
244
+ self.output_top_logprobs_idx[0].nbytes,
245
+ ]
246
+ return ptrs, data_lens, item_lens
247
+
248
+ def get_buf(self, idx: int):
249
+ return (
250
+ self.output_ids[idx],
251
+ self.output_token_logprobs_val[idx],
252
+ self.output_token_logprobs_idx[idx],
253
+ self.output_top_logprobs_val[idx],
254
+ self.output_top_logprobs_idx[idx],
255
+ )
256
+
257
+ def set_buf(self, req: Req):
258
+
259
+ self.output_ids[req.metadata_buffer_index][0] = req.output_ids[0]
260
+ if req.return_logprob:
261
+ if req.output_token_logprobs_val: # not none or empty list
262
+ self.output_token_logprobs_val[req.metadata_buffer_index][0] = (
263
+ req.output_token_logprobs_val[0]
264
+ )
265
+ if req.output_token_logprobs_idx: # not none or empty list
266
+ self.output_token_logprobs_idx[req.metadata_buffer_index][0] = (
267
+ req.output_token_logprobs_idx[0]
268
+ )
269
+
270
+ if req.output_top_logprobs_val: # not none or empty list
271
+ self.output_top_logprobs_val[req.metadata_buffer_index][
272
+ : len(req.output_top_logprobs_val[0])
273
+ ] = torch.tensor(
274
+ req.output_top_logprobs_val[0], dtype=torch.float32, device="cpu"
275
+ )
276
+ if req.output_top_logprobs_idx: # not none or empty list
277
+ self.output_top_logprobs_idx[req.metadata_buffer_index][
278
+ : len(req.output_top_logprobs_idx[0])
279
+ ] = torch.tensor(
280
+ req.output_top_logprobs_idx[0], dtype=torch.int32, device="cpu"
281
+ )
@@ -127,14 +127,14 @@ class StatelessProcessGroup:
127
127
  key = f"send_to/{dst}/{self.send_dst_counter[dst]}"
128
128
  self.store.set(key, pickle.dumps(obj))
129
129
  self.send_dst_counter[dst] += 1
130
- self.entries.append((key, time.time()))
130
+ self.entries.append((key, time.perf_counter()))
131
131
 
132
132
  def expire_data(self):
133
133
  """Expire data that is older than `data_expiration_seconds` seconds."""
134
134
  while self.entries:
135
135
  # check the oldest entry
136
136
  key, timestamp = self.entries[0]
137
- if time.time() - timestamp > self.data_expiration_seconds:
137
+ if time.perf_counter() - timestamp > self.data_expiration_seconds:
138
138
  self.store.delete_key(key)
139
139
  self.entries.popleft()
140
140
  else:
@@ -158,7 +158,7 @@ class StatelessProcessGroup:
158
158
  key = f"broadcast_from/{src}/" f"{self.broadcast_send_counter}"
159
159
  self.store.set(key, pickle.dumps(obj))
160
160
  self.broadcast_send_counter += 1
161
- self.entries.append((key, time.time()))
161
+ self.entries.append((key, time.perf_counter()))
162
162
  return obj
163
163
  else:
164
164
  key = f"broadcast_from/{src}/" f"{self.broadcast_recv_src_counter[src]}"
@@ -27,6 +27,11 @@ class EngineBase(ABC):
27
27
  """Generate outputs based on given inputs."""
28
28
  pass
29
29
 
30
+ @abstractmethod
31
+ def flush_cache(self):
32
+ """Flush the cache of the engine."""
33
+ pass
34
+
30
35
  @abstractmethod
31
36
  def update_weights_from_tensor(
32
37
  self,
@@ -47,6 +47,7 @@ from sglang.srt.managers.io_struct import (
47
47
  EmbeddingReqInput,
48
48
  GenerateReqInput,
49
49
  GetWeightsByNameReqInput,
50
+ ImageDataItem,
50
51
  InitWeightsUpdateGroupReqInput,
51
52
  ReleaseMemoryOccupationReqInput,
52
53
  ResumeMemoryOccupationReqInput,
@@ -150,9 +151,9 @@ class Engine(EngineBase):
150
151
  # See also python/sglang/srt/utils.py:load_image for more details.
151
152
  image_data: Optional[
152
153
  Union[
153
- List[List[Union[Image, str]]],
154
- List[Union[Image, str]],
155
- Union[Image, str],
154
+ List[List[ImageDataItem]],
155
+ List[ImageDataItem],
156
+ ImageDataItem,
156
157
  ]
157
158
  ] = None,
158
159
  return_logprob: Optional[Union[List[bool], bool]] = False,
@@ -221,9 +222,9 @@ class Engine(EngineBase):
221
222
  # See also python/sglang/srt/utils.py:load_image for more details.
222
223
  image_data: Optional[
223
224
  Union[
224
- List[List[Union[Image, str]]],
225
- List[Union[Image, str]],
226
- Union[Image, str],
225
+ List[List[ImageDataItem]],
226
+ List[ImageDataItem],
227
+ ImageDataItem,
227
228
  ]
228
229
  ] = None,
229
230
  return_logprob: Optional[Union[List[bool], bool]] = False,
@@ -320,7 +321,26 @@ class Engine(EngineBase):
320
321
  loop.run_until_complete(self.tokenizer_manager.start_profile())
321
322
 
322
323
  def stop_profile(self):
323
- self.tokenizer_manager.stop_profile()
324
+ loop = asyncio.get_event_loop()
325
+ loop.run_until_complete(self.tokenizer_manager.stop_profile())
326
+
327
+ def start_expert_distribution_record(self):
328
+ loop = asyncio.get_event_loop()
329
+ loop.run_until_complete(
330
+ self.tokenizer_manager.start_expert_distribution_record()
331
+ )
332
+
333
+ def stop_expert_distribution_record(self):
334
+ loop = asyncio.get_event_loop()
335
+ loop.run_until_complete(
336
+ self.tokenizer_manager.stop_expert_distribution_record()
337
+ )
338
+
339
+ def dump_expert_distribution_record(self):
340
+ loop = asyncio.get_event_loop()
341
+ loop.run_until_complete(
342
+ self.tokenizer_manager.dump_expert_distribution_record()
343
+ )
324
344
 
325
345
  def get_server_info(self):
326
346
  loop = asyncio.get_event_loop()
@@ -486,7 +506,7 @@ def _set_envs_and_config(server_args: ServerArgs):
486
506
  if _is_cuda:
487
507
  assert_pkg_version(
488
508
  "sgl-kernel",
489
- "0.1.2.post1",
509
+ "0.1.4",
490
510
  "Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`",
491
511
  )
492
512
 
@@ -47,7 +47,7 @@ from sglang.srt.disaggregation.utils import (
47
47
  register_disaggregation_server,
48
48
  )
49
49
  from sglang.srt.entrypoints.engine import _launch_subprocesses
50
- from sglang.srt.function_call_parser import FunctionCallParser
50
+ from sglang.srt.function_call.function_call_parser import FunctionCallParser
51
51
  from sglang.srt.managers.io_struct import (
52
52
  AbortReq,
53
53
  CloseSessionReqInput,
@@ -182,13 +182,14 @@ async def health_generate(request: Request) -> Response:
182
182
  async for _ in _global_state.tokenizer_manager.generate_request(gri, request):
183
183
  break
184
184
 
185
- tic = time.time()
185
+ tic = time.perf_counter()
186
186
  task = asyncio.create_task(gen())
187
- while time.time() < tic + HEALTH_CHECK_TIMEOUT:
187
+ while time.perf_counter() < tic + HEALTH_CHECK_TIMEOUT:
188
188
  await asyncio.sleep(1)
189
189
  if _global_state.tokenizer_manager.last_receive_tstamp > tic:
190
190
  task.cancel()
191
191
  _global_state.tokenizer_manager.rid_to_state.pop(rid, None)
192
+ _global_state.tokenizer_manager.health_check_failed = False
192
193
  return Response(status_code=200)
193
194
 
194
195
  task.cancel()
@@ -202,6 +203,7 @@ async def health_generate(request: Request) -> Response:
202
203
  f"last_heartbeat time: {last_receive_time}"
203
204
  )
204
205
  _global_state.tokenizer_manager.rid_to_state.pop(rid, None)
206
+ _global_state.tokenizer_manager.health_check_failed = True
205
207
  return Response(status_code=503)
206
208
 
207
209
 
@@ -353,7 +355,7 @@ async def start_profile_async(obj: Optional[ProfileReqInput] = None):
353
355
  @app.api_route("/stop_profile", methods=["GET", "POST"])
354
356
  async def stop_profile_async():
355
357
  """Stop profiling."""
356
- _global_state.tokenizer_manager.stop_profile()
358
+ await _global_state.tokenizer_manager.stop_profile()
357
359
  return Response(
358
360
  content="Stop profiling. This will take some time.\n",
359
361
  status_code=200,