sglang 0.4.8__py3-none-any.whl → 0.4.8.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/srt/configs/model_config.py +1 -0
- sglang/srt/conversation.py +1 -0
- sglang/srt/custom_op.py +7 -1
- sglang/srt/disaggregation/base/conn.py +2 -0
- sglang/srt/disaggregation/decode.py +1 -1
- sglang/srt/disaggregation/mooncake/conn.py +289 -48
- sglang/srt/disaggregation/mooncake/transfer_engine.py +31 -1
- sglang/srt/disaggregation/nixl/conn.py +94 -46
- sglang/srt/disaggregation/prefill.py +3 -2
- sglang/srt/disaggregation/utils.py +12 -11
- sglang/srt/entrypoints/engine.py +5 -3
- sglang/srt/entrypoints/openai/protocol.py +47 -4
- sglang/srt/entrypoints/openai/serving_chat.py +52 -76
- sglang/srt/entrypoints/openai/serving_completions.py +1 -0
- sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
- sglang/srt/layers/activation.py +7 -0
- sglang/srt/layers/attention/flashattention_backend.py +24 -14
- sglang/srt/layers/layernorm.py +15 -0
- sglang/srt/layers/linear.py +18 -1
- sglang/srt/layers/logits_processor.py +12 -3
- sglang/srt/layers/moe/ep_moe/layer.py +79 -12
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +19 -2
- sglang/srt/layers/moe/fused_moe_native.py +7 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +7 -2
- sglang/srt/layers/moe/fused_moe_triton/layer.py +73 -14
- sglang/srt/layers/moe/topk.py +26 -0
- sglang/srt/layers/quantization/fp8_utils.py +5 -4
- sglang/srt/layers/rotary_embedding.py +103 -11
- sglang/srt/layers/vocab_parallel_embedding.py +14 -1
- sglang/srt/managers/expert_distribution.py +21 -0
- sglang/srt/managers/io_struct.py +10 -2
- sglang/srt/managers/multimodal_processors/base_processor.py +44 -9
- sglang/srt/managers/multimodal_processors/gemma3n.py +97 -0
- sglang/srt/managers/schedule_batch.py +9 -1
- sglang/srt/managers/scheduler.py +42 -6
- sglang/srt/model_executor/cuda_graph_runner.py +1 -1
- sglang/srt/model_executor/model_runner.py +5 -2
- sglang/srt/model_loader/loader.py +45 -10
- sglang/srt/model_loader/weight_utils.py +89 -0
- sglang/srt/models/deepseek_nextn.py +7 -4
- sglang/srt/models/deepseek_v2.py +147 -4
- sglang/srt/models/gemma3n_audio.py +949 -0
- sglang/srt/models/gemma3n_causal.py +1009 -0
- sglang/srt/models/gemma3n_mm.py +511 -0
- sglang/srt/models/hunyuan.py +771 -0
- sglang/srt/server_args.py +16 -2
- sglang/srt/two_batch_overlap.py +4 -1
- sglang/srt/utils.py +71 -0
- sglang/version.py +1 -1
- {sglang-0.4.8.dist-info → sglang-0.4.8.post1.dist-info}/METADATA +1 -1
- {sglang-0.4.8.dist-info → sglang-0.4.8.post1.dist-info}/RECORD +54 -49
- {sglang-0.4.8.dist-info → sglang-0.4.8.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.8.dist-info → sglang-0.4.8.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.8.dist-info → sglang-0.4.8.post1.dist-info}/top_level.txt +0 -0
@@ -31,23 +31,19 @@ from sglang.srt.utils import get_local_ip_by_remote
|
|
31
31
|
|
32
32
|
logger = logging.getLogger(__name__)
|
33
33
|
|
34
|
-
NixlEngineInfo: TypeAlias = Dict[str, Union[str, int]]
|
35
|
-
|
36
34
|
GUARD = "NixlMsgGuard".encode("ascii")
|
37
35
|
|
38
36
|
|
39
37
|
@dataclasses.dataclass
|
40
38
|
class TransferInfo:
|
39
|
+
"""Contains indices for a transfer, sent by KVReceiver. Received by prefill bootstrap thread."""
|
40
|
+
|
41
41
|
room: int
|
42
42
|
endpoint: str
|
43
43
|
dst_port: int
|
44
|
-
agent_metadata: bytes
|
45
44
|
agent_name: str
|
46
|
-
dst_kv_ptrs: list[int]
|
47
45
|
dst_kv_indices: npt.NDArray[np.int32]
|
48
|
-
dst_aux_ptrs: list[int]
|
49
46
|
dst_aux_index: int
|
50
|
-
dst_gpu_id: int
|
51
47
|
required_dst_info_num: int
|
52
48
|
|
53
49
|
def is_dummy(self):
|
@@ -59,14 +55,37 @@ class TransferInfo:
|
|
59
55
|
room=int(msg[0].decode("ascii")),
|
60
56
|
endpoint=msg[1].decode("ascii"),
|
61
57
|
dst_port=int(msg[2].decode("ascii")),
|
62
|
-
|
63
|
-
|
58
|
+
agent_name=msg[3].decode("ascii"),
|
59
|
+
dst_kv_indices=np.frombuffer(msg[4], dtype=np.int32),
|
60
|
+
dst_aux_index=int(msg[5].decode("ascii")),
|
61
|
+
required_dst_info_num=int(msg[6].decode("ascii")),
|
62
|
+
)
|
63
|
+
|
64
|
+
|
65
|
+
@dataclasses.dataclass
|
66
|
+
class KVArgsRegisterInfo:
|
67
|
+
"""Contains base pointers and other info which only needs to be sent once by KVReceiver. Received by prefill bootstrap thread."""
|
68
|
+
|
69
|
+
room: str
|
70
|
+
endpoint: str
|
71
|
+
dst_port: int
|
72
|
+
agent_name: str
|
73
|
+
agent_metadata: bytes
|
74
|
+
dst_kv_ptrs: list[int]
|
75
|
+
dst_aux_ptrs: list[int]
|
76
|
+
gpu_id: int
|
77
|
+
|
78
|
+
@classmethod
|
79
|
+
def from_zmq(cls, msg: List[bytes]):
|
80
|
+
return cls(
|
81
|
+
room=str(msg[0].decode("ascii")),
|
82
|
+
endpoint=msg[1].decode("ascii"),
|
83
|
+
dst_port=int(msg[2].decode("ascii")),
|
84
|
+
agent_name=msg[3].decode("ascii"),
|
85
|
+
agent_metadata=msg[4],
|
64
86
|
dst_kv_ptrs=list(struct.unpack(f"{len(msg[5])//8}Q", msg[5])),
|
65
|
-
|
66
|
-
|
67
|
-
dst_aux_index=int(msg[8].decode("ascii")),
|
68
|
-
dst_gpu_id=int(msg[9].decode("ascii")),
|
69
|
-
required_dst_info_num=int(msg[10].decode("ascii")),
|
87
|
+
dst_aux_ptrs=list(struct.unpack(f"{len(msg[6])//8}Q", msg[6])),
|
88
|
+
gpu_id=int(msg[7].decode("ascii")),
|
70
89
|
)
|
71
90
|
|
72
91
|
|
@@ -109,9 +128,9 @@ class NixlKVManager(CommonKVManager):
|
|
109
128
|
self.register_buffer_to_engine()
|
110
129
|
|
111
130
|
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
112
|
-
self.request_status = {}
|
113
|
-
self.transfer_infos: Dict[int, TransferInfo] = {}
|
114
|
-
self.
|
131
|
+
self.request_status: Dict[int, KVPoll] = {}
|
132
|
+
self.transfer_infos: Dict[int, Dict[str, TransferInfo]] = {}
|
133
|
+
self.decode_kv_args_table: Dict[str, KVArgsRegisterInfo] = {}
|
115
134
|
self._start_bootstrap_thread()
|
116
135
|
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
117
136
|
self.transfer_statuses: Dict[int, TransferStatus] = defaultdict(
|
@@ -154,10 +173,13 @@ class NixlKVManager(CommonKVManager):
|
|
154
173
|
if not self.aux_descs:
|
155
174
|
raise Exception("NIXL memory registration failed for aux tensors")
|
156
175
|
|
157
|
-
def
|
158
|
-
|
159
|
-
|
160
|
-
|
176
|
+
def _add_remote_peer(self, decode_kv_args: KVArgsRegisterInfo):
|
177
|
+
agent_name = decode_kv_args.agent_name
|
178
|
+
if agent_name in self.decode_kv_args_table:
|
179
|
+
logger.info(f"Peer {agent_name} was already registered, ignoring.")
|
180
|
+
return
|
181
|
+
self.decode_kv_args_table[agent_name] = decode_kv_args
|
182
|
+
self.agent.add_remote_agent(decode_kv_args.agent_metadata)
|
161
183
|
|
162
184
|
def send_kvcache(
|
163
185
|
self,
|
@@ -262,17 +284,17 @@ class NixlKVManager(CommonKVManager):
|
|
262
284
|
if req.is_dummy():
|
263
285
|
continue
|
264
286
|
|
265
|
-
peer_name = self._add_remote(req.agent_name, req.agent_metadata)
|
266
287
|
chunked_dst_kv_indice = req.dst_kv_indices[index_slice]
|
267
288
|
assert len(chunked_dst_kv_indice) == len(kv_indices)
|
289
|
+
assert req.agent_name in self.decode_kv_args_table
|
268
290
|
|
269
291
|
notif = "_".join([str(req.room), "kv", str(chunk_id), str(int(is_last))])
|
270
292
|
kv_xfer_handle = self.send_kvcache(
|
271
|
-
|
293
|
+
req.agent_name,
|
272
294
|
kv_indices,
|
273
|
-
req.dst_kv_ptrs,
|
295
|
+
self.decode_kv_args_table[req.agent_name].dst_kv_ptrs,
|
274
296
|
chunked_dst_kv_indice,
|
275
|
-
req.
|
297
|
+
self.decode_kv_args_table[req.agent_name].gpu_id,
|
276
298
|
notif,
|
277
299
|
)
|
278
300
|
handles.append(kv_xfer_handle)
|
@@ -280,13 +302,15 @@ class NixlKVManager(CommonKVManager):
|
|
280
302
|
if is_last:
|
281
303
|
assert aux_index is not None
|
282
304
|
aux_xfer_handle = self.send_aux(
|
283
|
-
|
305
|
+
req.agent_name,
|
284
306
|
aux_index,
|
285
|
-
req.dst_aux_ptrs,
|
307
|
+
self.decode_kv_args_table[req.agent_name].dst_aux_ptrs,
|
286
308
|
req.dst_aux_index,
|
287
309
|
str(req.room) + "_aux",
|
288
310
|
)
|
289
311
|
handles.append(aux_xfer_handle)
|
312
|
+
if is_last:
|
313
|
+
del self.transfer_infos[bootstrap_room]
|
290
314
|
return handles
|
291
315
|
|
292
316
|
def update_transfer_status(self):
|
@@ -328,16 +352,23 @@ class NixlKVManager(CommonKVManager):
|
|
328
352
|
), f"First message should be {GUARD}. Foreign traffic?"
|
329
353
|
waiting_req_bytes = waiting_req_bytes[1:]
|
330
354
|
room = waiting_req_bytes[0].decode("ascii")
|
331
|
-
|
332
|
-
|
355
|
+
agent_name = waiting_req_bytes[3].decode("ascii")
|
356
|
+
if room == "None":
|
357
|
+
# Register new peer and save KV base pointers.
|
358
|
+
self._add_remote_peer(
|
359
|
+
KVArgsRegisterInfo.from_zmq(waiting_req_bytes)
|
360
|
+
)
|
361
|
+
logger.debug(f"Register KVArgs from {agent_name} successfully")
|
362
|
+
continue
|
333
363
|
room = int(room)
|
334
|
-
agent_name = waiting_req_bytes[4].decode("ascii")
|
335
364
|
if room not in self.transfer_infos:
|
336
365
|
self.transfer_infos[room] = {}
|
337
366
|
self.transfer_infos[room][agent_name] = TransferInfo.from_zmq(
|
338
367
|
waiting_req_bytes
|
339
368
|
)
|
340
|
-
|
369
|
+
required_dst_info_num = self.transfer_infos[room][
|
370
|
+
agent_name
|
371
|
+
].required_dst_info_num
|
341
372
|
logger.debug(f"got info {room=} {agent_name=} {required_dst_info_num=}")
|
342
373
|
if len(self.transfer_infos[room]) == required_dst_info_num:
|
343
374
|
logger.debug(f"{room=} is bootstrapped")
|
@@ -391,6 +422,7 @@ class NixlKVSender(BaseKVSender):
|
|
391
422
|
self.chunk_id += 1
|
392
423
|
if is_last:
|
393
424
|
self.has_sent = True
|
425
|
+
del self.kv_mgr.request_status[self.bootstrap_room]
|
394
426
|
|
395
427
|
def poll(self) -> KVPoll:
|
396
428
|
if not self.has_sent:
|
@@ -415,6 +447,7 @@ class NixlKVReceiver(CommonKVReceiver):
|
|
415
447
|
data_parallel_rank: Optional[int] = None,
|
416
448
|
):
|
417
449
|
self.started_transfer = False
|
450
|
+
self.conclude_state = None
|
418
451
|
super().__init__(mgr, bootstrap_addr, bootstrap_room, data_parallel_rank)
|
419
452
|
|
420
453
|
def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None):
|
@@ -426,17 +459,8 @@ class NixlKVReceiver(CommonKVReceiver):
|
|
426
459
|
f"Fetched bootstrap info: {bootstrap_info} for engine rank: {self.kv_mgr.kv_args.engine_rank}"
|
427
460
|
)
|
428
461
|
is_dummy = bootstrap_info["is_dummy"]
|
429
|
-
|
430
|
-
# TODO: send_kv_args earlier
|
431
|
-
packed_kv_data_ptrs = b"".join(
|
432
|
-
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs
|
433
|
-
)
|
434
|
-
packed_aux_data_ptrs = b"".join(
|
435
|
-
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs
|
436
|
-
)
|
437
|
-
|
438
462
|
logger.debug(
|
439
|
-
f"Sending to {self.prefill_server_url} with bootstrap room {self.bootstrap_room}"
|
463
|
+
f"Sending to {self.prefill_server_url} with bootstrap room {self.bootstrap_room} {is_dummy=}"
|
440
464
|
)
|
441
465
|
sock, lock = self._connect("tcp://" + self.prefill_server_url)
|
442
466
|
with lock:
|
@@ -446,13 +470,9 @@ class NixlKVReceiver(CommonKVReceiver):
|
|
446
470
|
str(self.bootstrap_room).encode("ascii"),
|
447
471
|
get_local_ip_by_remote().encode("ascii"),
|
448
472
|
str(self.kv_mgr.rank_port).encode("ascii"),
|
449
|
-
self.kv_mgr.agent.get_agent_metadata(),
|
450
473
|
self.kv_mgr.agent.name.encode("ascii"),
|
451
|
-
packed_kv_data_ptrs,
|
452
474
|
kv_indices.tobytes() if not is_dummy else b"",
|
453
|
-
packed_aux_data_ptrs,
|
454
475
|
str(aux_index).encode("ascii"),
|
455
|
-
str(self.kv_mgr.kv_args.gpu_id).encode("ascii"),
|
456
476
|
str(self.required_dst_info_num).encode("ascii"),
|
457
477
|
]
|
458
478
|
)
|
@@ -460,17 +480,45 @@ class NixlKVReceiver(CommonKVReceiver):
|
|
460
480
|
self.started_transfer = True
|
461
481
|
|
462
482
|
def poll(self) -> KVPoll:
|
483
|
+
if self.conclude_state is not None:
|
484
|
+
return self.conclude_state
|
463
485
|
if not self.started_transfer:
|
464
486
|
return KVPoll.WaitingForInput # type: ignore
|
465
487
|
|
466
488
|
self.kv_mgr.update_transfer_status()
|
467
|
-
|
468
489
|
if self.kv_mgr.check_transfer_done(self.bootstrap_room): # type: ignore
|
490
|
+
self.conclude_state = KVPoll.Success
|
491
|
+
del self.kv_mgr.transfer_statuses[self.bootstrap_room]
|
469
492
|
return KVPoll.Success # type: ignore
|
470
493
|
return KVPoll.WaitingForInput # type: ignore
|
471
494
|
|
472
495
|
def _register_kv_args(self):
|
473
|
-
|
496
|
+
for bootstrap_info in self.bootstrap_infos:
|
497
|
+
self.prefill_server_url = (
|
498
|
+
f"{bootstrap_info['rank_ip']}:{bootstrap_info['rank_port']}"
|
499
|
+
)
|
500
|
+
packed_kv_data_ptrs = b"".join(
|
501
|
+
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs
|
502
|
+
)
|
503
|
+
packed_aux_data_ptrs = b"".join(
|
504
|
+
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs
|
505
|
+
)
|
506
|
+
|
507
|
+
sock, lock = self._connect("tcp://" + self.prefill_server_url)
|
508
|
+
with lock:
|
509
|
+
sock.send_multipart(
|
510
|
+
[
|
511
|
+
GUARD,
|
512
|
+
"None".encode("ascii"),
|
513
|
+
get_local_ip_by_remote().encode("ascii"),
|
514
|
+
str(self.kv_mgr.rank_port).encode("ascii"),
|
515
|
+
self.kv_mgr.agent.name.encode("ascii"),
|
516
|
+
self.kv_mgr.agent.get_agent_metadata(),
|
517
|
+
packed_kv_data_ptrs,
|
518
|
+
packed_aux_data_ptrs,
|
519
|
+
str(self.kv_mgr.kv_args.gpu_id).encode("ascii"),
|
520
|
+
]
|
521
|
+
)
|
474
522
|
|
475
523
|
def failure_exception(self):
|
476
524
|
raise Exception("Fake KVReceiver Exception")
|
@@ -93,8 +93,6 @@ class PrefillBootstrapQueue:
|
|
93
93
|
self.gpu_id = gpu_id
|
94
94
|
self.bootstrap_port = bootstrap_port
|
95
95
|
self.queue: List[Req] = []
|
96
|
-
self.pp_rank = pp_rank
|
97
|
-
self.pp_size = pp_size
|
98
96
|
self.gloo_group = gloo_group
|
99
97
|
self.max_total_num_tokens = max_total_num_tokens
|
100
98
|
self.scheduler = scheduler
|
@@ -124,6 +122,9 @@ class PrefillBootstrapQueue:
|
|
124
122
|
kv_args.kv_data_ptrs = kv_data_ptrs
|
125
123
|
kv_args.kv_data_lens = kv_data_lens
|
126
124
|
kv_args.kv_item_lens = kv_item_lens
|
125
|
+
if not self.is_mla_backend:
|
126
|
+
kv_args.kv_head_num = self.token_to_kv_pool.head_num
|
127
|
+
kv_args.page_size = self.token_to_kv_pool.page_size
|
127
128
|
|
128
129
|
kv_args.aux_data_ptrs, kv_args.aux_data_lens, kv_args.aux_item_lens = (
|
129
130
|
self.metadata_buffers.get_buf_infos()
|
@@ -107,9 +107,6 @@ class MetadataBuffers:
|
|
107
107
|
# The minimal size for RDMA is 64Bytes, so we pad it to > 64Bytes
|
108
108
|
self.output_ids = torch.zeros((size, 16), dtype=torch.int32, device=device)
|
109
109
|
|
110
|
-
self.output_hidden_states = torch.zeros(
|
111
|
-
(size, hidden_size), dtype=dtype, device=device
|
112
|
-
)
|
113
110
|
self.output_token_logprobs_val = torch.zeros(
|
114
111
|
(size, 16), dtype=torch.float32, device=device
|
115
112
|
)
|
@@ -122,51 +119,50 @@ class MetadataBuffers:
|
|
122
119
|
self.output_top_logprobs_idx = torch.zeros(
|
123
120
|
(size, max_top_logprobs_num), dtype=torch.int32, device=device
|
124
121
|
)
|
122
|
+
self.output_hidden_states = torch.zeros(
|
123
|
+
(size, hidden_size), dtype=dtype, device=device
|
124
|
+
)
|
125
125
|
|
126
126
|
def get_buf_infos(self):
|
127
127
|
ptrs = [
|
128
128
|
self.output_ids.data_ptr(),
|
129
|
-
self.output_hidden_states.data_ptr(), # TODO: set None to avoid transfer hidden_states when spec_algorithm is None
|
130
129
|
self.output_token_logprobs_val.data_ptr(),
|
131
130
|
self.output_token_logprobs_idx.data_ptr(),
|
132
131
|
self.output_top_logprobs_val.data_ptr(),
|
133
132
|
self.output_top_logprobs_idx.data_ptr(),
|
133
|
+
self.output_hidden_states.data_ptr(),
|
134
134
|
]
|
135
135
|
data_lens = [
|
136
136
|
self.output_ids.nbytes,
|
137
|
-
self.output_hidden_states.nbytes,
|
138
137
|
self.output_token_logprobs_val.nbytes,
|
139
138
|
self.output_token_logprobs_idx.nbytes,
|
140
139
|
self.output_top_logprobs_val.nbytes,
|
141
140
|
self.output_top_logprobs_idx.nbytes,
|
141
|
+
self.output_hidden_states.nbytes,
|
142
142
|
]
|
143
143
|
item_lens = [
|
144
144
|
self.output_ids[0].nbytes,
|
145
|
-
self.output_hidden_states[0].nbytes,
|
146
145
|
self.output_token_logprobs_val[0].nbytes,
|
147
146
|
self.output_token_logprobs_idx[0].nbytes,
|
148
147
|
self.output_top_logprobs_val[0].nbytes,
|
149
148
|
self.output_top_logprobs_idx[0].nbytes,
|
149
|
+
self.output_hidden_states[0].nbytes,
|
150
150
|
]
|
151
151
|
return ptrs, data_lens, item_lens
|
152
152
|
|
153
153
|
def get_buf(self, idx: int):
|
154
154
|
return (
|
155
155
|
self.output_ids[idx],
|
156
|
-
self.output_hidden_states[idx],
|
157
156
|
self.output_token_logprobs_val[idx],
|
158
157
|
self.output_token_logprobs_idx[idx],
|
159
158
|
self.output_top_logprobs_val[idx],
|
160
159
|
self.output_top_logprobs_idx[idx],
|
160
|
+
self.output_hidden_states[idx],
|
161
161
|
)
|
162
162
|
|
163
163
|
def set_buf(self, req: Req):
|
164
164
|
|
165
165
|
self.output_ids[req.metadata_buffer_index][0] = req.output_ids[0]
|
166
|
-
if req.hidden_states_tensor is not None:
|
167
|
-
self.output_hidden_states[req.metadata_buffer_index].copy_(
|
168
|
-
req.hidden_states_tensor
|
169
|
-
)
|
170
166
|
if req.return_logprob:
|
171
167
|
if req.output_token_logprobs_val: # not none or empty list
|
172
168
|
self.output_token_logprobs_val[req.metadata_buffer_index][0] = (
|
@@ -189,6 +185,11 @@ class MetadataBuffers:
|
|
189
185
|
] = torch.tensor(
|
190
186
|
req.output_top_logprobs_idx[0], dtype=torch.int32, device="cpu"
|
191
187
|
)
|
188
|
+
# for PD + spec decode
|
189
|
+
if req.hidden_states_tensor is not None:
|
190
|
+
self.output_hidden_states[req.metadata_buffer_index].copy_(
|
191
|
+
req.hidden_states_tensor
|
192
|
+
)
|
192
193
|
|
193
194
|
|
194
195
|
#########################
|
sglang/srt/entrypoints/engine.py
CHANGED
@@ -115,13 +115,13 @@ class Engine(EngineBase):
|
|
115
115
|
atexit.register(self.shutdown)
|
116
116
|
|
117
117
|
# Allocate ports for inter-process communications
|
118
|
-
port_args = PortArgs.init_new(server_args)
|
118
|
+
self.port_args = PortArgs.init_new(server_args)
|
119
119
|
logger.info(f"{server_args=}")
|
120
120
|
|
121
121
|
# Launch subprocesses
|
122
122
|
tokenizer_manager, template_manager, scheduler_info = _launch_subprocesses(
|
123
123
|
server_args=server_args,
|
124
|
-
port_args=port_args,
|
124
|
+
port_args=self.port_args,
|
125
125
|
)
|
126
126
|
self.server_args = server_args
|
127
127
|
self.tokenizer_manager = tokenizer_manager
|
@@ -130,7 +130,7 @@ class Engine(EngineBase):
|
|
130
130
|
|
131
131
|
context = zmq.Context(2)
|
132
132
|
self.send_to_rpc = get_zmq_socket(
|
133
|
-
context, zmq.DEALER, port_args.rpc_ipc_name, True
|
133
|
+
context, zmq.DEALER, self.port_args.rpc_ipc_name, True
|
134
134
|
)
|
135
135
|
|
136
136
|
def generate(
|
@@ -242,6 +242,7 @@ class Engine(EngineBase):
|
|
242
242
|
token_ids_logprob: Optional[Union[List[List[int]], List[int]]] = None,
|
243
243
|
lora_path: Optional[List[Optional[str]]] = None,
|
244
244
|
custom_logit_processor: Optional[Union[List[str], str]] = None,
|
245
|
+
return_hidden_states: bool = False,
|
245
246
|
stream: bool = False,
|
246
247
|
bootstrap_host: Optional[Union[List[str], str]] = None,
|
247
248
|
bootstrap_port: Optional[Union[List[int], int]] = None,
|
@@ -274,6 +275,7 @@ class Engine(EngineBase):
|
|
274
275
|
top_logprobs_num=top_logprobs_num,
|
275
276
|
token_ids_logprob=token_ids_logprob,
|
276
277
|
lora_path=lora_path,
|
278
|
+
return_hidden_states=return_hidden_states,
|
277
279
|
stream=stream,
|
278
280
|
custom_logit_processor=custom_logit_processor,
|
279
281
|
bootstrap_host=bootstrap_host,
|
@@ -14,7 +14,8 @@
|
|
14
14
|
"""Pydantic models for OpenAI API protocol"""
|
15
15
|
|
16
16
|
import time
|
17
|
-
from
|
17
|
+
from dataclasses import dataclass
|
18
|
+
from typing import Any, Dict, List, Optional, Union
|
18
19
|
|
19
20
|
from pydantic import (
|
20
21
|
BaseModel,
|
@@ -195,6 +196,9 @@ class CompletionRequest(BaseModel):
|
|
195
196
|
bootstrap_port: Optional[int] = None
|
196
197
|
bootstrap_room: Optional[int] = None
|
197
198
|
|
199
|
+
# For request id
|
200
|
+
rid: Optional[Union[List[str], str]] = None
|
201
|
+
|
198
202
|
@field_validator("max_tokens")
|
199
203
|
@classmethod
|
200
204
|
def validate_max_tokens_positive(cls, v):
|
@@ -309,6 +313,18 @@ class ChatCompletionMessageGenericParam(BaseModel):
|
|
309
313
|
reasoning_content: Optional[str] = None
|
310
314
|
tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None])
|
311
315
|
|
316
|
+
@field_validator("role", mode="before")
|
317
|
+
@classmethod
|
318
|
+
def _normalize_role(cls, v):
|
319
|
+
if isinstance(v, str):
|
320
|
+
v_lower = v.lower()
|
321
|
+
if v_lower not in {"system", "assistant", "tool"}:
|
322
|
+
raise ValueError(
|
323
|
+
"'role' must be one of 'system', 'assistant', or 'tool' (case-insensitive)."
|
324
|
+
)
|
325
|
+
return v_lower
|
326
|
+
raise ValueError("'role' must be a string")
|
327
|
+
|
312
328
|
|
313
329
|
class ChatCompletionMessageUserParam(BaseModel):
|
314
330
|
role: Literal["user"]
|
@@ -429,8 +445,8 @@ class ChatCompletionRequest(BaseModel):
|
|
429
445
|
stream_reasoning: bool = True
|
430
446
|
chat_template_kwargs: Optional[Dict] = None
|
431
447
|
|
432
|
-
#
|
433
|
-
rid: Optional[str] = None
|
448
|
+
# For request id
|
449
|
+
rid: Optional[Union[List[str], str]] = None
|
434
450
|
|
435
451
|
# For PD disaggregation
|
436
452
|
bootstrap_host: Optional[str] = None
|
@@ -528,7 +544,7 @@ class EmbeddingRequest(BaseModel):
|
|
528
544
|
user: Optional[str] = None
|
529
545
|
|
530
546
|
# The request id.
|
531
|
-
rid: Optional[str] = None
|
547
|
+
rid: Optional[Union[List[str], str]] = None
|
532
548
|
|
533
549
|
|
534
550
|
class EmbeddingObject(BaseModel):
|
@@ -587,3 +603,30 @@ OpenAIServingRequest = Union[
|
|
587
603
|
ScoringRequest,
|
588
604
|
V1RerankReqInput,
|
589
605
|
]
|
606
|
+
|
607
|
+
|
608
|
+
@dataclass
|
609
|
+
class MessageProcessingResult:
|
610
|
+
"""Result of processing chat messages and applying templates.
|
611
|
+
|
612
|
+
This dataclass encapsulates all the outputs from message processing including
|
613
|
+
prompt generation, multimodal data extraction, and constraint preparation.
|
614
|
+
Used internally by OpenAIServingChat to pass processed data between methods.
|
615
|
+
|
616
|
+
Args:
|
617
|
+
prompt: The final text prompt after applying chat template
|
618
|
+
prompt_ids: Either the text prompt (str) or tokenized IDs (List[int])
|
619
|
+
image_data: Extracted image data from messages, if any
|
620
|
+
audio_data: Extracted audio data from messages, if any
|
621
|
+
modalities: List of modality types present in the messages
|
622
|
+
stop: Combined stop strings from template and request
|
623
|
+
tool_call_constraint: Optional constraint for structured tool calls
|
624
|
+
"""
|
625
|
+
|
626
|
+
prompt: str
|
627
|
+
prompt_ids: Union[str, List[int]]
|
628
|
+
image_data: Optional[Any]
|
629
|
+
audio_data: Optional[Any]
|
630
|
+
modalities: List[str]
|
631
|
+
stop: List[str]
|
632
|
+
tool_call_constraint: Optional[Any] = None
|