sglang 0.4.6.post3__py3-none-any.whl → 0.4.6.post5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. sglang/bench_offline_throughput.py +10 -8
  2. sglang/bench_one_batch.py +7 -6
  3. sglang/bench_one_batch_server.py +157 -21
  4. sglang/bench_serving.py +137 -59
  5. sglang/compile_deep_gemm.py +5 -5
  6. sglang/eval/loogle_eval.py +157 -0
  7. sglang/lang/chat_template.py +78 -78
  8. sglang/lang/tracer.py +1 -1
  9. sglang/srt/code_completion_parser.py +1 -1
  10. sglang/srt/configs/deepseekvl2.py +2 -2
  11. sglang/srt/configs/model_config.py +40 -28
  12. sglang/srt/constrained/base_grammar_backend.py +55 -72
  13. sglang/srt/constrained/llguidance_backend.py +25 -21
  14. sglang/srt/constrained/outlines_backend.py +27 -26
  15. sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
  16. sglang/srt/constrained/xgrammar_backend.py +69 -43
  17. sglang/srt/conversation.py +49 -44
  18. sglang/srt/disaggregation/base/conn.py +1 -0
  19. sglang/srt/disaggregation/decode.py +129 -135
  20. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +142 -0
  21. sglang/srt/disaggregation/fake/conn.py +3 -13
  22. sglang/srt/disaggregation/kv_events.py +357 -0
  23. sglang/srt/disaggregation/mini_lb.py +57 -24
  24. sglang/srt/disaggregation/mooncake/conn.py +238 -122
  25. sglang/srt/disaggregation/mooncake/transfer_engine.py +2 -1
  26. sglang/srt/disaggregation/nixl/conn.py +10 -19
  27. sglang/srt/disaggregation/prefill.py +132 -47
  28. sglang/srt/disaggregation/utils.py +123 -6
  29. sglang/srt/distributed/utils.py +3 -3
  30. sglang/srt/entrypoints/EngineBase.py +5 -0
  31. sglang/srt/entrypoints/engine.py +44 -9
  32. sglang/srt/entrypoints/http_server.py +23 -6
  33. sglang/srt/entrypoints/http_server_engine.py +5 -2
  34. sglang/srt/function_call/base_format_detector.py +250 -0
  35. sglang/srt/function_call/core_types.py +34 -0
  36. sglang/srt/function_call/deepseekv3_detector.py +157 -0
  37. sglang/srt/function_call/ebnf_composer.py +234 -0
  38. sglang/srt/function_call/function_call_parser.py +175 -0
  39. sglang/srt/function_call/llama32_detector.py +74 -0
  40. sglang/srt/function_call/mistral_detector.py +84 -0
  41. sglang/srt/function_call/pythonic_detector.py +163 -0
  42. sglang/srt/function_call/qwen25_detector.py +67 -0
  43. sglang/srt/function_call/utils.py +35 -0
  44. sglang/srt/hf_transformers_utils.py +46 -7
  45. sglang/srt/layers/attention/aiter_backend.py +513 -0
  46. sglang/srt/layers/attention/flashattention_backend.py +64 -18
  47. sglang/srt/layers/attention/flashinfer_mla_backend.py +8 -4
  48. sglang/srt/layers/attention/flashmla_backend.py +340 -78
  49. sglang/srt/layers/attention/triton_backend.py +3 -0
  50. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
  51. sglang/srt/layers/attention/utils.py +6 -4
  52. sglang/srt/layers/attention/vision.py +1 -1
  53. sglang/srt/layers/communicator.py +451 -0
  54. sglang/srt/layers/dp_attention.py +61 -21
  55. sglang/srt/layers/layernorm.py +1 -1
  56. sglang/srt/layers/logits_processor.py +46 -11
  57. sglang/srt/layers/moe/cutlass_moe.py +207 -0
  58. sglang/srt/layers/moe/ep_moe/kernels.py +34 -12
  59. sglang/srt/layers/moe/ep_moe/layer.py +105 -51
  60. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +82 -7
  61. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -1
  62. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -0
  63. sglang/srt/layers/moe/topk.py +67 -10
  64. sglang/srt/layers/multimodal.py +70 -0
  65. sglang/srt/layers/quantization/__init__.py +8 -3
  66. sglang/srt/layers/quantization/blockwise_int8.py +2 -2
  67. sglang/srt/layers/quantization/deep_gemm.py +77 -74
  68. sglang/srt/layers/quantization/fp8.py +92 -2
  69. sglang/srt/layers/quantization/fp8_kernel.py +3 -3
  70. sglang/srt/layers/quantization/fp8_utils.py +6 -0
  71. sglang/srt/layers/quantization/gptq.py +298 -6
  72. sglang/srt/layers/quantization/int8_kernel.py +20 -7
  73. sglang/srt/layers/quantization/qoq.py +244 -0
  74. sglang/srt/layers/sampler.py +0 -4
  75. sglang/srt/layers/vocab_parallel_embedding.py +18 -7
  76. sglang/srt/lora/lora_manager.py +2 -4
  77. sglang/srt/lora/mem_pool.py +4 -4
  78. sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
  79. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  80. sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
  81. sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
  82. sglang/srt/lora/utils.py +1 -1
  83. sglang/srt/managers/data_parallel_controller.py +3 -3
  84. sglang/srt/managers/deepseek_eplb.py +278 -0
  85. sglang/srt/managers/detokenizer_manager.py +21 -8
  86. sglang/srt/managers/eplb_manager.py +55 -0
  87. sglang/srt/managers/expert_distribution.py +704 -56
  88. sglang/srt/managers/expert_location.py +394 -0
  89. sglang/srt/managers/expert_location_dispatch.py +91 -0
  90. sglang/srt/managers/io_struct.py +19 -4
  91. sglang/srt/managers/mm_utils.py +294 -140
  92. sglang/srt/managers/multimodal_processors/base_processor.py +127 -42
  93. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +6 -1
  94. sglang/srt/managers/multimodal_processors/gemma3.py +31 -6
  95. sglang/srt/managers/multimodal_processors/internvl.py +14 -5
  96. sglang/srt/managers/multimodal_processors/janus_pro.py +7 -1
  97. sglang/srt/managers/multimodal_processors/kimi_vl.py +7 -6
  98. sglang/srt/managers/multimodal_processors/llava.py +46 -0
  99. sglang/srt/managers/multimodal_processors/minicpm.py +25 -31
  100. sglang/srt/managers/multimodal_processors/mllama4.py +6 -0
  101. sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
  102. sglang/srt/managers/multimodal_processors/qwen_vl.py +58 -16
  103. sglang/srt/managers/schedule_batch.py +122 -42
  104. sglang/srt/managers/schedule_policy.py +1 -5
  105. sglang/srt/managers/scheduler.py +205 -138
  106. sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
  107. sglang/srt/managers/session_controller.py +1 -1
  108. sglang/srt/managers/tokenizer_manager.py +232 -58
  109. sglang/srt/managers/tp_worker.py +12 -9
  110. sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
  111. sglang/srt/mem_cache/base_prefix_cache.py +3 -0
  112. sglang/srt/mem_cache/chunk_cache.py +3 -1
  113. sglang/srt/mem_cache/hiradix_cache.py +4 -4
  114. sglang/srt/mem_cache/memory_pool.py +76 -52
  115. sglang/srt/mem_cache/multimodal_cache.py +45 -0
  116. sglang/srt/mem_cache/radix_cache.py +58 -5
  117. sglang/srt/metrics/collector.py +314 -39
  118. sglang/srt/mm_utils.py +10 -0
  119. sglang/srt/model_executor/cuda_graph_runner.py +29 -19
  120. sglang/srt/model_executor/expert_location_updater.py +422 -0
  121. sglang/srt/model_executor/forward_batch_info.py +5 -1
  122. sglang/srt/model_executor/model_runner.py +163 -68
  123. sglang/srt/model_loader/loader.py +10 -6
  124. sglang/srt/models/clip.py +5 -1
  125. sglang/srt/models/deepseek_janus_pro.py +2 -2
  126. sglang/srt/models/deepseek_v2.py +308 -351
  127. sglang/srt/models/exaone.py +8 -3
  128. sglang/srt/models/gemma3_mm.py +70 -33
  129. sglang/srt/models/llama.py +2 -0
  130. sglang/srt/models/llama4.py +15 -8
  131. sglang/srt/models/llava.py +258 -7
  132. sglang/srt/models/mimo_mtp.py +220 -0
  133. sglang/srt/models/minicpmo.py +5 -12
  134. sglang/srt/models/mistral.py +71 -1
  135. sglang/srt/models/mixtral.py +98 -34
  136. sglang/srt/models/mllama.py +3 -3
  137. sglang/srt/models/pixtral.py +467 -0
  138. sglang/srt/models/qwen2.py +95 -26
  139. sglang/srt/models/qwen2_5_vl.py +8 -0
  140. sglang/srt/models/qwen2_moe.py +330 -60
  141. sglang/srt/models/qwen2_vl.py +6 -0
  142. sglang/srt/models/qwen3.py +52 -10
  143. sglang/srt/models/qwen3_moe.py +411 -48
  144. sglang/srt/models/roberta.py +1 -1
  145. sglang/srt/models/siglip.py +294 -0
  146. sglang/srt/models/torch_native_llama.py +1 -1
  147. sglang/srt/openai_api/adapter.py +58 -20
  148. sglang/srt/openai_api/protocol.py +6 -8
  149. sglang/srt/operations.py +154 -0
  150. sglang/srt/operations_strategy.py +31 -0
  151. sglang/srt/reasoning_parser.py +3 -3
  152. sglang/srt/sampling/custom_logit_processor.py +18 -3
  153. sglang/srt/sampling/sampling_batch_info.py +4 -56
  154. sglang/srt/sampling/sampling_params.py +2 -2
  155. sglang/srt/server_args.py +162 -22
  156. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
  157. sglang/srt/speculative/eagle_utils.py +138 -7
  158. sglang/srt/speculative/eagle_worker.py +69 -21
  159. sglang/srt/utils.py +74 -17
  160. sglang/test/few_shot_gsm8k.py +2 -2
  161. sglang/test/few_shot_gsm8k_engine.py +2 -2
  162. sglang/test/run_eval.py +2 -2
  163. sglang/test/runners.py +8 -1
  164. sglang/test/send_one.py +13 -3
  165. sglang/test/simple_eval_common.py +1 -1
  166. sglang/test/simple_eval_humaneval.py +1 -1
  167. sglang/test/test_cutlass_moe.py +278 -0
  168. sglang/test/test_programs.py +5 -5
  169. sglang/test/test_utils.py +55 -14
  170. sglang/utils.py +3 -3
  171. sglang/version.py +1 -1
  172. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/METADATA +23 -13
  173. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/RECORD +178 -149
  174. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/WHEEL +1 -1
  175. sglang/srt/function_call_parser.py +0 -858
  176. sglang/srt/platforms/interface.py +0 -371
  177. /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
  178. /sglang/srt/models/{xiaomi_mimo.py → mimo.py} +0 -0
  179. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/licenses/LICENSE +0 -0
  180. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/top_level.txt +0 -0
@@ -22,6 +22,7 @@ from __future__ import annotations
22
22
  import logging
23
23
  import threading
24
24
  from collections import deque
25
+ from http import HTTPStatus
25
26
  from typing import TYPE_CHECKING, List, Optional
26
27
 
27
28
  import torch
@@ -31,14 +32,18 @@ from sglang.srt.disaggregation.utils import (
31
32
  DisaggregationMode,
32
33
  FakeBootstrapHost,
33
34
  KVClassType,
35
+ MetadataBuffers,
34
36
  ReqToMetadataIdxAllocator,
35
37
  TransferBackend,
36
38
  get_kv_class,
39
+ is_mla_backend,
37
40
  kv_to_page_indices,
38
41
  kv_to_page_num,
39
42
  poll_and_all_reduce,
43
+ prepare_abort,
40
44
  )
41
45
  from sglang.srt.managers.schedule_batch import FINISH_LENGTH, Req, ScheduleBatch
46
+ from sglang.srt.model_executor.forward_batch_info import ForwardMode
42
47
 
43
48
  if TYPE_CHECKING:
44
49
  from torch.distributed import ProcessGroup
@@ -58,9 +63,9 @@ class PrefillBootstrapQueue:
58
63
  def __init__(
59
64
  self,
60
65
  token_to_kv_pool: KVCache,
66
+ draft_token_to_kv_pool: Optional[KVCache],
61
67
  req_to_metadata_buffer_idx_allocator: ReqToMetadataIdxAllocator,
62
- metadata_buffers: List[torch.Tensor],
63
- aux_dtype: torch.dtype,
68
+ metadata_buffers: MetadataBuffers,
64
69
  tp_rank: int,
65
70
  tp_size: int,
66
71
  bootstrap_port: int,
@@ -69,7 +74,9 @@ class PrefillBootstrapQueue:
69
74
  scheduler: Scheduler,
70
75
  ):
71
76
  self.token_to_kv_pool = token_to_kv_pool
72
- self.aux_dtype = aux_dtype
77
+ self.draft_token_to_kv_pool = draft_token_to_kv_pool
78
+
79
+ self.is_mla_backend = is_mla_backend(token_to_kv_pool)
73
80
 
74
81
  self.metadata_buffers = metadata_buffers
75
82
  self.req_to_metadata_buffer_idx_allocator = req_to_metadata_buffer_idx_allocator
@@ -94,25 +101,32 @@ class PrefillBootstrapQueue:
94
101
  self.token_to_kv_pool.get_contiguous_buf_infos()
95
102
  )
96
103
 
104
+ if self.draft_token_to_kv_pool is not None:
105
+ # We should also transfer draft model kv cache. The indices are
106
+ # always shared with a target model.
107
+ draft_kv_data_ptrs, draft_kv_data_lens, draft_kv_item_lens = (
108
+ self.draft_token_to_kv_pool.get_contiguous_buf_infos()
109
+ )
110
+ kv_data_ptrs += draft_kv_data_ptrs
111
+ kv_data_lens += draft_kv_data_lens
112
+ kv_item_lens += draft_kv_item_lens
113
+
97
114
  kv_args.kv_data_ptrs = kv_data_ptrs
98
115
  kv_args.kv_data_lens = kv_data_lens
99
116
  kv_args.kv_item_lens = kv_item_lens
100
117
 
101
118
  # Define req -> input ids buffer
102
- kv_args.aux_data_ptrs = [
103
- metadata_buffer.data_ptr() for metadata_buffer in self.metadata_buffers
104
- ]
105
- kv_args.aux_data_lens = [
106
- metadata_buffer.nbytes for metadata_buffer in self.metadata_buffers
107
- ]
108
- kv_args.aux_item_lens = [
109
- metadata_buffer[0].nbytes for metadata_buffer in self.metadata_buffers
110
- ]
119
+ kv_args.aux_data_ptrs, kv_args.aux_data_lens, kv_args.aux_item_lens = (
120
+ self.metadata_buffers.get_buf_infos()
121
+ )
111
122
  kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device
112
123
  kv_args.gpu_id = self.scheduler.gpu_id
113
124
  kv_manager_class = get_kv_class(self.transfer_backend, KVClassType.MANAGER)
114
125
  kv_manager = kv_manager_class(
115
- kv_args, DisaggregationMode.PREFILL, self.scheduler.server_args
126
+ kv_args,
127
+ DisaggregationMode.PREFILL,
128
+ self.scheduler.server_args,
129
+ self.is_mla_backend,
116
130
  )
117
131
  return kv_manager
118
132
 
@@ -130,6 +144,10 @@ class PrefillBootstrapQueue:
130
144
  self._process_req(req)
131
145
  self.queue.append(req)
132
146
 
147
+ def extend(self, reqs: List[Req]) -> None:
148
+ for req in reqs:
149
+ self.add(req)
150
+
133
151
  def _process_req(self, req: Req) -> None:
134
152
  """
135
153
  Set max_new_tokens = 1, so PrefillAdder memory estimation is accurate
@@ -152,7 +170,18 @@ class PrefillBootstrapQueue:
152
170
  if poll == KVPoll.Bootstrapping:
153
171
  continue
154
172
  elif poll == KVPoll.Failed:
155
- raise Exception("Bootstrap failed")
173
+ error_message = f"Prefill bootstrap failed for request rank={self.tp_rank} {req.rid=} {req.bootstrap_room=}"
174
+ try:
175
+ req.disagg_kv_sender.failure_exception()
176
+ except Exception as e:
177
+ error_message += f" with exception {e}"
178
+ logger.error(error_message)
179
+ prepare_abort(
180
+ req, error_message, status_code=HTTPStatus.INTERNAL_SERVER_ERROR
181
+ )
182
+ self.scheduler.stream_output([req], req.return_logprob)
183
+ indices_to_remove.add(i)
184
+ continue
156
185
 
157
186
  # KV.WaitingForInput
158
187
  num_kv_indices = len(req.origin_input_ids)
@@ -245,6 +274,16 @@ class SchedulerDisaggregationPrefillMixin:
245
274
  result = self.run_batch(batch)
246
275
  self.result_queue.append((batch.copy(), result))
247
276
 
277
+ if self.last_batch is None:
278
+ # Create a dummy first batch to start the pipeline for overlap schedule.
279
+ # It is now used for triggering the sampling_info_done event.
280
+ tmp_batch = ScheduleBatch(
281
+ reqs=None,
282
+ forward_mode=ForwardMode.DUMMY_FIRST,
283
+ next_batch_sampling_info=self.tp_worker.cur_sampling_info,
284
+ )
285
+ self.set_next_batch_sampling_info_done(tmp_batch)
286
+
248
287
  if self.last_batch:
249
288
  tmp_batch, tmp_result = self.result_queue.popleft()
250
289
  self.process_batch_result_disagg_prefill(tmp_batch, tmp_result)
@@ -268,45 +307,93 @@ class SchedulerDisaggregationPrefillMixin:
268
307
  launch_done: Optional[threading.Event] = None,
269
308
  ) -> None:
270
309
  """
271
- Transfer kv for prefill completed requests and add it into disagg_prefill_inflight_queue
310
+ Transfer kv for prefill completed requests and add it into disagg_prefill_infight_queue
272
311
  Adapted from process_batch_result_prefill
273
312
  """
274
-
275
313
  (
276
314
  logits_output,
277
315
  next_token_ids,
278
316
  extend_input_len_per_req,
279
317
  extend_logprob_start_len_per_req,
280
- bid,
281
318
  ) = (
282
319
  result.logits_output,
283
320
  result.next_token_ids,
284
321
  result.extend_input_len_per_req,
285
322
  result.extend_logprob_start_len_per_req,
286
- result.bid,
287
323
  )
288
324
 
325
+ logprob_pt = 0
289
326
  # Transfer kv for prefill completed requests and add it into disagg_prefill_infight_queue
290
327
  if self.enable_overlap:
291
328
  # wait
292
- _, next_token_ids = self.tp_worker.resolve_last_batch_result(launch_done)
329
+ logits_output, next_token_ids, _ = self.tp_worker.resolve_last_batch_result(
330
+ launch_done
331
+ )
293
332
  else:
294
333
  next_token_ids = result.next_token_ids.tolist()
295
-
296
- for req, next_token_id in zip(batch.reqs, next_token_ids, strict=True):
334
+ if batch.return_logprob:
335
+ if logits_output.next_token_logprobs is not None:
336
+ logits_output.next_token_logprobs = (
337
+ logits_output.next_token_logprobs.tolist()
338
+ )
339
+ if logits_output.input_token_logprobs is not None:
340
+ logits_output.input_token_logprobs = tuple(
341
+ logits_output.input_token_logprobs.tolist()
342
+ )
343
+ for i, (req, next_token_id) in enumerate(
344
+ zip(batch.reqs, next_token_ids, strict=True)
345
+ ):
297
346
  req: Req
298
347
  if req.is_chunked <= 0:
299
348
  # There is no output_ids for prefill
300
349
  req.output_ids.append(next_token_id)
301
350
  self.tree_cache.cache_unfinished_req(req) # update the tree and lock
302
- self.send_kv_chunk(req, token_id=next_token_id)
303
351
  self.disagg_prefill_inflight_queue.append(req)
352
+ if req.return_logprob:
353
+ assert extend_logprob_start_len_per_req is not None
354
+ assert extend_input_len_per_req is not None
355
+ extend_logprob_start_len = extend_logprob_start_len_per_req[i]
356
+ extend_input_len = extend_input_len_per_req[i]
357
+ num_input_logprobs = extend_input_len - extend_logprob_start_len
358
+ self.add_logprob_return_values(
359
+ i,
360
+ req,
361
+ logprob_pt,
362
+ next_token_ids,
363
+ num_input_logprobs,
364
+ logits_output,
365
+ )
366
+ logprob_pt += num_input_logprobs
367
+ self.send_kv_chunk(req, last_chunk=True)
368
+
369
+ if req.grammar is not None:
370
+ req.grammar.accept_token(next_token_id)
371
+ req.grammar.finished = req.finished()
304
372
  else:
305
373
  # being chunked reqs' prefill is not finished
306
374
  req.is_chunked -= 1
307
375
 
376
+ if req.return_logprob:
377
+ extend_logprob_start_len = extend_logprob_start_len_per_req[i]
378
+ extend_input_len = extend_input_len_per_req[i]
379
+ if extend_logprob_start_len < extend_input_len:
380
+ # Update input logprobs.
381
+ num_input_logprobs = extend_input_len - extend_logprob_start_len
382
+ self.add_input_logprob_return_values(
383
+ i,
384
+ req,
385
+ logits_output,
386
+ logprob_pt,
387
+ num_input_logprobs,
388
+ last_prefill_chunk=False,
389
+ )
390
+ logprob_pt += num_input_logprobs
391
+
308
392
  if self.enable_overlap:
309
- self.send_kv_chunk(req, end_idx=req.tmp_end_idx)
393
+ self.send_kv_chunk(req, last_chunk=False, end_idx=req.tmp_end_idx)
394
+
395
+ # We need to remove the sync in the following function for overlap schedule.
396
+ self.set_next_batch_sampling_info_done(batch)
310
397
 
311
398
  def process_disagg_prefill_inflight_queue(self: Scheduler) -> None:
312
399
  """
@@ -332,7 +419,17 @@ class SchedulerDisaggregationPrefillMixin:
332
419
  # FIXME: clean up req's data in transfer engine
333
420
  done_reqs.append(req)
334
421
  elif poll == KVPoll.Failed:
335
- raise Exception("Transferring failed")
422
+ error_message = f"Prefill transfer failed for request rank={self.tp_rank} {req.rid=} {req.bootstrap_room=}"
423
+ try:
424
+ req.disagg_kv_sender.failure_exception()
425
+ except Exception as e:
426
+ error_message += f" with exception {e}"
427
+ logger.warning(error_message)
428
+ self.tree_cache.cache_finished_req(req) # unlock the tree
429
+ prepare_abort(
430
+ req, error_message, status_code=HTTPStatus.INTERNAL_SERVER_ERROR
431
+ )
432
+ done_reqs.append(req)
336
433
 
337
434
  for req in done_reqs:
338
435
  self.disagg_prefill_bootstrap_queue.req_to_metadata_buffer_idx_allocator.free(
@@ -340,7 +437,11 @@ class SchedulerDisaggregationPrefillMixin:
340
437
  )
341
438
 
342
439
  # Stream requests which have finished transfer
343
- self.stream_output(done_reqs, False, None)
440
+ self.stream_output(
441
+ done_reqs,
442
+ any(req.return_logprob for req in done_reqs),
443
+ None,
444
+ )
344
445
 
345
446
  self.disagg_prefill_inflight_queue = undone_reqs
346
447
 
@@ -366,7 +467,7 @@ class SchedulerDisaggregationPrefillMixin:
366
467
  def send_kv_chunk(
367
468
  self: Scheduler,
368
469
  req: Req,
369
- token_id: Optional[int] = None,
470
+ last_chunk: bool = False,
370
471
  end_idx: Optional[int] = None,
371
472
  ) -> None:
372
473
  """
@@ -374,44 +475,28 @@ class SchedulerDisaggregationPrefillMixin:
374
475
  """
375
476
  page_size = self.token_to_kv_pool_allocator.page_size
376
477
  start_idx = req.start_send_idx
377
- # if end_idx is specified, use it as the end index of the kv chunk because in overlap schedule,
378
- # the resolved length is not the same as fill_ids's length
379
478
  end_idx = (
380
479
  end_idx
381
480
  if end_idx is not None
382
481
  else min(len(req.fill_ids), len(req.origin_input_ids))
383
482
  )
384
- last_chunk = token_id is not None
385
483
 
386
- if (not last_chunk) and (
387
- end_idx % page_size != 0
388
- ): # todo: remove the second condition
484
+ if not last_chunk:
389
485
  # if not the last chunk and the last page is partial, delay the last partial page to the next send
390
486
  end_idx = end_idx - end_idx % page_size
391
487
 
392
- # Update next start_send_idx
393
- req.start_send_idx = end_idx
394
-
395
488
  kv_indices = (
396
489
  self.req_to_token_pool.req_to_token[req.req_pool_idx, start_idx:end_idx]
397
490
  .cpu()
398
491
  .numpy()
399
492
  )
400
- if last_chunk is True:
401
- self.disagg_prefill_bootstrap_queue.store_prefill_results(
402
- req.metadata_buffer_index, token_id
403
- )
493
+ req.start_send_idx = end_idx
494
+ if last_chunk:
495
+ self.disagg_metadata_buffers.set_buf(req)
404
496
  page_indices = kv_to_page_indices(kv_indices, page_size)
405
-
406
- page_start_idx = start_idx // page_size
407
- page_end_idx = page_start_idx + len(page_indices)
408
-
409
497
  if len(page_indices) == 0:
410
498
  logger.info(
411
499
  f"Skip sending kv chunk for request {req.rid=} {req.bootstrap_room=} because page_indices is empty"
412
500
  )
413
501
  return
414
-
415
- req.disagg_kv_sender.send(
416
- page_indices, slice(page_start_idx, page_end_idx), last_chunk
417
- )
502
+ req.disagg_kv_sender.send(page_indices)
@@ -1,10 +1,12 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import dataclasses
4
+ import os
5
+ import random
4
6
  import warnings
5
7
  from collections import deque
6
8
  from enum import Enum
7
- from typing import List, Optional
9
+ from typing import TYPE_CHECKING, List, Optional
8
10
 
9
11
  import numpy as np
10
12
  import requests
@@ -13,6 +15,14 @@ import torch.distributed as dist
13
15
 
14
16
  from sglang.srt.utils import get_ip
15
17
 
18
+ if TYPE_CHECKING:
19
+ from sglang.srt.managers.schedule_batch import Req
20
+
21
+ FakeBootstrapHost = "2.2.2.2"
22
+
23
+ # env var for testing failure, convert to float explicitly
24
+ FAILURE_PROB = float(os.getenv("DISAGGREGATION_TEST_FAILURE_PROB", 0))
25
+
16
26
 
17
27
  class DisaggregationMode(Enum):
18
28
  NULL = "null"
@@ -20,11 +30,17 @@ class DisaggregationMode(Enum):
20
30
  DECODE = "decode"
21
31
 
22
32
 
23
- FakeBootstrapHost = "2.2.2.2"
24
-
25
-
26
33
  def poll_and_all_reduce(pollers, gloo_group):
27
- polls = [int(poller.poll()) for poller in pollers]
34
+ # at a certain prob, the poll is failed to simulate failure
35
+ if FAILURE_PROB > 0:
36
+ from sglang.srt.disaggregation.base import KVPoll
37
+
38
+ polls = [
39
+ int(KVPoll.Failed) if random.random() < FAILURE_PROB else int(poller.poll())
40
+ for poller in pollers
41
+ ]
42
+ else:
43
+ polls = [int(poller.poll()) for poller in pollers]
28
44
  tensor_to_reduce = torch.tensor(polls, dtype=torch.uint8, device="cpu")
29
45
  dist.all_reduce(tensor_to_reduce, op=dist.ReduceOp.MIN, group=gloo_group)
30
46
  return tensor_to_reduce.tolist()
@@ -112,7 +128,7 @@ def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
112
128
 
113
129
 
114
130
  def kv_to_page_indices(kv_indices: np.ndarray, page_size: int):
115
- # 1. The page is guaruanteed to be full except the last page.
131
+ # 1. The page is guaranteed to be full except the last page.
116
132
  # 2. page index = kv_index // page_size
117
133
  # The return vector is kv_indices[::page_size] // page_size
118
134
  if page_size == 1: # shortcut
@@ -162,3 +178,104 @@ def register_disaggregation_server(
162
178
  warnings.warn(
163
179
  f"Failed to register disaggregation server: {res.status_code} {res.text}"
164
180
  )
181
+
182
+
183
+ def is_mla_backend(target_kv_pool) -> bool:
184
+ from sglang.srt.mem_cache.memory_pool import MLATokenToKVPool
185
+
186
+ return isinstance(target_kv_pool, MLATokenToKVPool)
187
+
188
+
189
+ def prepare_abort(req: Req, error_message: str, status_code=None):
190
+ from sglang.srt.managers.schedule_batch import FINISH_ABORT
191
+
192
+ # populate finish metadata and stream output
193
+ req.finished_reason = FINISH_ABORT(error_message, status_code)
194
+
195
+ if req.return_logprob:
196
+ req.input_token_logprobs_val = []
197
+ req.input_token_logprobs_idx = []
198
+ req.input_top_logprobs_val = []
199
+ req.input_top_logprobs_idx = []
200
+ req.input_token_ids_logprobs_val = []
201
+ req.input_token_ids_logprobs_idx = []
202
+
203
+
204
+ class MetadataBuffers:
205
+ def __init__(self, size: int, max_top_logprobs_num: int = 128):
206
+ # TODO: abort top_logprobs_num > 128 in PD
207
+
208
+ # We transfer the metadata of first output token to decode
209
+ # The minimal size for RDMA is 64Bytes, so we pad it to > 64Bytes
210
+ self.output_ids = torch.zeros((size, 16), dtype=torch.int32, device="cpu")
211
+ self.output_token_logprobs_val = torch.zeros(
212
+ (size, 16), dtype=torch.float32, device="cpu"
213
+ )
214
+ self.output_token_logprobs_idx = torch.zeros(
215
+ (size, 16), dtype=torch.int32, device="cpu"
216
+ )
217
+ self.output_top_logprobs_val = torch.zeros(
218
+ (size, max_top_logprobs_num), dtype=torch.float32, device="cpu"
219
+ )
220
+ self.output_top_logprobs_idx = torch.zeros(
221
+ (size, max_top_logprobs_num), dtype=torch.int32, device="cpu"
222
+ )
223
+
224
+ def get_buf_infos(self):
225
+ ptrs = [
226
+ self.output_ids.data_ptr(),
227
+ self.output_token_logprobs_val.data_ptr(),
228
+ self.output_token_logprobs_idx.data_ptr(),
229
+ self.output_top_logprobs_val.data_ptr(),
230
+ self.output_top_logprobs_idx.data_ptr(),
231
+ ]
232
+ data_lens = [
233
+ self.output_ids.nbytes,
234
+ self.output_token_logprobs_val.nbytes,
235
+ self.output_token_logprobs_idx.nbytes,
236
+ self.output_top_logprobs_val.nbytes,
237
+ self.output_top_logprobs_idx.nbytes,
238
+ ]
239
+ item_lens = [
240
+ self.output_ids[0].nbytes,
241
+ self.output_token_logprobs_val[0].nbytes,
242
+ self.output_token_logprobs_idx[0].nbytes,
243
+ self.output_top_logprobs_val[0].nbytes,
244
+ self.output_top_logprobs_idx[0].nbytes,
245
+ ]
246
+ return ptrs, data_lens, item_lens
247
+
248
+ def get_buf(self, idx: int):
249
+ return (
250
+ self.output_ids[idx],
251
+ self.output_token_logprobs_val[idx],
252
+ self.output_token_logprobs_idx[idx],
253
+ self.output_top_logprobs_val[idx],
254
+ self.output_top_logprobs_idx[idx],
255
+ )
256
+
257
+ def set_buf(self, req: Req):
258
+
259
+ self.output_ids[req.metadata_buffer_index][0] = req.output_ids[0]
260
+ if req.return_logprob:
261
+ if req.output_token_logprobs_val: # not none or empty list
262
+ self.output_token_logprobs_val[req.metadata_buffer_index][0] = (
263
+ req.output_token_logprobs_val[0]
264
+ )
265
+ if req.output_token_logprobs_idx: # not none or empty list
266
+ self.output_token_logprobs_idx[req.metadata_buffer_index][0] = (
267
+ req.output_token_logprobs_idx[0]
268
+ )
269
+
270
+ if req.output_top_logprobs_val: # not none or empty list
271
+ self.output_top_logprobs_val[req.metadata_buffer_index][
272
+ : len(req.output_top_logprobs_val[0])
273
+ ] = torch.tensor(
274
+ req.output_top_logprobs_val[0], dtype=torch.float32, device="cpu"
275
+ )
276
+ if req.output_top_logprobs_idx: # not none or empty list
277
+ self.output_top_logprobs_idx[req.metadata_buffer_index][
278
+ : len(req.output_top_logprobs_idx[0])
279
+ ] = torch.tensor(
280
+ req.output_top_logprobs_idx[0], dtype=torch.int32, device="cpu"
281
+ )
@@ -127,14 +127,14 @@ class StatelessProcessGroup:
127
127
  key = f"send_to/{dst}/{self.send_dst_counter[dst]}"
128
128
  self.store.set(key, pickle.dumps(obj))
129
129
  self.send_dst_counter[dst] += 1
130
- self.entries.append((key, time.time()))
130
+ self.entries.append((key, time.perf_counter()))
131
131
 
132
132
  def expire_data(self):
133
133
  """Expire data that is older than `data_expiration_seconds` seconds."""
134
134
  while self.entries:
135
135
  # check the oldest entry
136
136
  key, timestamp = self.entries[0]
137
- if time.time() - timestamp > self.data_expiration_seconds:
137
+ if time.perf_counter() - timestamp > self.data_expiration_seconds:
138
138
  self.store.delete_key(key)
139
139
  self.entries.popleft()
140
140
  else:
@@ -158,7 +158,7 @@ class StatelessProcessGroup:
158
158
  key = f"broadcast_from/{src}/" f"{self.broadcast_send_counter}"
159
159
  self.store.set(key, pickle.dumps(obj))
160
160
  self.broadcast_send_counter += 1
161
- self.entries.append((key, time.time()))
161
+ self.entries.append((key, time.perf_counter()))
162
162
  return obj
163
163
  else:
164
164
  key = f"broadcast_from/{src}/" f"{self.broadcast_recv_src_counter[src]}"
@@ -27,6 +27,11 @@ class EngineBase(ABC):
27
27
  """Generate outputs based on given inputs."""
28
28
  pass
29
29
 
30
+ @abstractmethod
31
+ def flush_cache(self):
32
+ """Flush the cache of the engine."""
33
+ pass
34
+
30
35
  @abstractmethod
31
36
  def update_weights_from_tensor(
32
37
  self,
@@ -47,6 +47,7 @@ from sglang.srt.managers.io_struct import (
47
47
  EmbeddingReqInput,
48
48
  GenerateReqInput,
49
49
  GetWeightsByNameReqInput,
50
+ ImageDataItem,
50
51
  InitWeightsUpdateGroupReqInput,
51
52
  ReleaseMemoryOccupationReqInput,
52
53
  ResumeMemoryOccupationReqInput,
@@ -150,9 +151,9 @@ class Engine(EngineBase):
150
151
  # See also python/sglang/srt/utils.py:load_image for more details.
151
152
  image_data: Optional[
152
153
  Union[
153
- List[List[Union[Image, str]]],
154
- List[Union[Image, str]],
155
- Union[Image, str],
154
+ List[List[ImageDataItem]],
155
+ List[ImageDataItem],
156
+ ImageDataItem,
156
157
  ]
157
158
  ] = None,
158
159
  return_logprob: Optional[Union[List[bool], bool]] = False,
@@ -221,9 +222,9 @@ class Engine(EngineBase):
221
222
  # See also python/sglang/srt/utils.py:load_image for more details.
222
223
  image_data: Optional[
223
224
  Union[
224
- List[List[Union[Image, str]]],
225
- List[Union[Image, str]],
226
- Union[Image, str],
225
+ List[List[ImageDataItem]],
226
+ List[ImageDataItem],
227
+ ImageDataItem,
227
228
  ]
228
229
  ] = None,
229
230
  return_logprob: Optional[Union[List[bool], bool]] = False,
@@ -285,6 +286,21 @@ class Engine(EngineBase):
285
286
  ret = loop.run_until_complete(generator.__anext__())
286
287
  return ret
287
288
 
289
+ async def async_encode(
290
+ self,
291
+ prompt: Union[str, List[str], List[Dict], List[List[Dict]]],
292
+ image_data: Optional[Union[List[str], str]] = None,
293
+ ) -> Dict:
294
+ """
295
+ Asynchronous version of encode method.
296
+
297
+ The arguments of this function is the same as `sglang/srt/managers/io_struct.py::EmbeddingReqInput`.
298
+ Please refer to `EmbeddingReqInput` for the documentation.
299
+ """
300
+ obj = EmbeddingReqInput(text=prompt, image_data=image_data)
301
+ generator = self.tokenizer_manager.generate_request(obj, None)
302
+ return await generator.__anext__()
303
+
288
304
  def shutdown(self):
289
305
  """Shutdown the engine"""
290
306
  kill_process_tree(os.getpid(), include_parent=False)
@@ -305,7 +321,26 @@ class Engine(EngineBase):
305
321
  loop.run_until_complete(self.tokenizer_manager.start_profile())
306
322
 
307
323
  def stop_profile(self):
308
- self.tokenizer_manager.stop_profile()
324
+ loop = asyncio.get_event_loop()
325
+ loop.run_until_complete(self.tokenizer_manager.stop_profile())
326
+
327
+ def start_expert_distribution_record(self):
328
+ loop = asyncio.get_event_loop()
329
+ loop.run_until_complete(
330
+ self.tokenizer_manager.start_expert_distribution_record()
331
+ )
332
+
333
+ def stop_expert_distribution_record(self):
334
+ loop = asyncio.get_event_loop()
335
+ loop.run_until_complete(
336
+ self.tokenizer_manager.stop_expert_distribution_record()
337
+ )
338
+
339
+ def dump_expert_distribution_record(self):
340
+ loop = asyncio.get_event_loop()
341
+ loop.run_until_complete(
342
+ self.tokenizer_manager.dump_expert_distribution_record()
343
+ )
309
344
 
310
345
  def get_server_info(self):
311
346
  loop = asyncio.get_event_loop()
@@ -315,7 +350,7 @@ class Engine(EngineBase):
315
350
  return {
316
351
  **dataclasses.asdict(self.tokenizer_manager.server_args),
317
352
  **self.scheduler_info,
318
- **internal_states,
353
+ "internal_states": internal_states,
319
354
  "version": __version__,
320
355
  }
321
356
 
@@ -471,7 +506,7 @@ def _set_envs_and_config(server_args: ServerArgs):
471
506
  if _is_cuda:
472
507
  assert_pkg_version(
473
508
  "sgl-kernel",
474
- "0.1.1",
509
+ "0.1.4",
475
510
  "Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`",
476
511
  )
477
512