sglang 0.4.5__py3-none-any.whl → 0.4.5.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. sglang/__init__.py +2 -4
  2. sglang/bench_one_batch.py +23 -2
  3. sglang/bench_serving.py +6 -4
  4. sglang/lang/backend/anthropic.py +0 -4
  5. sglang/lang/backend/base_backend.py +1 -1
  6. sglang/lang/backend/openai.py +1 -1
  7. sglang/lang/backend/vertexai.py +0 -1
  8. sglang/lang/compiler.py +1 -7
  9. sglang/lang/tracer.py +3 -7
  10. sglang/srt/_custom_ops.py +0 -2
  11. sglang/srt/configs/model_config.py +37 -5
  12. sglang/srt/constrained/base_grammar_backend.py +26 -5
  13. sglang/srt/constrained/llguidance_backend.py +1 -0
  14. sglang/srt/constrained/outlines_backend.py +1 -0
  15. sglang/srt/constrained/outlines_jump_forward.py +14 -1
  16. sglang/srt/constrained/reasoner_grammar_backend.py +101 -0
  17. sglang/srt/constrained/triton_ops/bitmask_ops.py +141 -0
  18. sglang/srt/constrained/xgrammar_backend.py +27 -4
  19. sglang/srt/custom_op.py +0 -62
  20. sglang/srt/disaggregation/base/__init__.py +8 -0
  21. sglang/srt/disaggregation/base/conn.py +113 -0
  22. sglang/srt/disaggregation/decode.py +80 -11
  23. sglang/srt/disaggregation/mini_lb.py +58 -123
  24. sglang/srt/disaggregation/mooncake/__init__.py +6 -0
  25. sglang/srt/disaggregation/mooncake/conn.py +585 -0
  26. sglang/srt/disaggregation/mooncake/transfer_engine.py +77 -0
  27. sglang/srt/disaggregation/prefill.py +82 -22
  28. sglang/srt/disaggregation/utils.py +46 -0
  29. sglang/srt/entrypoints/EngineBase.py +53 -0
  30. sglang/srt/entrypoints/engine.py +36 -8
  31. sglang/srt/entrypoints/http_server.py +37 -8
  32. sglang/srt/entrypoints/http_server_engine.py +142 -0
  33. sglang/srt/entrypoints/verl_engine.py +42 -13
  34. sglang/srt/hf_transformers_utils.py +4 -0
  35. sglang/srt/layers/activation.py +6 -8
  36. sglang/srt/layers/attention/flashattention_backend.py +430 -257
  37. sglang/srt/layers/attention/flashinfer_backend.py +18 -9
  38. sglang/srt/layers/attention/torch_native_backend.py +6 -1
  39. sglang/srt/layers/attention/triton_backend.py +6 -0
  40. sglang/srt/layers/attention/triton_ops/extend_attention.py +13 -2
  41. sglang/srt/layers/attention/vision.py +1 -1
  42. sglang/srt/layers/dp_attention.py +2 -4
  43. sglang/srt/layers/elementwise.py +15 -2
  44. sglang/srt/layers/layernorm.py +1 -1
  45. sglang/srt/layers/linear.py +18 -3
  46. sglang/srt/layers/moe/ep_moe/layer.py +15 -29
  47. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +145 -118
  48. sglang/srt/layers/moe/fused_moe_native.py +4 -0
  49. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  50. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  51. sglang/srt/layers/moe/fused_moe_triton/configs/{E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=264,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +34 -34
  52. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  53. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  54. sglang/srt/layers/moe/fused_moe_triton/configs/E=288,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +46 -34
  56. sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
  57. sglang/srt/layers/moe/router.py +7 -1
  58. sglang/srt/layers/moe/topk.py +63 -45
  59. sglang/srt/layers/parameter.py +0 -2
  60. sglang/srt/layers/quantization/__init__.py +13 -5
  61. sglang/srt/layers/quantization/blockwise_int8.py +2 -0
  62. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +12 -2
  63. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -77
  64. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +4 -7
  65. sglang/srt/layers/quantization/fp8.py +131 -136
  66. sglang/srt/layers/quantization/fp8_kernel.py +328 -46
  67. sglang/srt/layers/quantization/fp8_utils.py +206 -253
  68. sglang/srt/layers/quantization/kv_cache.py +43 -52
  69. sglang/srt/layers/quantization/modelopt_quant.py +271 -4
  70. sglang/srt/layers/quantization/moe_wna16.py +2 -0
  71. sglang/srt/layers/quantization/utils.py +5 -11
  72. sglang/srt/layers/quantization/w8a8_fp8.py +156 -4
  73. sglang/srt/layers/quantization/w8a8_int8.py +8 -7
  74. sglang/srt/layers/radix_attention.py +28 -1
  75. sglang/srt/layers/rotary_embedding.py +15 -3
  76. sglang/srt/layers/sampler.py +5 -10
  77. sglang/srt/lora/backend/base_backend.py +18 -2
  78. sglang/srt/lora/backend/flashinfer_backend.py +1 -1
  79. sglang/srt/lora/backend/triton_backend.py +1 -1
  80. sglang/srt/lora/layers.py +1 -1
  81. sglang/srt/lora/lora.py +1 -1
  82. sglang/srt/lora/lora_manager.py +1 -1
  83. sglang/srt/managers/detokenizer_manager.py +0 -1
  84. sglang/srt/managers/io_struct.py +255 -97
  85. sglang/srt/managers/mm_utils.py +7 -5
  86. sglang/srt/managers/multimodal_processor.py +0 -2
  87. sglang/srt/managers/multimodal_processors/base_processor.py +117 -79
  88. sglang/srt/managers/multimodal_processors/janus_pro.py +3 -1
  89. sglang/srt/managers/multimodal_processors/mllama4.py +21 -36
  90. sglang/srt/managers/schedule_batch.py +64 -25
  91. sglang/srt/managers/scheduler.py +80 -82
  92. sglang/srt/managers/tokenizer_manager.py +18 -3
  93. sglang/srt/managers/tp_worker.py +1 -0
  94. sglang/srt/mem_cache/hiradix_cache.py +5 -1
  95. sglang/srt/mem_cache/memory_pool.py +21 -3
  96. sglang/srt/metrics/collector.py +9 -0
  97. sglang/srt/model_executor/cuda_graph_runner.py +9 -6
  98. sglang/srt/model_executor/forward_batch_info.py +234 -15
  99. sglang/srt/model_executor/model_runner.py +67 -35
  100. sglang/srt/model_loader/loader.py +31 -4
  101. sglang/srt/model_loader/weight_utils.py +4 -2
  102. sglang/srt/models/baichuan.py +2 -0
  103. sglang/srt/models/bert.py +398 -0
  104. sglang/srt/models/chatglm.py +1 -0
  105. sglang/srt/models/commandr.py +1 -0
  106. sglang/srt/models/dbrx.py +1 -0
  107. sglang/srt/models/deepseek.py +2 -1
  108. sglang/srt/models/deepseek_nextn.py +74 -70
  109. sglang/srt/models/deepseek_v2.py +494 -366
  110. sglang/srt/models/exaone.py +1 -0
  111. sglang/srt/models/gemma.py +1 -0
  112. sglang/srt/models/gemma2.py +1 -0
  113. sglang/srt/models/gemma3_causal.py +1 -0
  114. sglang/srt/models/gpt2.py +1 -0
  115. sglang/srt/models/gpt_bigcode.py +1 -0
  116. sglang/srt/models/granite.py +1 -0
  117. sglang/srt/models/grok.py +1 -0
  118. sglang/srt/models/internlm2.py +1 -0
  119. sglang/srt/models/llama.py +6 -5
  120. sglang/srt/models/llama4.py +101 -34
  121. sglang/srt/models/minicpm.py +1 -0
  122. sglang/srt/models/minicpm3.py +30 -200
  123. sglang/srt/models/mixtral.py +1 -0
  124. sglang/srt/models/mixtral_quant.py +1 -0
  125. sglang/srt/models/mllama.py +51 -8
  126. sglang/srt/models/mllama4.py +102 -29
  127. sglang/srt/models/olmo.py +1 -0
  128. sglang/srt/models/olmo2.py +1 -0
  129. sglang/srt/models/olmoe.py +1 -0
  130. sglang/srt/models/phi3_small.py +1 -0
  131. sglang/srt/models/qwen.py +1 -0
  132. sglang/srt/models/qwen2.py +5 -1
  133. sglang/srt/models/qwen2_5_vl.py +35 -70
  134. sglang/srt/models/qwen2_moe.py +15 -13
  135. sglang/srt/models/qwen2_vl.py +27 -25
  136. sglang/srt/models/qwen3.py +335 -0
  137. sglang/srt/models/qwen3_moe.py +423 -0
  138. sglang/srt/models/stablelm.py +1 -0
  139. sglang/srt/models/xverse.py +1 -0
  140. sglang/srt/models/xverse_moe.py +1 -0
  141. sglang/srt/openai_api/adapter.py +4 -1
  142. sglang/srt/patch_torch.py +11 -0
  143. sglang/srt/reasoning_parser.py +0 -1
  144. sglang/srt/sampling/sampling_batch_info.py +2 -3
  145. sglang/srt/server_args.py +55 -19
  146. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +4 -4
  147. sglang/srt/speculative/eagle_utils.py +1 -11
  148. sglang/srt/speculative/eagle_worker.py +10 -9
  149. sglang/srt/utils.py +136 -10
  150. sglang/test/attention/test_flashattn_backend.py +259 -221
  151. sglang/test/attention/test_flashattn_mla_backend.py +285 -0
  152. sglang/test/attention/test_prefix_chunk_info.py +224 -0
  153. sglang/test/runners.py +5 -1
  154. sglang/test/test_block_fp8.py +224 -0
  155. sglang/test/test_custom_ops.py +1 -1
  156. sglang/test/test_utils.py +19 -8
  157. sglang/version.py +1 -1
  158. {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/METADATA +15 -5
  159. {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/RECORD +162 -147
  160. {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/WHEEL +1 -1
  161. sglang/lang/__init__.py +0 -0
  162. sglang/srt/disaggregation/conn.py +0 -81
  163. sglang/srt/lora/backend/__init__.py +0 -25
  164. sglang/srt/server.py +0 -18
  165. {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/licenses/LICENSE +0 -0
  166. {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/top_level.txt +0 -0
@@ -10,9 +10,9 @@ Life cycle of a request in the prefill server
10
10
  2. Waiting Queue
11
11
  a. Use PrefillAdder to pop requests
12
12
  b. Run forward
13
- c. Add the request to Infight Queue
13
+ c. Add the request to Inflight Queue
14
14
 
15
- 3. Infight Queue
15
+ 3. Inflight Queue
16
16
  a. Poll (non-blocking) the sender of the request
17
17
  b. Once the transfer has finished, return the request
18
18
  """
@@ -24,9 +24,15 @@ from typing import TYPE_CHECKING, List, Optional
24
24
 
25
25
  import torch
26
26
 
27
- from sglang.srt.disaggregation.conn import KVArgs, KVManager, KVPoll, KVSender
27
+ from sglang.srt.disaggregation.base import BaseKVManager, KVArgs, KVPoll
28
28
  from sglang.srt.disaggregation.utils import (
29
+ DisaggregationMode,
30
+ KVClassType,
29
31
  ReqToMetadataIdxAllocator,
32
+ TransferBackend,
33
+ get_kv_class,
34
+ kv_to_page_indices,
35
+ kv_to_page_num,
30
36
  poll_and_all_reduce,
31
37
  )
32
38
  from sglang.srt.managers.schedule_batch import FINISH_LENGTH, Req, ScheduleBatch
@@ -37,6 +43,7 @@ if TYPE_CHECKING:
37
43
  from sglang.srt.managers.scheduler import GenerationBatchResult, Scheduler
38
44
  from sglang.srt.mem_cache.memory_pool import KVCache
39
45
 
46
+
40
47
  logger = logging.getLogger(__name__)
41
48
 
42
49
 
@@ -55,6 +62,8 @@ class PrefillBootstrapQueue:
55
62
  tp_size: int,
56
63
  bootstrap_port: int,
57
64
  gloo_group: ProcessGroup,
65
+ transfer_backend: TransferBackend,
66
+ scheduler: Scheduler,
58
67
  ):
59
68
  self.token_to_kv_pool = token_to_kv_pool
60
69
  self.aux_dtype = aux_dtype
@@ -63,17 +72,19 @@ class PrefillBootstrapQueue:
63
72
  self.req_to_metadata_buffer_idx_allocator = req_to_metadata_buffer_idx_allocator
64
73
  self.tp_rank = tp_rank
65
74
  self.tp_size = tp_size
75
+ self.transfer_backend = transfer_backend
76
+ self.scheduler = scheduler
66
77
  self.kv_manager = self._init_kv_manager()
67
78
  self.queue: List[Req] = []
68
79
  self.gloo_group = gloo_group
69
80
  self.bootstrap_port = bootstrap_port
70
81
 
71
- def allocate_token_id(self, idx: int, token_id: int):
82
+ def store_prefill_results(self, idx: int, token_id: int):
72
83
  assert token_id >= 0, f"token_id: {token_id} is negative"
73
84
  output_id_buffer = self.metadata_buffers[0]
74
85
  output_id_buffer[idx] = token_id
75
86
 
76
- def _init_kv_manager(self) -> KVManager:
87
+ def _init_kv_manager(self) -> BaseKVManager:
77
88
  kv_args = KVArgs()
78
89
  kv_args.engine_rank = self.tp_rank
79
90
  kv_data_ptrs, kv_data_lens, kv_item_lens = (
@@ -94,12 +105,17 @@ class PrefillBootstrapQueue:
94
105
  kv_args.aux_item_lens = [
95
106
  metadata_buffer[0].nbytes for metadata_buffer in self.metadata_buffers
96
107
  ]
97
- kv_args.ib_device = "mock-ib-device"
98
- kv_manager = KVManager(kv_args)
108
+ kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device
109
+ kv_args.gpu_id = self.scheduler.gpu_id
110
+ kv_manager_class = get_kv_class(self.transfer_backend, KVClassType.MANAGER)
111
+ kv_manager = kv_manager_class(
112
+ kv_args, DisaggregationMode.PREFILL, self.scheduler.server_args
113
+ )
99
114
  return kv_manager
100
115
 
101
116
  def add(self, req: Req) -> None:
102
- req.disagg_kv_sender = KVSender(
117
+ kv_sender_class = get_kv_class(self.transfer_backend, KVClassType.SENDER)
118
+ req.disagg_kv_sender = kv_sender_class(
103
119
  mgr=self.kv_manager,
104
120
  bootstrap_addr=f"{req.bootstrap_host}:{self.bootstrap_port}",
105
121
  bootstrap_room=req.bootstrap_room,
@@ -131,7 +147,7 @@ class PrefillBootstrapQueue:
131
147
  elif poll == KVPoll.Failed:
132
148
  raise Exception("Bootstrap failed")
133
149
 
134
- # KV.WaitingForInput - init here
150
+ # KV.WaitingForInput
135
151
  num_kv_indices = len(req.origin_input_ids)
136
152
  if self.req_to_metadata_buffer_idx_allocator.available_size() == 0:
137
153
  break
@@ -140,7 +156,8 @@ class PrefillBootstrapQueue:
140
156
  self.req_to_metadata_buffer_idx_allocator.alloc()
141
157
  )
142
158
  assert req.metadata_buffer_index is not None
143
- req.disagg_kv_sender.init(num_kv_indices, req.metadata_buffer_index)
159
+ num_pages = kv_to_page_num(num_kv_indices, self.token_to_kv_pool.page_size)
160
+ req.disagg_kv_sender.init(num_pages, req.metadata_buffer_index)
144
161
 
145
162
  bootstrapped_reqs.append(req)
146
163
  indices_to_remove.add(i)
@@ -157,11 +174,41 @@ class SchedulerDisaggregationPrefillMixin:
157
174
  Mixin for Scheduler to handle disaggregation prefill
158
175
  """
159
176
 
177
+ @torch.no_grad()
178
+ def event_loop_normal_disagg_prefill(self):
179
+ """A normal scheduler loop for prefill worker in disaggregation mode."""
180
+
181
+ while True:
182
+ recv_reqs = self.recv_requests()
183
+ self.process_input_requests(recv_reqs)
184
+ self.waiting_queue.extend(
185
+ self.disagg_prefill_pending_queue.pop_bootstrapped()
186
+ )
187
+ self.process_prefill_chunk()
188
+ batch = self.get_new_batch_prefill()
189
+ self.cur_batch = batch
190
+
191
+ if batch:
192
+ result = self.run_batch(batch)
193
+ self.process_batch_result_disagg_prefill(batch, result)
194
+
195
+ if len(self.disagg_prefill_inflight_queue) > 0:
196
+ self.process_disagg_prefill_inflight_queue()
197
+
198
+ if batch is None and len(self.disagg_prefill_inflight_queue) == 0:
199
+ self.check_memory()
200
+ self.new_token_ratio = self.init_new_token_ratio
201
+
202
+ self.last_batch = batch
203
+ # HACK (byronhsu): reset the batch_is_full flag because we never enter update_running_batch which resets it
204
+ # Otherwise, it hangs under high concurrency
205
+ self.running_batch.batch_is_full = False
206
+
160
207
  def process_batch_result_disagg_prefill(
161
208
  self: Scheduler, batch: ScheduleBatch, result: GenerationBatchResult
162
209
  ) -> None:
163
210
  """
164
- Transfer kv for prefill completed requests and add it into disagg_prefill_infight_queue
211
+ Transfer kv for prefill completed requests and add it into disagg_prefill_inflight_queue
165
212
  Adapted from process_batch_result_prefill
166
213
  """
167
214
 
@@ -174,7 +221,7 @@ class SchedulerDisaggregationPrefillMixin:
174
221
  req.output_ids.append(next_token_id)
175
222
  self.tree_cache.cache_unfinished_req(req) # update the tree and lock
176
223
  self.send_kv_chunk(req, token_id=next_token_id)
177
- self.disagg_prefill_infight_queue.append(req)
224
+ self.disagg_prefill_inflight_queue.append(req)
178
225
  else:
179
226
  # being chunked reqs' prefill is not finished
180
227
  req.is_chunked -= 1
@@ -186,35 +233,41 @@ class SchedulerDisaggregationPrefillMixin:
186
233
  self.current_stream.synchronize()
187
234
  batch.next_batch_sampling_info.sampling_info_done.set()
188
235
 
189
- def process_disagg_prefill_infight_queue(self: Scheduler) -> None:
236
+ def process_disagg_prefill_inflight_queue(self: Scheduler) -> None:
190
237
  """
191
238
  Poll the requests in the middle of transfer. If done, return the request.
192
239
  """
193
- assert len(self.disagg_prefill_infight_queue) > 0
240
+ assert len(self.disagg_prefill_inflight_queue) > 0
194
241
 
195
242
  done_reqs = []
196
243
 
197
244
  polls = poll_and_all_reduce(
198
- [req.disagg_kv_sender for req in self.disagg_prefill_infight_queue],
199
- self.tp_worker.get_tp_cpu_group(),
245
+ [req.disagg_kv_sender for req in self.disagg_prefill_inflight_queue],
246
+ self.attn_tp_cpu_group,
200
247
  )
201
248
 
202
249
  undone_reqs: List[Req] = []
203
- # Check .poll() for the reqs in disagg_prefill_infight_queue. If Success, respond to the client and remove it from the queue
204
- for req, poll in zip(self.disagg_prefill_infight_queue, polls):
250
+ # Check .poll() for the reqs in disagg_prefill_inflight_queue. If Success, respond to the client and remove it from the queue
251
+ for req, poll in zip(self.disagg_prefill_inflight_queue, polls):
205
252
  if poll in [KVPoll.WaitingForInput, KVPoll.Transferring]:
206
253
  undone_reqs.append(req)
207
254
  elif poll == KVPoll.Success: # transfer done
208
255
  self.tree_cache.cache_finished_req(req) # unlock the tree
209
256
  req.finished_reason = FINISH_LENGTH(length=0)
257
+ # FIXME: clean up req's data in transfer engine
210
258
  done_reqs.append(req)
211
259
  elif poll == KVPoll.Failed:
212
260
  raise Exception("Transferring failed")
213
261
 
262
+ for req in done_reqs:
263
+ self.disagg_prefill_pending_queue.req_to_metadata_buffer_idx_allocator.free(
264
+ req.metadata_buffer_index
265
+ )
266
+
214
267
  # Stream requests which have finished transfer
215
268
  self.stream_output(done_reqs, False, None)
216
269
 
217
- self.disagg_prefill_infight_queue = undone_reqs
270
+ self.disagg_prefill_inflight_queue = undone_reqs
218
271
 
219
272
  def process_prefill_chunk(self: Scheduler) -> None:
220
273
  if self.last_batch and self.last_batch.forward_mode.is_extend():
@@ -236,14 +289,21 @@ class SchedulerDisaggregationPrefillMixin:
236
289
  """
237
290
  start_idx = req.start_send_idx
238
291
  end_idx = min(len(req.fill_ids), len(req.origin_input_ids))
292
+
293
+ # Update next start_send_idx
294
+ req.start_send_idx = end_idx
295
+
239
296
  kv_indices = (
240
297
  self.req_to_token_pool.req_to_token[req.req_pool_idx][start_idx:end_idx]
241
298
  .cpu()
242
299
  .numpy()
243
300
  )
244
- req.start_send_idx = end_idx
245
301
  if token_id is not None:
246
- self.disagg_prefill_pending_queue.allocate_token_id(
302
+ self.disagg_prefill_pending_queue.store_prefill_results(
247
303
  req.metadata_buffer_index, token_id
248
304
  )
249
- req.disagg_kv_sender.send(kv_indices)
305
+ is_last = token_id is not None
306
+ page_indices = kv_to_page_indices(
307
+ kv_indices, self.token_to_kv_pool_allocator.page_size
308
+ )
309
+ req.disagg_kv_sender.send(page_indices, slice(start_idx, end_idx), is_last)
@@ -4,6 +4,7 @@ from collections import deque
4
4
  from enum import Enum
5
5
  from typing import List
6
6
 
7
+ import numpy as np
7
8
  import torch
8
9
  import torch.distributed as dist
9
10
 
@@ -42,3 +43,48 @@ class ReqToMetadataIdxAllocator:
42
43
 
43
44
  def free(self, free_index: int):
44
45
  self.free_slots.append(free_index)
46
+
47
+
48
+ class TransferBackend(Enum):
49
+ MOONCAKE = "mooncake"
50
+ FAKE = "fake"
51
+
52
+
53
+ class KVClassType(Enum):
54
+ MANAGER = "manager"
55
+ SENDER = "sender"
56
+ RECEIVER = "receiver"
57
+ BOOTSTRAP_SERVER = "bootstrap_server"
58
+
59
+
60
+ def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
61
+ if transfer_backend == TransferBackend.MOONCAKE:
62
+ from sglang.srt.disaggregation.mooncake import (
63
+ MooncakeKVBootstrapServer,
64
+ MooncakeKVManager,
65
+ MooncakeKVReceiver,
66
+ MooncakeKVSender,
67
+ )
68
+
69
+ class_mapping = {
70
+ KVClassType.MANAGER: MooncakeKVManager,
71
+ KVClassType.SENDER: MooncakeKVSender,
72
+ KVClassType.RECEIVER: MooncakeKVReceiver,
73
+ KVClassType.BOOTSTRAP_SERVER: MooncakeKVBootstrapServer,
74
+ }
75
+ return class_mapping.get(class_type)
76
+ raise ValueError(f"Unsupported transfer backend: {transfer_backend}")
77
+
78
+
79
+ def kv_to_page_indices(kv_indices: np.ndarray, page_size: int):
80
+ # 1. The page is guaruanteed to be full except the last page.
81
+ # 2. page index = kv_index // page_size
82
+ # The return vector is kv_indices[::page_size] // page_size
83
+ if page_size == 1: # shortcut
84
+ return kv_indices
85
+ return kv_indices[::page_size] // page_size
86
+
87
+
88
+ def kv_to_page_num(num_kv_indices: int, page_size: int):
89
+ # ceil(num_kv_indices / page_size)
90
+ return (num_kv_indices + page_size - 1) // page_size
@@ -0,0 +1,53 @@
1
+ from abc import ABC, abstractmethod
2
+ from typing import Dict, Iterator, List, Optional, Tuple, Union
3
+
4
+ import torch
5
+
6
+
7
+ class EngineBase(ABC):
8
+ """
9
+ Abstract base class for engine interfaces that support generation, weight updating, and memory control.
10
+ This base class provides a unified API for both HTTP-based engines and engines.
11
+ """
12
+
13
+ @abstractmethod
14
+ def generate(
15
+ self,
16
+ prompt: Optional[Union[List[str], str]] = None,
17
+ sampling_params: Optional[Union[List[Dict], Dict]] = None,
18
+ input_ids: Optional[Union[List[List[int]], List[int]]] = None,
19
+ image_data: Optional[Union[List[str], str]] = None,
20
+ return_logprob: Optional[Union[List[bool], bool]] = False,
21
+ logprob_start_len: Optional[Union[List[int], int]] = None,
22
+ top_logprobs_num: Optional[Union[List[int], int]] = None,
23
+ token_ids_logprob: Optional[Union[List[List[int]], List[int]]] = None,
24
+ lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None,
25
+ custom_logit_processor: Optional[Union[List[str], str]] = None,
26
+ ) -> Union[Dict, Iterator[Dict]]:
27
+ """Generate outputs based on given inputs."""
28
+ pass
29
+
30
+ @abstractmethod
31
+ def update_weights_from_tensor(
32
+ self,
33
+ named_tensors: List[Tuple[str, torch.Tensor]],
34
+ load_format: Optional[str] = None,
35
+ flush_cache: bool = True,
36
+ ):
37
+ """Update model weights with in-memory tensor data."""
38
+ pass
39
+
40
+ @abstractmethod
41
+ def release_memory_occupation(self):
42
+ """Release GPU memory occupation temporarily."""
43
+ pass
44
+
45
+ @abstractmethod
46
+ def resume_memory_occupation(self):
47
+ """Resume GPU memory occupation which is previously released."""
48
+ pass
49
+
50
+ @abstractmethod
51
+ def shutdown(self):
52
+ """Shutdown the engine and clean up resources."""
53
+ pass
@@ -29,6 +29,7 @@ from typing import AsyncIterator, Dict, Iterator, List, Optional, Tuple, Union
29
29
 
30
30
  import zmq
31
31
  import zmq.asyncio
32
+ from PIL.Image import Image
32
33
 
33
34
  # Fix a bug of Python threading
34
35
  setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
@@ -37,6 +38,7 @@ import torch
37
38
  import uvloop
38
39
 
39
40
  from sglang.srt.code_completion_parser import load_completion_template_for_openai_api
41
+ from sglang.srt.entrypoints.EngineBase import EngineBase
40
42
  from sglang.srt.managers.data_parallel_controller import (
41
43
  run_data_parallel_controller_process,
42
44
  )
@@ -77,7 +79,7 @@ logger = logging.getLogger(__name__)
77
79
  asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
78
80
 
79
81
 
80
- class Engine:
82
+ class Engine(EngineBase):
81
83
  """
82
84
  The entry point to the inference engine.
83
85
 
@@ -135,9 +137,19 @@ class Engine:
135
137
  sampling_params: Optional[Union[List[Dict], Dict]] = None,
136
138
  # The token ids for text; one can either specify text or input_ids.
137
139
  input_ids: Optional[Union[List[List[int]], List[int]]] = None,
138
- # The image input. It can be a file name, a url, or base64 encoded string.
139
- # See also python/sglang/srt/utils.py:load_image.
140
- image_data: Optional[Union[List[str], str]] = None,
140
+ # The image input. It can be an image instance, file name, URL, or base64 encoded string.
141
+ # Can be formatted as:
142
+ # - Single image for a single request
143
+ # - List of images (one per request in a batch)
144
+ # - List of lists of images (multiple images per request)
145
+ # See also python/sglang/srt/utils.py:load_image for more details.
146
+ image_data: Optional[
147
+ Union[
148
+ List[List[Union[Image, str]]],
149
+ List[Union[Image, str]],
150
+ Union[Image, str],
151
+ ]
152
+ ] = None,
141
153
  return_logprob: Optional[Union[List[bool], bool]] = False,
142
154
  logprob_start_len: Optional[Union[List[int], int]] = None,
143
155
  top_logprobs_num: Optional[Union[List[int], int]] = None,
@@ -190,9 +202,19 @@ class Engine:
190
202
  sampling_params: Optional[Union[List[Dict], Dict]] = None,
191
203
  # The token ids for text; one can either specify text or input_ids.
192
204
  input_ids: Optional[Union[List[List[int]], List[int]]] = None,
193
- # The image input. It can be a file name, a url, or base64 encoded string.
194
- # See also python/sglang/srt/utils.py:load_image.
195
- image_data: Optional[Union[List[str], str]] = None,
205
+ # The image input. It can be an image instance, file name, URL, or base64 encoded string.
206
+ # Can be formatted as:
207
+ # - Single image for a single request
208
+ # - List of images (one per request in a batch)
209
+ # - List of lists of images (multiple images per request)
210
+ # See also python/sglang/srt/utils.py:load_image for more details.
211
+ image_data: Optional[
212
+ Union[
213
+ List[List[Union[Image, str]]],
214
+ List[Union[Image, str]],
215
+ Union[Image, str],
216
+ ]
217
+ ] = None,
196
218
  return_logprob: Optional[Union[List[bool], bool]] = False,
197
219
  logprob_start_len: Optional[Union[List[int], int]] = None,
198
220
  top_logprobs_num: Optional[Union[List[int], int]] = None,
@@ -228,7 +250,13 @@ class Engine:
228
250
  def encode(
229
251
  self,
230
252
  prompt: Union[str, List[str], List[Dict], List[List[Dict]]],
231
- image_data: Optional[Union[List[str], str]] = None,
253
+ image_data: Optional[
254
+ Union[
255
+ List[List[Union[Image, str]]],
256
+ List[Union[Image, str]],
257
+ Union[Image, str],
258
+ ]
259
+ ] = None,
232
260
  ) -> Dict:
233
261
  """
234
262
  The arguments of this function is the same as `sglang/srt/managers/io_struct.py::EmbeddingReqInput`.
@@ -25,8 +25,11 @@ import multiprocessing as multiprocessing
25
25
  import os
26
26
  import threading
27
27
  import time
28
+ from ast import Mult
28
29
  from http import HTTPStatus
29
- from typing import AsyncIterator, Callable, Dict, Optional
30
+ from typing import AsyncIterator, Callable, Dict, Optional, Union
31
+
32
+ from sglang.srt.model_executor.model_runner import LocalSerializedTensor
30
33
 
31
34
  # Fix a bug of Python threading
32
35
  setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
@@ -60,6 +63,7 @@ from sglang.srt.managers.io_struct import (
60
63
  SetInternalStateReq,
61
64
  UpdateWeightFromDiskReqInput,
62
65
  UpdateWeightsFromDistributedReqInput,
66
+ UpdateWeightsFromTensorReqInput,
63
67
  VertexGenerateReqInput,
64
68
  )
65
69
  from sglang.srt.managers.tokenizer_manager import TokenizerManager
@@ -80,6 +84,7 @@ from sglang.srt.openai_api.protocol import ModelCard, ModelList
80
84
  from sglang.srt.reasoning_parser import ReasoningParser
81
85
  from sglang.srt.server_args import ServerArgs
82
86
  from sglang.srt.utils import (
87
+ MultiprocessingSerializer,
83
88
  add_api_key_middleware,
84
89
  add_prometheus_middleware,
85
90
  delete_directory,
@@ -411,6 +416,26 @@ async def init_weights_update_group(
411
416
  return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST)
412
417
 
413
418
 
419
+ @app.post("/update_weights_from_tensor")
420
+ async def update_weights_from_tensor(
421
+ obj: UpdateWeightsFromTensorReqInput, request: Request
422
+ ):
423
+ """Update the weights from tensor inplace without re-launching the server.
424
+ Notes:
425
+ 1. Ensure that the model is on the correct device (e.g., GPU) before calling this endpoint. If the model is moved to the CPU unexpectedly, it may cause performance issues or runtime errors.
426
+ 2. HTTP will transmit only the metadata of the tensor, while the tensor itself will be directly copied to the model.
427
+ 3. Any binary data in the named tensors should be base64 encoded.
428
+ """
429
+
430
+ success, message = await _global_state.tokenizer_manager.update_weights_from_tensor(
431
+ obj, request
432
+ )
433
+ content = {"success": success, "message": message}
434
+ return ORJSONResponse(
435
+ content, status_code=200 if success else HTTPStatus.BAD_REQUEST
436
+ )
437
+
438
+
414
439
  @app.post("/update_weights_from_distributed")
415
440
  async def update_weights_from_distributed(
416
441
  obj: UpdateWeightsFromDistributedReqInput, request: Request
@@ -785,13 +810,17 @@ def _wait_and_warmup(
785
810
  json_data["sampling_params"]["max_new_tokens"] = 0
786
811
 
787
812
  try:
788
- res = requests.post(
789
- url + request_name,
790
- json=json_data,
791
- headers=headers,
792
- timeout=600,
793
- )
794
- assert res.status_code == 200, f"{res}"
813
+ if server_args.disaggregation_mode == "null":
814
+ res = requests.post(
815
+ url + request_name,
816
+ json=json_data,
817
+ headers=headers,
818
+ timeout=600,
819
+ )
820
+ assert res.status_code == 200, f"{res}"
821
+ else:
822
+ # Warmup request currently hangs in disaggregation mode, so we skip it.
823
+ logger.info("Skipping warmup request in disaggregation mode")
795
824
  except Exception:
796
825
  last_traceback = get_exception_traceback()
797
826
  if pipe_finish_writer is not None:
@@ -0,0 +1,142 @@
1
+ import base64
2
+ import copy
3
+ import dataclasses
4
+ import multiprocessing
5
+ import pickle
6
+ import threading
7
+ import time
8
+ from typing import Any, Dict, List, Optional, Tuple, Union
9
+
10
+ import requests
11
+ import torch
12
+ import torch.distributed as dist
13
+
14
+ from sglang.srt.entrypoints.EngineBase import EngineBase
15
+ from sglang.srt.entrypoints.http_server import launch_server
16
+ from sglang.srt.server_args import ServerArgs
17
+ from sglang.srt.utils import MultiprocessingSerializer, kill_process_tree
18
+
19
+
20
+ def launch_server_process(server_args: ServerArgs) -> multiprocessing.Process:
21
+
22
+ p = multiprocessing.Process(target=launch_server, args=(server_args,))
23
+ p.start()
24
+
25
+ base_url = server_args.url()
26
+ timeout = 300.0 # Increased timeout to 5 minutes for downloading large models
27
+ start_time = time.time()
28
+
29
+ with requests.Session() as session:
30
+ while time.time() - start_time < timeout:
31
+ try:
32
+ headers = {
33
+ "Content-Type": "application/json; charset=utf-8",
34
+ "Authorization": f"Bearer {server_args.api_key}",
35
+ }
36
+ response = session.get(f"{base_url}/health_generate", headers=headers)
37
+ if response.status_code == 200:
38
+ return p
39
+ except requests.RequestException:
40
+ pass
41
+
42
+ if not p.is_alive():
43
+ raise Exception("Server process terminated unexpectedly.")
44
+
45
+ time.sleep(2)
46
+
47
+ p.terminate()
48
+ raise TimeoutError("Server failed to start within the timeout period.")
49
+
50
+
51
+ class HttpServerEngineAdapter(EngineBase):
52
+ """
53
+ You can use this class to launch a server from a VerlEngine instance.
54
+ We recommend using this class only you need to use http server.
55
+ Otherwise, you can use Engine directly.
56
+ """
57
+
58
+ def __init__(self, **kwargs):
59
+ self.server_args = ServerArgs(**kwargs)
60
+ print(
61
+ f"Launch HttpServerEngineAdapter at: {self.server_args.host}:{self.server_args.port}"
62
+ )
63
+ self.process = launch_server_process(self.server_args)
64
+
65
+ def _make_request(self, endpoint: str, payload: Optional[dict] = None):
66
+ """Make a POST request to the specified endpoint with the given payload.
67
+
68
+ Args:
69
+ endpoint: The API endpoint to call
70
+ payload: The JSON payload to send (default: empty dict)
71
+
72
+ Returns:
73
+ The JSON response from the server
74
+ """
75
+ url = f"http://{self.server_args.host}:{self.server_args.port}/{endpoint}"
76
+ response = requests.post(url, json=payload or {})
77
+ response.raise_for_status()
78
+ return response.json()
79
+
80
+ def update_weights_from_tensor(
81
+ self,
82
+ named_tensors: List[Tuple[str, torch.Tensor]],
83
+ load_format: Optional[str] = None,
84
+ flush_cache: bool = False,
85
+ ):
86
+ """
87
+ Update model weights from tensor data. The HTTP server will only post meta data, and the real weights will be copied directly from GPUs.
88
+
89
+ Note: The model should be on GPUs rather than CPU for this functionality to work properly.
90
+ If you encounter issues, ensure your model is loaded on GPU devices rather than CPU.
91
+ """
92
+
93
+ return self._make_request(
94
+ "update_weights_from_tensor",
95
+ {
96
+ "serialized_named_tensors": [
97
+ MultiprocessingSerializer.serialize(named_tensors, output_str=True)
98
+ for _ in range(self.server_args.tp_size)
99
+ ],
100
+ "load_format": load_format,
101
+ "flush_cache": flush_cache,
102
+ },
103
+ )
104
+
105
+ def shutdown(self):
106
+ kill_process_tree(self.process.pid)
107
+
108
+ def generate(
109
+ self,
110
+ prompt=None,
111
+ sampling_params=None,
112
+ input_ids=None,
113
+ image_data=None,
114
+ return_logprob=False,
115
+ logprob_start_len=None,
116
+ top_logprobs_num=None,
117
+ token_ids_logprob=None,
118
+ lora_path=None,
119
+ custom_logit_processor=None,
120
+ ):
121
+ payload = {
122
+ "text": prompt,
123
+ "sampling_params": sampling_params,
124
+ "input_ids": input_ids,
125
+ "image_data": image_data,
126
+ "return_logprob": return_logprob,
127
+ "logprob_start_len": logprob_start_len,
128
+ "top_logprobs_num": top_logprobs_num,
129
+ "token_ids_logprob": token_ids_logprob,
130
+ "lora_path": lora_path,
131
+ "custom_logit_processor": custom_logit_processor,
132
+ }
133
+ # Filter out None values
134
+ payload = {k: v for k, v in payload.items() if v is not None}
135
+
136
+ return self._make_request("generate", payload)
137
+
138
+ def release_memory_occupation(self):
139
+ return self._make_request("release_memory_occupation")
140
+
141
+ def resume_memory_occupation(self):
142
+ return self._make_request("resume_memory_occupation")