sglang 0.4.5__py3-none-any.whl → 0.4.5.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (121) hide show
  1. sglang/bench_one_batch.py +21 -0
  2. sglang/bench_serving.py +10 -4
  3. sglang/srt/configs/model_config.py +37 -5
  4. sglang/srt/constrained/base_grammar_backend.py +26 -5
  5. sglang/srt/constrained/llguidance_backend.py +1 -0
  6. sglang/srt/constrained/outlines_backend.py +1 -0
  7. sglang/srt/constrained/reasoner_grammar_backend.py +101 -0
  8. sglang/srt/constrained/xgrammar_backend.py +1 -0
  9. sglang/srt/disaggregation/base/__init__.py +8 -0
  10. sglang/srt/disaggregation/base/conn.py +113 -0
  11. sglang/srt/disaggregation/decode.py +18 -5
  12. sglang/srt/disaggregation/mini_lb.py +53 -122
  13. sglang/srt/disaggregation/mooncake/__init__.py +6 -0
  14. sglang/srt/disaggregation/mooncake/conn.py +615 -0
  15. sglang/srt/disaggregation/mooncake/transfer_engine.py +108 -0
  16. sglang/srt/disaggregation/prefill.py +43 -19
  17. sglang/srt/disaggregation/utils.py +31 -0
  18. sglang/srt/entrypoints/EngineBase.py +53 -0
  19. sglang/srt/entrypoints/engine.py +36 -8
  20. sglang/srt/entrypoints/http_server.py +37 -8
  21. sglang/srt/entrypoints/http_server_engine.py +142 -0
  22. sglang/srt/entrypoints/verl_engine.py +37 -10
  23. sglang/srt/hf_transformers_utils.py +4 -0
  24. sglang/srt/layers/attention/flashattention_backend.py +330 -200
  25. sglang/srt/layers/attention/flashinfer_backend.py +13 -7
  26. sglang/srt/layers/attention/vision.py +1 -1
  27. sglang/srt/layers/dp_attention.py +2 -4
  28. sglang/srt/layers/elementwise.py +15 -2
  29. sglang/srt/layers/linear.py +1 -0
  30. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +145 -118
  31. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  32. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  33. sglang/srt/layers/moe/fused_moe_triton/configs/{E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=264,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +34 -34
  34. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  35. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  36. sglang/srt/layers/moe/fused_moe_triton/configs/E=288,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  37. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +38 -21
  38. sglang/srt/layers/moe/router.py +7 -1
  39. sglang/srt/layers/moe/topk.py +37 -16
  40. sglang/srt/layers/quantization/__init__.py +12 -5
  41. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +4 -0
  42. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +68 -45
  43. sglang/srt/layers/quantization/fp8.py +25 -13
  44. sglang/srt/layers/quantization/fp8_kernel.py +130 -4
  45. sglang/srt/layers/quantization/fp8_utils.py +34 -6
  46. sglang/srt/layers/quantization/kv_cache.py +43 -52
  47. sglang/srt/layers/quantization/modelopt_quant.py +271 -4
  48. sglang/srt/layers/quantization/w8a8_fp8.py +154 -4
  49. sglang/srt/layers/quantization/w8a8_int8.py +1 -0
  50. sglang/srt/layers/radix_attention.py +13 -1
  51. sglang/srt/layers/rotary_embedding.py +12 -1
  52. sglang/srt/managers/io_struct.py +254 -97
  53. sglang/srt/managers/mm_utils.py +3 -2
  54. sglang/srt/managers/multimodal_processors/base_processor.py +114 -77
  55. sglang/srt/managers/multimodal_processors/janus_pro.py +3 -1
  56. sglang/srt/managers/multimodal_processors/mllama4.py +21 -36
  57. sglang/srt/managers/schedule_batch.py +62 -21
  58. sglang/srt/managers/scheduler.py +71 -14
  59. sglang/srt/managers/tokenizer_manager.py +17 -3
  60. sglang/srt/managers/tp_worker.py +1 -0
  61. sglang/srt/mem_cache/memory_pool.py +14 -1
  62. sglang/srt/metrics/collector.py +9 -0
  63. sglang/srt/model_executor/cuda_graph_runner.py +7 -4
  64. sglang/srt/model_executor/forward_batch_info.py +234 -15
  65. sglang/srt/model_executor/model_runner.py +48 -9
  66. sglang/srt/model_loader/loader.py +31 -4
  67. sglang/srt/model_loader/weight_utils.py +4 -2
  68. sglang/srt/models/baichuan.py +2 -0
  69. sglang/srt/models/chatglm.py +1 -0
  70. sglang/srt/models/commandr.py +1 -0
  71. sglang/srt/models/dbrx.py +1 -0
  72. sglang/srt/models/deepseek.py +1 -0
  73. sglang/srt/models/deepseek_v2.py +248 -61
  74. sglang/srt/models/exaone.py +1 -0
  75. sglang/srt/models/gemma.py +1 -0
  76. sglang/srt/models/gemma2.py +1 -0
  77. sglang/srt/models/gemma3_causal.py +1 -0
  78. sglang/srt/models/gpt2.py +1 -0
  79. sglang/srt/models/gpt_bigcode.py +1 -0
  80. sglang/srt/models/granite.py +1 -0
  81. sglang/srt/models/grok.py +1 -0
  82. sglang/srt/models/internlm2.py +1 -0
  83. sglang/srt/models/llama.py +1 -0
  84. sglang/srt/models/llama4.py +101 -34
  85. sglang/srt/models/minicpm.py +1 -0
  86. sglang/srt/models/minicpm3.py +2 -0
  87. sglang/srt/models/mixtral.py +1 -0
  88. sglang/srt/models/mixtral_quant.py +1 -0
  89. sglang/srt/models/mllama.py +51 -8
  90. sglang/srt/models/mllama4.py +102 -29
  91. sglang/srt/models/olmo.py +1 -0
  92. sglang/srt/models/olmo2.py +1 -0
  93. sglang/srt/models/olmoe.py +1 -0
  94. sglang/srt/models/phi3_small.py +1 -0
  95. sglang/srt/models/qwen.py +1 -0
  96. sglang/srt/models/qwen2.py +1 -0
  97. sglang/srt/models/qwen2_5_vl.py +35 -70
  98. sglang/srt/models/qwen2_moe.py +1 -0
  99. sglang/srt/models/qwen2_vl.py +27 -25
  100. sglang/srt/models/stablelm.py +1 -0
  101. sglang/srt/models/xverse.py +1 -0
  102. sglang/srt/models/xverse_moe.py +1 -0
  103. sglang/srt/openai_api/adapter.py +4 -1
  104. sglang/srt/patch_torch.py +11 -0
  105. sglang/srt/server_args.py +34 -0
  106. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +4 -4
  107. sglang/srt/speculative/eagle_utils.py +1 -11
  108. sglang/srt/speculative/eagle_worker.py +6 -2
  109. sglang/srt/utils.py +120 -9
  110. sglang/test/attention/test_flashattn_backend.py +259 -221
  111. sglang/test/attention/test_flashattn_mla_backend.py +285 -0
  112. sglang/test/attention/test_prefix_chunk_info.py +224 -0
  113. sglang/test/test_block_fp8.py +57 -0
  114. sglang/test/test_utils.py +19 -8
  115. sglang/version.py +1 -1
  116. {sglang-0.4.5.dist-info → sglang-0.4.5.post1.dist-info}/METADATA +14 -4
  117. {sglang-0.4.5.dist-info → sglang-0.4.5.post1.dist-info}/RECORD +120 -106
  118. sglang/srt/disaggregation/conn.py +0 -81
  119. {sglang-0.4.5.dist-info → sglang-0.4.5.post1.dist-info}/WHEEL +0 -0
  120. {sglang-0.4.5.dist-info → sglang-0.4.5.post1.dist-info}/licenses/LICENSE +0 -0
  121. {sglang-0.4.5.dist-info → sglang-0.4.5.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,108 @@
1
+ import json
2
+ import logging
3
+ import os
4
+ import uuid
5
+ from dataclasses import dataclass
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+
10
+ @dataclass
11
+ class MooncakeTransferEngineConfig:
12
+ local_hostname: str
13
+ metadata_server: str
14
+ protocol: str
15
+ device_name: str
16
+
17
+ @staticmethod
18
+ def from_file(file_path: str) -> "MooncakeTransferEngineConfig":
19
+ """Load the config from a JSON file."""
20
+ with open(file_path) as fin:
21
+ config = json.load(fin)
22
+ return MooncakeTransferEngineConfig(
23
+ local_hostname=config.get("local_hostname", None),
24
+ metadata_server=config.get("metadata_server"),
25
+ protocol=config.get("protocol", "rdma"),
26
+ device_name=config.get("device_name", ""),
27
+ )
28
+
29
+ @staticmethod
30
+ def load_from_env() -> "MooncakeTransferEngineConfig":
31
+ """Load config from a file specified in the environment variable."""
32
+ config_file_path = os.getenv("MOONCAKE_CONFIG_PATH")
33
+ if config_file_path is None:
34
+ raise ValueError(
35
+ "The environment variable 'MOONCAKE_CONFIG_PATH' is not set."
36
+ )
37
+ return MooncakeTransferEngineConfig.from_file(config_file_path)
38
+
39
+
40
+ class MooncakeTransferEngine:
41
+
42
+ def __init__(self):
43
+ try:
44
+ from mooncake.engine import TransferEngine
45
+ except ImportError as e:
46
+ raise ImportError(
47
+ "Please install mooncake by following the instructions at "
48
+ "https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md " # noqa: E501
49
+ "to run SGLang with MooncakeTransferEngine."
50
+ ) from e
51
+
52
+ self.engine = TransferEngine()
53
+
54
+ try:
55
+ self.config = MooncakeTransferEngineConfig.load_from_env()
56
+ logger.info("Mooncake Configuration loaded successfully.")
57
+ except ValueError as e:
58
+ logger.error(e)
59
+ raise
60
+ except Exception as exc:
61
+ logger.error("An error occurred while loading the configuration: %s", exc)
62
+ raise
63
+
64
+ self.config = MooncakeTransferEngineConfig.load_from_env()
65
+
66
+ session_suffix = "_" + str(uuid.uuid4())
67
+ self.session_id = self.config.local_hostname + session_suffix
68
+ self.initialize(
69
+ self.session_id,
70
+ self.config.metadata_server,
71
+ self.config.protocol,
72
+ self.config.device_name,
73
+ )
74
+
75
+ def register(self, ptr, length):
76
+ self.engine.register_memory(ptr, length)
77
+
78
+ def deregister(self, ptr):
79
+ self.engine.unregister_memory(ptr)
80
+
81
+ def initialize(
82
+ self,
83
+ local_hostname: str,
84
+ metadata_server: str,
85
+ protocol: str,
86
+ device_name: str,
87
+ ) -> None:
88
+ """Initialize the mooncake instance."""
89
+ self.engine.initialize(local_hostname, metadata_server, protocol, device_name)
90
+
91
+ def transfer_sync(
92
+ self, session_id: str, buffer: int, peer_buffer_address: int, length: int
93
+ ) -> int:
94
+ """Synchronously transfer data to the specified address."""
95
+
96
+ ret = self.engine.transfer_sync_write(
97
+ session_id, buffer, peer_buffer_address, length
98
+ )
99
+ if ret < 0:
100
+ logger.error("Transfer Return Error")
101
+ raise Exception("Transfer Return Error")
102
+ return ret
103
+
104
+ def get_localhost(self):
105
+ return self.config.local_hostname
106
+
107
+ def get_session_id(self):
108
+ return self.session_id
@@ -10,9 +10,9 @@ Life cycle of a request in the prefill server
10
10
  2. Waiting Queue
11
11
  a. Use PrefillAdder to pop requests
12
12
  b. Run forward
13
- c. Add the request to Infight Queue
13
+ c. Add the request to Inflight Queue
14
14
 
15
- 3. Infight Queue
15
+ 3. Inflight Queue
16
16
  a. Poll (non-blocking) the sender of the request
17
17
  b. Once the transfer has finished, return the request
18
18
  """
@@ -24,9 +24,13 @@ from typing import TYPE_CHECKING, List, Optional
24
24
 
25
25
  import torch
26
26
 
27
- from sglang.srt.disaggregation.conn import KVArgs, KVManager, KVPoll, KVSender
27
+ from sglang.srt.disaggregation.base import BaseKVManager, KVArgs, KVPoll
28
28
  from sglang.srt.disaggregation.utils import (
29
+ DisaggregationMode,
30
+ KVClassType,
29
31
  ReqToMetadataIdxAllocator,
32
+ TransferBackend,
33
+ get_kv_class,
30
34
  poll_and_all_reduce,
31
35
  )
32
36
  from sglang.srt.managers.schedule_batch import FINISH_LENGTH, Req, ScheduleBatch
@@ -37,6 +41,7 @@ if TYPE_CHECKING:
37
41
  from sglang.srt.managers.scheduler import GenerationBatchResult, Scheduler
38
42
  from sglang.srt.mem_cache.memory_pool import KVCache
39
43
 
44
+
40
45
  logger = logging.getLogger(__name__)
41
46
 
42
47
 
@@ -55,6 +60,8 @@ class PrefillBootstrapQueue:
55
60
  tp_size: int,
56
61
  bootstrap_port: int,
57
62
  gloo_group: ProcessGroup,
63
+ transfer_backend: TransferBackend,
64
+ scheduler: Scheduler,
58
65
  ):
59
66
  self.token_to_kv_pool = token_to_kv_pool
60
67
  self.aux_dtype = aux_dtype
@@ -63,17 +70,19 @@ class PrefillBootstrapQueue:
63
70
  self.req_to_metadata_buffer_idx_allocator = req_to_metadata_buffer_idx_allocator
64
71
  self.tp_rank = tp_rank
65
72
  self.tp_size = tp_size
73
+ self.transfer_backend = transfer_backend
74
+ self.scheduler = scheduler
66
75
  self.kv_manager = self._init_kv_manager()
67
76
  self.queue: List[Req] = []
68
77
  self.gloo_group = gloo_group
69
78
  self.bootstrap_port = bootstrap_port
70
79
 
71
- def allocate_token_id(self, idx: int, token_id: int):
80
+ def store_prefill_results(self, idx: int, token_id: int):
72
81
  assert token_id >= 0, f"token_id: {token_id} is negative"
73
82
  output_id_buffer = self.metadata_buffers[0]
74
83
  output_id_buffer[idx] = token_id
75
84
 
76
- def _init_kv_manager(self) -> KVManager:
85
+ def _init_kv_manager(self) -> BaseKVManager:
77
86
  kv_args = KVArgs()
78
87
  kv_args.engine_rank = self.tp_rank
79
88
  kv_data_ptrs, kv_data_lens, kv_item_lens = (
@@ -95,11 +104,16 @@ class PrefillBootstrapQueue:
95
104
  metadata_buffer[0].nbytes for metadata_buffer in self.metadata_buffers
96
105
  ]
97
106
  kv_args.ib_device = "mock-ib-device"
98
- kv_manager = KVManager(kv_args)
107
+ kv_args.gpu_id = self.scheduler.gpu_id
108
+ kv_manager_class = get_kv_class(self.transfer_backend, KVClassType.MANAGER)
109
+ kv_manager = kv_manager_class(
110
+ kv_args, DisaggregationMode.PREFILL, self.scheduler.server_args
111
+ )
99
112
  return kv_manager
100
113
 
101
114
  def add(self, req: Req) -> None:
102
- req.disagg_kv_sender = KVSender(
115
+ kv_sender_class = get_kv_class(self.transfer_backend, KVClassType.SENDER)
116
+ req.disagg_kv_sender = kv_sender_class(
103
117
  mgr=self.kv_manager,
104
118
  bootstrap_addr=f"{req.bootstrap_host}:{self.bootstrap_port}",
105
119
  bootstrap_room=req.bootstrap_room,
@@ -131,7 +145,7 @@ class PrefillBootstrapQueue:
131
145
  elif poll == KVPoll.Failed:
132
146
  raise Exception("Bootstrap failed")
133
147
 
134
- # KV.WaitingForInput - init here
148
+ # KV.WaitingForInput
135
149
  num_kv_indices = len(req.origin_input_ids)
136
150
  if self.req_to_metadata_buffer_idx_allocator.available_size() == 0:
137
151
  break
@@ -161,7 +175,7 @@ class SchedulerDisaggregationPrefillMixin:
161
175
  self: Scheduler, batch: ScheduleBatch, result: GenerationBatchResult
162
176
  ) -> None:
163
177
  """
164
- Transfer kv for prefill completed requests and add it into disagg_prefill_infight_queue
178
+ Transfer kv for prefill completed requests and add it into disagg_prefill_inflight_queue
165
179
  Adapted from process_batch_result_prefill
166
180
  """
167
181
 
@@ -174,7 +188,7 @@ class SchedulerDisaggregationPrefillMixin:
174
188
  req.output_ids.append(next_token_id)
175
189
  self.tree_cache.cache_unfinished_req(req) # update the tree and lock
176
190
  self.send_kv_chunk(req, token_id=next_token_id)
177
- self.disagg_prefill_infight_queue.append(req)
191
+ self.disagg_prefill_inflight_queue.append(req)
178
192
  else:
179
193
  # being chunked reqs' prefill is not finished
180
194
  req.is_chunked -= 1
@@ -186,35 +200,41 @@ class SchedulerDisaggregationPrefillMixin:
186
200
  self.current_stream.synchronize()
187
201
  batch.next_batch_sampling_info.sampling_info_done.set()
188
202
 
189
- def process_disagg_prefill_infight_queue(self: Scheduler) -> None:
203
+ def process_disagg_prefill_inflight_queue(self: Scheduler) -> None:
190
204
  """
191
205
  Poll the requests in the middle of transfer. If done, return the request.
192
206
  """
193
- assert len(self.disagg_prefill_infight_queue) > 0
207
+ assert len(self.disagg_prefill_inflight_queue) > 0
194
208
 
195
209
  done_reqs = []
196
210
 
197
211
  polls = poll_and_all_reduce(
198
- [req.disagg_kv_sender for req in self.disagg_prefill_infight_queue],
212
+ [req.disagg_kv_sender for req in self.disagg_prefill_inflight_queue],
199
213
  self.tp_worker.get_tp_cpu_group(),
200
214
  )
201
215
 
202
216
  undone_reqs: List[Req] = []
203
- # Check .poll() for the reqs in disagg_prefill_infight_queue. If Success, respond to the client and remove it from the queue
204
- for req, poll in zip(self.disagg_prefill_infight_queue, polls):
217
+ # Check .poll() for the reqs in disagg_prefill_inflight_queue. If Success, respond to the client and remove it from the queue
218
+ for req, poll in zip(self.disagg_prefill_inflight_queue, polls):
205
219
  if poll in [KVPoll.WaitingForInput, KVPoll.Transferring]:
206
220
  undone_reqs.append(req)
207
221
  elif poll == KVPoll.Success: # transfer done
208
222
  self.tree_cache.cache_finished_req(req) # unlock the tree
209
223
  req.finished_reason = FINISH_LENGTH(length=0)
224
+ # FIXME: clean up req's data in transfer engine
210
225
  done_reqs.append(req)
211
226
  elif poll == KVPoll.Failed:
212
227
  raise Exception("Transferring failed")
213
228
 
229
+ for req in done_reqs:
230
+ self.disagg_prefill_pending_queue.req_to_metadata_buffer_idx_allocator.free(
231
+ req.metadata_buffer_index
232
+ )
233
+
214
234
  # Stream requests which have finished transfer
215
235
  self.stream_output(done_reqs, False, None)
216
236
 
217
- self.disagg_prefill_infight_queue = undone_reqs
237
+ self.disagg_prefill_inflight_queue = undone_reqs
218
238
 
219
239
  def process_prefill_chunk(self: Scheduler) -> None:
220
240
  if self.last_batch and self.last_batch.forward_mode.is_extend():
@@ -236,14 +256,18 @@ class SchedulerDisaggregationPrefillMixin:
236
256
  """
237
257
  start_idx = req.start_send_idx
238
258
  end_idx = min(len(req.fill_ids), len(req.origin_input_ids))
259
+
260
+ # Update next start_send_idx
261
+ req.start_send_idx = end_idx
262
+
239
263
  kv_indices = (
240
264
  self.req_to_token_pool.req_to_token[req.req_pool_idx][start_idx:end_idx]
241
265
  .cpu()
242
266
  .numpy()
243
267
  )
244
- req.start_send_idx = end_idx
245
268
  if token_id is not None:
246
- self.disagg_prefill_pending_queue.allocate_token_id(
269
+ self.disagg_prefill_pending_queue.store_prefill_results(
247
270
  req.metadata_buffer_index, token_id
248
271
  )
249
- req.disagg_kv_sender.send(kv_indices)
272
+ is_last = token_id is not None
273
+ req.disagg_kv_sender.send(kv_indices, slice(start_idx, end_idx), is_last)
@@ -42,3 +42,34 @@ class ReqToMetadataIdxAllocator:
42
42
 
43
43
  def free(self, free_index: int):
44
44
  self.free_slots.append(free_index)
45
+
46
+
47
+ class TransferBackend(Enum):
48
+ MOONCAKE = "mooncake"
49
+ FAKE = "fake"
50
+
51
+
52
+ class KVClassType(Enum):
53
+ MANAGER = "manager"
54
+ SENDER = "sender"
55
+ RECEIVER = "receiver"
56
+ BOOTSTRAP_SERVER = "bootstrap_server"
57
+
58
+
59
+ def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
60
+ if transfer_backend == TransferBackend.MOONCAKE:
61
+ from sglang.srt.disaggregation.mooncake import (
62
+ MooncakeKVBootstrapServer,
63
+ MooncakeKVManager,
64
+ MooncakeKVReceiver,
65
+ MooncakeKVSender,
66
+ )
67
+
68
+ class_mapping = {
69
+ KVClassType.MANAGER: MooncakeKVManager,
70
+ KVClassType.SENDER: MooncakeKVSender,
71
+ KVClassType.RECEIVER: MooncakeKVReceiver,
72
+ KVClassType.BOOTSTRAP_SERVER: MooncakeKVBootstrapServer,
73
+ }
74
+ return class_mapping.get(class_type)
75
+ raise ValueError(f"Unsupported transfer backend: {transfer_backend}")
@@ -0,0 +1,53 @@
1
+ from abc import ABC, abstractmethod
2
+ from typing import Dict, Iterator, List, Optional, Tuple, Union
3
+
4
+ import torch
5
+
6
+
7
+ class EngineBase(ABC):
8
+ """
9
+ Abstract base class for engine interfaces that support generation, weight updating, and memory control.
10
+ This base class provides a unified API for both HTTP-based engines and engines.
11
+ """
12
+
13
+ @abstractmethod
14
+ def generate(
15
+ self,
16
+ prompt: Optional[Union[List[str], str]] = None,
17
+ sampling_params: Optional[Union[List[Dict], Dict]] = None,
18
+ input_ids: Optional[Union[List[List[int]], List[int]]] = None,
19
+ image_data: Optional[Union[List[str], str]] = None,
20
+ return_logprob: Optional[Union[List[bool], bool]] = False,
21
+ logprob_start_len: Optional[Union[List[int], int]] = None,
22
+ top_logprobs_num: Optional[Union[List[int], int]] = None,
23
+ token_ids_logprob: Optional[Union[List[List[int]], List[int]]] = None,
24
+ lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None,
25
+ custom_logit_processor: Optional[Union[List[str], str]] = None,
26
+ ) -> Union[Dict, Iterator[Dict]]:
27
+ """Generate outputs based on given inputs."""
28
+ pass
29
+
30
+ @abstractmethod
31
+ def update_weights_from_tensor(
32
+ self,
33
+ named_tensors: List[Tuple[str, torch.Tensor]],
34
+ load_format: Optional[str] = None,
35
+ flush_cache: bool = True,
36
+ ):
37
+ """Update model weights with in-memory tensor data."""
38
+ pass
39
+
40
+ @abstractmethod
41
+ def release_memory_occupation(self):
42
+ """Release GPU memory occupation temporarily."""
43
+ pass
44
+
45
+ @abstractmethod
46
+ def resume_memory_occupation(self):
47
+ """Resume GPU memory occupation which is previously released."""
48
+ pass
49
+
50
+ @abstractmethod
51
+ def shutdown(self):
52
+ """Shutdown the engine and clean up resources."""
53
+ pass
@@ -29,6 +29,7 @@ from typing import AsyncIterator, Dict, Iterator, List, Optional, Tuple, Union
29
29
 
30
30
  import zmq
31
31
  import zmq.asyncio
32
+ from PIL.Image import Image
32
33
 
33
34
  # Fix a bug of Python threading
34
35
  setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
@@ -37,6 +38,7 @@ import torch
37
38
  import uvloop
38
39
 
39
40
  from sglang.srt.code_completion_parser import load_completion_template_for_openai_api
41
+ from sglang.srt.entrypoints.EngineBase import EngineBase
40
42
  from sglang.srt.managers.data_parallel_controller import (
41
43
  run_data_parallel_controller_process,
42
44
  )
@@ -77,7 +79,7 @@ logger = logging.getLogger(__name__)
77
79
  asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
78
80
 
79
81
 
80
- class Engine:
82
+ class Engine(EngineBase):
81
83
  """
82
84
  The entry point to the inference engine.
83
85
 
@@ -135,9 +137,19 @@ class Engine:
135
137
  sampling_params: Optional[Union[List[Dict], Dict]] = None,
136
138
  # The token ids for text; one can either specify text or input_ids.
137
139
  input_ids: Optional[Union[List[List[int]], List[int]]] = None,
138
- # The image input. It can be a file name, a url, or base64 encoded string.
139
- # See also python/sglang/srt/utils.py:load_image.
140
- image_data: Optional[Union[List[str], str]] = None,
140
+ # The image input. It can be an image instance, file name, URL, or base64 encoded string.
141
+ # Can be formatted as:
142
+ # - Single image for a single request
143
+ # - List of images (one per request in a batch)
144
+ # - List of lists of images (multiple images per request)
145
+ # See also python/sglang/srt/utils.py:load_image for more details.
146
+ image_data: Optional[
147
+ Union[
148
+ List[List[Union[Image, str]]],
149
+ List[Union[Image, str]],
150
+ Union[Image, str],
151
+ ]
152
+ ] = None,
141
153
  return_logprob: Optional[Union[List[bool], bool]] = False,
142
154
  logprob_start_len: Optional[Union[List[int], int]] = None,
143
155
  top_logprobs_num: Optional[Union[List[int], int]] = None,
@@ -190,9 +202,19 @@ class Engine:
190
202
  sampling_params: Optional[Union[List[Dict], Dict]] = None,
191
203
  # The token ids for text; one can either specify text or input_ids.
192
204
  input_ids: Optional[Union[List[List[int]], List[int]]] = None,
193
- # The image input. It can be a file name, a url, or base64 encoded string.
194
- # See also python/sglang/srt/utils.py:load_image.
195
- image_data: Optional[Union[List[str], str]] = None,
205
+ # The image input. It can be an image instance, file name, URL, or base64 encoded string.
206
+ # Can be formatted as:
207
+ # - Single image for a single request
208
+ # - List of images (one per request in a batch)
209
+ # - List of lists of images (multiple images per request)
210
+ # See also python/sglang/srt/utils.py:load_image for more details.
211
+ image_data: Optional[
212
+ Union[
213
+ List[List[Union[Image, str]]],
214
+ List[Union[Image, str]],
215
+ Union[Image, str],
216
+ ]
217
+ ] = None,
196
218
  return_logprob: Optional[Union[List[bool], bool]] = False,
197
219
  logprob_start_len: Optional[Union[List[int], int]] = None,
198
220
  top_logprobs_num: Optional[Union[List[int], int]] = None,
@@ -228,7 +250,13 @@ class Engine:
228
250
  def encode(
229
251
  self,
230
252
  prompt: Union[str, List[str], List[Dict], List[List[Dict]]],
231
- image_data: Optional[Union[List[str], str]] = None,
253
+ image_data: Optional[
254
+ Union[
255
+ List[List[Union[Image, str]]],
256
+ List[Union[Image, str]],
257
+ Union[Image, str],
258
+ ]
259
+ ] = None,
232
260
  ) -> Dict:
233
261
  """
234
262
  The arguments of this function is the same as `sglang/srt/managers/io_struct.py::EmbeddingReqInput`.
@@ -25,8 +25,11 @@ import multiprocessing as multiprocessing
25
25
  import os
26
26
  import threading
27
27
  import time
28
+ from ast import Mult
28
29
  from http import HTTPStatus
29
- from typing import AsyncIterator, Callable, Dict, Optional
30
+ from typing import AsyncIterator, Callable, Dict, Optional, Union
31
+
32
+ from sglang.srt.model_executor.model_runner import LocalSerializedTensor
30
33
 
31
34
  # Fix a bug of Python threading
32
35
  setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
@@ -60,6 +63,7 @@ from sglang.srt.managers.io_struct import (
60
63
  SetInternalStateReq,
61
64
  UpdateWeightFromDiskReqInput,
62
65
  UpdateWeightsFromDistributedReqInput,
66
+ UpdateWeightsFromTensorReqInput,
63
67
  VertexGenerateReqInput,
64
68
  )
65
69
  from sglang.srt.managers.tokenizer_manager import TokenizerManager
@@ -80,6 +84,7 @@ from sglang.srt.openai_api.protocol import ModelCard, ModelList
80
84
  from sglang.srt.reasoning_parser import ReasoningParser
81
85
  from sglang.srt.server_args import ServerArgs
82
86
  from sglang.srt.utils import (
87
+ MultiprocessingSerializer,
83
88
  add_api_key_middleware,
84
89
  add_prometheus_middleware,
85
90
  delete_directory,
@@ -411,6 +416,26 @@ async def init_weights_update_group(
411
416
  return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST)
412
417
 
413
418
 
419
+ @app.post("/update_weights_from_tensor")
420
+ async def update_weights_from_tensor(
421
+ obj: UpdateWeightsFromTensorReqInput, request: Request
422
+ ):
423
+ """Update the weights from tensor inplace without re-launching the server.
424
+ Notes:
425
+ 1. Ensure that the model is on the correct device (e.g., GPU) before calling this endpoint. If the model is moved to the CPU unexpectedly, it may cause performance issues or runtime errors.
426
+ 2. HTTP will transmit only the metadata of the tensor, while the tensor itself will be directly copied to the model.
427
+ 3. Any binary data in the named tensors should be base64 encoded.
428
+ """
429
+
430
+ success, message = await _global_state.tokenizer_manager.update_weights_from_tensor(
431
+ obj, request
432
+ )
433
+ content = {"success": success, "message": message}
434
+ return ORJSONResponse(
435
+ content, status_code=200 if success else HTTPStatus.BAD_REQUEST
436
+ )
437
+
438
+
414
439
  @app.post("/update_weights_from_distributed")
415
440
  async def update_weights_from_distributed(
416
441
  obj: UpdateWeightsFromDistributedReqInput, request: Request
@@ -785,13 +810,17 @@ def _wait_and_warmup(
785
810
  json_data["sampling_params"]["max_new_tokens"] = 0
786
811
 
787
812
  try:
788
- res = requests.post(
789
- url + request_name,
790
- json=json_data,
791
- headers=headers,
792
- timeout=600,
793
- )
794
- assert res.status_code == 200, f"{res}"
813
+ if server_args.disaggregation_mode == "null":
814
+ res = requests.post(
815
+ url + request_name,
816
+ json=json_data,
817
+ headers=headers,
818
+ timeout=600,
819
+ )
820
+ assert res.status_code == 200, f"{res}"
821
+ else:
822
+ # Warmup request currently hangs in disaggregation mode, so we skip it.
823
+ logger.info("Skipping warmup request in disaggregation mode")
795
824
  except Exception:
796
825
  last_traceback = get_exception_traceback()
797
826
  if pipe_finish_writer is not None:
@@ -0,0 +1,142 @@
1
+ import base64
2
+ import copy
3
+ import dataclasses
4
+ import multiprocessing
5
+ import pickle
6
+ import threading
7
+ import time
8
+ from typing import Any, Dict, List, Optional, Tuple, Union
9
+
10
+ import requests
11
+ import torch
12
+ import torch.distributed as dist
13
+
14
+ from sglang.srt.entrypoints.EngineBase import EngineBase
15
+ from sglang.srt.entrypoints.http_server import launch_server
16
+ from sglang.srt.server_args import ServerArgs
17
+ from sglang.srt.utils import MultiprocessingSerializer, kill_process_tree
18
+
19
+
20
+ def launch_server_process(server_args: ServerArgs) -> multiprocessing.Process:
21
+
22
+ p = multiprocessing.Process(target=launch_server, args=(server_args,))
23
+ p.start()
24
+
25
+ base_url = server_args.url()
26
+ timeout = 300.0 # Increased timeout to 5 minutes for downloading large models
27
+ start_time = time.time()
28
+
29
+ with requests.Session() as session:
30
+ while time.time() - start_time < timeout:
31
+ try:
32
+ headers = {
33
+ "Content-Type": "application/json; charset=utf-8",
34
+ "Authorization": f"Bearer {server_args.api_key}",
35
+ }
36
+ response = session.get(f"{base_url}/health_generate", headers=headers)
37
+ if response.status_code == 200:
38
+ return p
39
+ except requests.RequestException:
40
+ pass
41
+
42
+ if not p.is_alive():
43
+ raise Exception("Server process terminated unexpectedly.")
44
+
45
+ time.sleep(2)
46
+
47
+ p.terminate()
48
+ raise TimeoutError("Server failed to start within the timeout period.")
49
+
50
+
51
+ class HttpServerEngineAdapter(EngineBase):
52
+ """
53
+ You can use this class to launch a server from a VerlEngine instance.
54
+ We recommend using this class only you need to use http server.
55
+ Otherwise, you can use Engine directly.
56
+ """
57
+
58
+ def __init__(self, **kwargs):
59
+ self.server_args = ServerArgs(**kwargs)
60
+ print(
61
+ f"Launch HttpServerEngineAdapter at: {self.server_args.host}:{self.server_args.port}"
62
+ )
63
+ self.process = launch_server_process(self.server_args)
64
+
65
+ def _make_request(self, endpoint: str, payload: Optional[dict] = None):
66
+ """Make a POST request to the specified endpoint with the given payload.
67
+
68
+ Args:
69
+ endpoint: The API endpoint to call
70
+ payload: The JSON payload to send (default: empty dict)
71
+
72
+ Returns:
73
+ The JSON response from the server
74
+ """
75
+ url = f"http://{self.server_args.host}:{self.server_args.port}/{endpoint}"
76
+ response = requests.post(url, json=payload or {})
77
+ response.raise_for_status()
78
+ return response.json()
79
+
80
+ def update_weights_from_tensor(
81
+ self,
82
+ named_tensors: List[Tuple[str, torch.Tensor]],
83
+ load_format: Optional[str] = None,
84
+ flush_cache: bool = False,
85
+ ):
86
+ """
87
+ Update model weights from tensor data. The HTTP server will only post meta data, and the real weights will be copied directly from GPUs.
88
+
89
+ Note: The model should be on GPUs rather than CPU for this functionality to work properly.
90
+ If you encounter issues, ensure your model is loaded on GPU devices rather than CPU.
91
+ """
92
+
93
+ return self._make_request(
94
+ "update_weights_from_tensor",
95
+ {
96
+ "serialized_named_tensors": [
97
+ MultiprocessingSerializer.serialize(named_tensors, output_str=True)
98
+ for _ in range(self.server_args.tp_size)
99
+ ],
100
+ "load_format": load_format,
101
+ "flush_cache": flush_cache,
102
+ },
103
+ )
104
+
105
+ def shutdown(self):
106
+ kill_process_tree(self.process.pid)
107
+
108
+ def generate(
109
+ self,
110
+ prompt=None,
111
+ sampling_params=None,
112
+ input_ids=None,
113
+ image_data=None,
114
+ return_logprob=False,
115
+ logprob_start_len=None,
116
+ top_logprobs_num=None,
117
+ token_ids_logprob=None,
118
+ lora_path=None,
119
+ custom_logit_processor=None,
120
+ ):
121
+ payload = {
122
+ "text": prompt,
123
+ "sampling_params": sampling_params,
124
+ "input_ids": input_ids,
125
+ "image_data": image_data,
126
+ "return_logprob": return_logprob,
127
+ "logprob_start_len": logprob_start_len,
128
+ "top_logprobs_num": top_logprobs_num,
129
+ "token_ids_logprob": token_ids_logprob,
130
+ "lora_path": lora_path,
131
+ "custom_logit_processor": custom_logit_processor,
132
+ }
133
+ # Filter out None values
134
+ payload = {k: v for k, v in payload.items() if v is not None}
135
+
136
+ return self._make_request("generate", payload)
137
+
138
+ def release_memory_occupation(self):
139
+ return self._make_request("release_memory_occupation")
140
+
141
+ def resume_memory_occupation(self):
142
+ return self._make_request("resume_memory_occupation")