sglang 0.4.8__py3-none-any.whl → 0.4.8.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. sglang/srt/configs/model_config.py +1 -0
  2. sglang/srt/conversation.py +1 -0
  3. sglang/srt/custom_op.py +7 -1
  4. sglang/srt/disaggregation/base/conn.py +2 -0
  5. sglang/srt/disaggregation/decode.py +1 -1
  6. sglang/srt/disaggregation/mooncake/conn.py +289 -48
  7. sglang/srt/disaggregation/mooncake/transfer_engine.py +31 -1
  8. sglang/srt/disaggregation/nixl/conn.py +94 -46
  9. sglang/srt/disaggregation/prefill.py +3 -2
  10. sglang/srt/disaggregation/utils.py +12 -11
  11. sglang/srt/entrypoints/engine.py +5 -3
  12. sglang/srt/entrypoints/openai/protocol.py +47 -4
  13. sglang/srt/entrypoints/openai/serving_chat.py +52 -76
  14. sglang/srt/entrypoints/openai/serving_completions.py +1 -0
  15. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  16. sglang/srt/layers/activation.py +7 -0
  17. sglang/srt/layers/attention/flashattention_backend.py +24 -14
  18. sglang/srt/layers/layernorm.py +15 -0
  19. sglang/srt/layers/linear.py +18 -1
  20. sglang/srt/layers/logits_processor.py +12 -3
  21. sglang/srt/layers/moe/ep_moe/layer.py +79 -12
  22. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +19 -2
  23. sglang/srt/layers/moe/fused_moe_native.py +7 -0
  24. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +7 -2
  25. sglang/srt/layers/moe/fused_moe_triton/layer.py +73 -14
  26. sglang/srt/layers/moe/topk.py +26 -0
  27. sglang/srt/layers/quantization/fp8_utils.py +5 -4
  28. sglang/srt/layers/rotary_embedding.py +103 -11
  29. sglang/srt/layers/vocab_parallel_embedding.py +14 -1
  30. sglang/srt/managers/expert_distribution.py +21 -0
  31. sglang/srt/managers/io_struct.py +10 -2
  32. sglang/srt/managers/multimodal_processors/base_processor.py +44 -9
  33. sglang/srt/managers/multimodal_processors/gemma3n.py +97 -0
  34. sglang/srt/managers/schedule_batch.py +9 -1
  35. sglang/srt/managers/scheduler.py +42 -6
  36. sglang/srt/model_executor/cuda_graph_runner.py +1 -1
  37. sglang/srt/model_executor/model_runner.py +5 -2
  38. sglang/srt/model_loader/loader.py +45 -10
  39. sglang/srt/model_loader/weight_utils.py +89 -0
  40. sglang/srt/models/deepseek_nextn.py +7 -4
  41. sglang/srt/models/deepseek_v2.py +147 -4
  42. sglang/srt/models/gemma3n_audio.py +949 -0
  43. sglang/srt/models/gemma3n_causal.py +1009 -0
  44. sglang/srt/models/gemma3n_mm.py +511 -0
  45. sglang/srt/models/hunyuan.py +771 -0
  46. sglang/srt/server_args.py +16 -2
  47. sglang/srt/two_batch_overlap.py +4 -1
  48. sglang/srt/utils.py +71 -0
  49. sglang/version.py +1 -1
  50. {sglang-0.4.8.dist-info → sglang-0.4.8.post1.dist-info}/METADATA +1 -1
  51. {sglang-0.4.8.dist-info → sglang-0.4.8.post1.dist-info}/RECORD +54 -49
  52. {sglang-0.4.8.dist-info → sglang-0.4.8.post1.dist-info}/WHEEL +0 -0
  53. {sglang-0.4.8.dist-info → sglang-0.4.8.post1.dist-info}/licenses/LICENSE +0 -0
  54. {sglang-0.4.8.dist-info → sglang-0.4.8.post1.dist-info}/top_level.txt +0 -0
@@ -31,23 +31,19 @@ from sglang.srt.utils import get_local_ip_by_remote
31
31
 
32
32
  logger = logging.getLogger(__name__)
33
33
 
34
- NixlEngineInfo: TypeAlias = Dict[str, Union[str, int]]
35
-
36
34
  GUARD = "NixlMsgGuard".encode("ascii")
37
35
 
38
36
 
39
37
  @dataclasses.dataclass
40
38
  class TransferInfo:
39
+ """Contains indices for a transfer, sent by KVReceiver. Received by prefill bootstrap thread."""
40
+
41
41
  room: int
42
42
  endpoint: str
43
43
  dst_port: int
44
- agent_metadata: bytes
45
44
  agent_name: str
46
- dst_kv_ptrs: list[int]
47
45
  dst_kv_indices: npt.NDArray[np.int32]
48
- dst_aux_ptrs: list[int]
49
46
  dst_aux_index: int
50
- dst_gpu_id: int
51
47
  required_dst_info_num: int
52
48
 
53
49
  def is_dummy(self):
@@ -59,14 +55,37 @@ class TransferInfo:
59
55
  room=int(msg[0].decode("ascii")),
60
56
  endpoint=msg[1].decode("ascii"),
61
57
  dst_port=int(msg[2].decode("ascii")),
62
- agent_metadata=msg[3],
63
- agent_name=msg[4].decode("ascii"),
58
+ agent_name=msg[3].decode("ascii"),
59
+ dst_kv_indices=np.frombuffer(msg[4], dtype=np.int32),
60
+ dst_aux_index=int(msg[5].decode("ascii")),
61
+ required_dst_info_num=int(msg[6].decode("ascii")),
62
+ )
63
+
64
+
65
+ @dataclasses.dataclass
66
+ class KVArgsRegisterInfo:
67
+ """Contains base pointers and other info which only needs to be sent once by KVReceiver. Received by prefill bootstrap thread."""
68
+
69
+ room: str
70
+ endpoint: str
71
+ dst_port: int
72
+ agent_name: str
73
+ agent_metadata: bytes
74
+ dst_kv_ptrs: list[int]
75
+ dst_aux_ptrs: list[int]
76
+ gpu_id: int
77
+
78
+ @classmethod
79
+ def from_zmq(cls, msg: List[bytes]):
80
+ return cls(
81
+ room=str(msg[0].decode("ascii")),
82
+ endpoint=msg[1].decode("ascii"),
83
+ dst_port=int(msg[2].decode("ascii")),
84
+ agent_name=msg[3].decode("ascii"),
85
+ agent_metadata=msg[4],
64
86
  dst_kv_ptrs=list(struct.unpack(f"{len(msg[5])//8}Q", msg[5])),
65
- dst_kv_indices=np.frombuffer(msg[6], dtype=np.int32),
66
- dst_aux_ptrs=list(struct.unpack(f"{len(msg[7])//8}Q", msg[7])),
67
- dst_aux_index=int(msg[8].decode("ascii")),
68
- dst_gpu_id=int(msg[9].decode("ascii")),
69
- required_dst_info_num=int(msg[10].decode("ascii")),
87
+ dst_aux_ptrs=list(struct.unpack(f"{len(msg[6])//8}Q", msg[6])),
88
+ gpu_id=int(msg[7].decode("ascii")),
70
89
  )
71
90
 
72
91
 
@@ -109,9 +128,9 @@ class NixlKVManager(CommonKVManager):
109
128
  self.register_buffer_to_engine()
110
129
 
111
130
  if self.disaggregation_mode == DisaggregationMode.PREFILL:
112
- self.request_status = {}
113
- self.transfer_infos: Dict[int, TransferInfo] = {}
114
- self.peer_names: Dict[str, str] = {}
131
+ self.request_status: Dict[int, KVPoll] = {}
132
+ self.transfer_infos: Dict[int, Dict[str, TransferInfo]] = {}
133
+ self.decode_kv_args_table: Dict[str, KVArgsRegisterInfo] = {}
115
134
  self._start_bootstrap_thread()
116
135
  elif self.disaggregation_mode == DisaggregationMode.DECODE:
117
136
  self.transfer_statuses: Dict[int, TransferStatus] = defaultdict(
@@ -154,10 +173,13 @@ class NixlKVManager(CommonKVManager):
154
173
  if not self.aux_descs:
155
174
  raise Exception("NIXL memory registration failed for aux tensors")
156
175
 
157
- def _add_remote(self, agent_name: str, agent_metadata: bytes):
158
- if agent_name not in self.peer_names:
159
- self.peer_names[agent_name] = self.agent.add_remote_agent(agent_metadata)
160
- return self.peer_names[agent_name]
176
+ def _add_remote_peer(self, decode_kv_args: KVArgsRegisterInfo):
177
+ agent_name = decode_kv_args.agent_name
178
+ if agent_name in self.decode_kv_args_table:
179
+ logger.info(f"Peer {agent_name} was already registered, ignoring.")
180
+ return
181
+ self.decode_kv_args_table[agent_name] = decode_kv_args
182
+ self.agent.add_remote_agent(decode_kv_args.agent_metadata)
161
183
 
162
184
  def send_kvcache(
163
185
  self,
@@ -262,17 +284,17 @@ class NixlKVManager(CommonKVManager):
262
284
  if req.is_dummy():
263
285
  continue
264
286
 
265
- peer_name = self._add_remote(req.agent_name, req.agent_metadata)
266
287
  chunked_dst_kv_indice = req.dst_kv_indices[index_slice]
267
288
  assert len(chunked_dst_kv_indice) == len(kv_indices)
289
+ assert req.agent_name in self.decode_kv_args_table
268
290
 
269
291
  notif = "_".join([str(req.room), "kv", str(chunk_id), str(int(is_last))])
270
292
  kv_xfer_handle = self.send_kvcache(
271
- peer_name,
293
+ req.agent_name,
272
294
  kv_indices,
273
- req.dst_kv_ptrs,
295
+ self.decode_kv_args_table[req.agent_name].dst_kv_ptrs,
274
296
  chunked_dst_kv_indice,
275
- req.dst_gpu_id,
297
+ self.decode_kv_args_table[req.agent_name].gpu_id,
276
298
  notif,
277
299
  )
278
300
  handles.append(kv_xfer_handle)
@@ -280,13 +302,15 @@ class NixlKVManager(CommonKVManager):
280
302
  if is_last:
281
303
  assert aux_index is not None
282
304
  aux_xfer_handle = self.send_aux(
283
- peer_name,
305
+ req.agent_name,
284
306
  aux_index,
285
- req.dst_aux_ptrs,
307
+ self.decode_kv_args_table[req.agent_name].dst_aux_ptrs,
286
308
  req.dst_aux_index,
287
309
  str(req.room) + "_aux",
288
310
  )
289
311
  handles.append(aux_xfer_handle)
312
+ if is_last:
313
+ del self.transfer_infos[bootstrap_room]
290
314
  return handles
291
315
 
292
316
  def update_transfer_status(self):
@@ -328,16 +352,23 @@ class NixlKVManager(CommonKVManager):
328
352
  ), f"First message should be {GUARD}. Foreign traffic?"
329
353
  waiting_req_bytes = waiting_req_bytes[1:]
330
354
  room = waiting_req_bytes[0].decode("ascii")
331
-
332
- required_dst_info_num = int(waiting_req_bytes[10].decode("ascii"))
355
+ agent_name = waiting_req_bytes[3].decode("ascii")
356
+ if room == "None":
357
+ # Register new peer and save KV base pointers.
358
+ self._add_remote_peer(
359
+ KVArgsRegisterInfo.from_zmq(waiting_req_bytes)
360
+ )
361
+ logger.debug(f"Register KVArgs from {agent_name} successfully")
362
+ continue
333
363
  room = int(room)
334
- agent_name = waiting_req_bytes[4].decode("ascii")
335
364
  if room not in self.transfer_infos:
336
365
  self.transfer_infos[room] = {}
337
366
  self.transfer_infos[room][agent_name] = TransferInfo.from_zmq(
338
367
  waiting_req_bytes
339
368
  )
340
-
369
+ required_dst_info_num = self.transfer_infos[room][
370
+ agent_name
371
+ ].required_dst_info_num
341
372
  logger.debug(f"got info {room=} {agent_name=} {required_dst_info_num=}")
342
373
  if len(self.transfer_infos[room]) == required_dst_info_num:
343
374
  logger.debug(f"{room=} is bootstrapped")
@@ -391,6 +422,7 @@ class NixlKVSender(BaseKVSender):
391
422
  self.chunk_id += 1
392
423
  if is_last:
393
424
  self.has_sent = True
425
+ del self.kv_mgr.request_status[self.bootstrap_room]
394
426
 
395
427
  def poll(self) -> KVPoll:
396
428
  if not self.has_sent:
@@ -415,6 +447,7 @@ class NixlKVReceiver(CommonKVReceiver):
415
447
  data_parallel_rank: Optional[int] = None,
416
448
  ):
417
449
  self.started_transfer = False
450
+ self.conclude_state = None
418
451
  super().__init__(mgr, bootstrap_addr, bootstrap_room, data_parallel_rank)
419
452
 
420
453
  def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None):
@@ -426,17 +459,8 @@ class NixlKVReceiver(CommonKVReceiver):
426
459
  f"Fetched bootstrap info: {bootstrap_info} for engine rank: {self.kv_mgr.kv_args.engine_rank}"
427
460
  )
428
461
  is_dummy = bootstrap_info["is_dummy"]
429
-
430
- # TODO: send_kv_args earlier
431
- packed_kv_data_ptrs = b"".join(
432
- struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs
433
- )
434
- packed_aux_data_ptrs = b"".join(
435
- struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs
436
- )
437
-
438
462
  logger.debug(
439
- f"Sending to {self.prefill_server_url} with bootstrap room {self.bootstrap_room}"
463
+ f"Sending to {self.prefill_server_url} with bootstrap room {self.bootstrap_room} {is_dummy=}"
440
464
  )
441
465
  sock, lock = self._connect("tcp://" + self.prefill_server_url)
442
466
  with lock:
@@ -446,13 +470,9 @@ class NixlKVReceiver(CommonKVReceiver):
446
470
  str(self.bootstrap_room).encode("ascii"),
447
471
  get_local_ip_by_remote().encode("ascii"),
448
472
  str(self.kv_mgr.rank_port).encode("ascii"),
449
- self.kv_mgr.agent.get_agent_metadata(),
450
473
  self.kv_mgr.agent.name.encode("ascii"),
451
- packed_kv_data_ptrs,
452
474
  kv_indices.tobytes() if not is_dummy else b"",
453
- packed_aux_data_ptrs,
454
475
  str(aux_index).encode("ascii"),
455
- str(self.kv_mgr.kv_args.gpu_id).encode("ascii"),
456
476
  str(self.required_dst_info_num).encode("ascii"),
457
477
  ]
458
478
  )
@@ -460,17 +480,45 @@ class NixlKVReceiver(CommonKVReceiver):
460
480
  self.started_transfer = True
461
481
 
462
482
  def poll(self) -> KVPoll:
483
+ if self.conclude_state is not None:
484
+ return self.conclude_state
463
485
  if not self.started_transfer:
464
486
  return KVPoll.WaitingForInput # type: ignore
465
487
 
466
488
  self.kv_mgr.update_transfer_status()
467
-
468
489
  if self.kv_mgr.check_transfer_done(self.bootstrap_room): # type: ignore
490
+ self.conclude_state = KVPoll.Success
491
+ del self.kv_mgr.transfer_statuses[self.bootstrap_room]
469
492
  return KVPoll.Success # type: ignore
470
493
  return KVPoll.WaitingForInput # type: ignore
471
494
 
472
495
  def _register_kv_args(self):
473
- pass
496
+ for bootstrap_info in self.bootstrap_infos:
497
+ self.prefill_server_url = (
498
+ f"{bootstrap_info['rank_ip']}:{bootstrap_info['rank_port']}"
499
+ )
500
+ packed_kv_data_ptrs = b"".join(
501
+ struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs
502
+ )
503
+ packed_aux_data_ptrs = b"".join(
504
+ struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs
505
+ )
506
+
507
+ sock, lock = self._connect("tcp://" + self.prefill_server_url)
508
+ with lock:
509
+ sock.send_multipart(
510
+ [
511
+ GUARD,
512
+ "None".encode("ascii"),
513
+ get_local_ip_by_remote().encode("ascii"),
514
+ str(self.kv_mgr.rank_port).encode("ascii"),
515
+ self.kv_mgr.agent.name.encode("ascii"),
516
+ self.kv_mgr.agent.get_agent_metadata(),
517
+ packed_kv_data_ptrs,
518
+ packed_aux_data_ptrs,
519
+ str(self.kv_mgr.kv_args.gpu_id).encode("ascii"),
520
+ ]
521
+ )
474
522
 
475
523
  def failure_exception(self):
476
524
  raise Exception("Fake KVReceiver Exception")
@@ -93,8 +93,6 @@ class PrefillBootstrapQueue:
93
93
  self.gpu_id = gpu_id
94
94
  self.bootstrap_port = bootstrap_port
95
95
  self.queue: List[Req] = []
96
- self.pp_rank = pp_rank
97
- self.pp_size = pp_size
98
96
  self.gloo_group = gloo_group
99
97
  self.max_total_num_tokens = max_total_num_tokens
100
98
  self.scheduler = scheduler
@@ -124,6 +122,9 @@ class PrefillBootstrapQueue:
124
122
  kv_args.kv_data_ptrs = kv_data_ptrs
125
123
  kv_args.kv_data_lens = kv_data_lens
126
124
  kv_args.kv_item_lens = kv_item_lens
125
+ if not self.is_mla_backend:
126
+ kv_args.kv_head_num = self.token_to_kv_pool.head_num
127
+ kv_args.page_size = self.token_to_kv_pool.page_size
127
128
 
128
129
  kv_args.aux_data_ptrs, kv_args.aux_data_lens, kv_args.aux_item_lens = (
129
130
  self.metadata_buffers.get_buf_infos()
@@ -107,9 +107,6 @@ class MetadataBuffers:
107
107
  # The minimal size for RDMA is 64Bytes, so we pad it to > 64Bytes
108
108
  self.output_ids = torch.zeros((size, 16), dtype=torch.int32, device=device)
109
109
 
110
- self.output_hidden_states = torch.zeros(
111
- (size, hidden_size), dtype=dtype, device=device
112
- )
113
110
  self.output_token_logprobs_val = torch.zeros(
114
111
  (size, 16), dtype=torch.float32, device=device
115
112
  )
@@ -122,51 +119,50 @@ class MetadataBuffers:
122
119
  self.output_top_logprobs_idx = torch.zeros(
123
120
  (size, max_top_logprobs_num), dtype=torch.int32, device=device
124
121
  )
122
+ self.output_hidden_states = torch.zeros(
123
+ (size, hidden_size), dtype=dtype, device=device
124
+ )
125
125
 
126
126
  def get_buf_infos(self):
127
127
  ptrs = [
128
128
  self.output_ids.data_ptr(),
129
- self.output_hidden_states.data_ptr(), # TODO: set None to avoid transfer hidden_states when spec_algorithm is None
130
129
  self.output_token_logprobs_val.data_ptr(),
131
130
  self.output_token_logprobs_idx.data_ptr(),
132
131
  self.output_top_logprobs_val.data_ptr(),
133
132
  self.output_top_logprobs_idx.data_ptr(),
133
+ self.output_hidden_states.data_ptr(),
134
134
  ]
135
135
  data_lens = [
136
136
  self.output_ids.nbytes,
137
- self.output_hidden_states.nbytes,
138
137
  self.output_token_logprobs_val.nbytes,
139
138
  self.output_token_logprobs_idx.nbytes,
140
139
  self.output_top_logprobs_val.nbytes,
141
140
  self.output_top_logprobs_idx.nbytes,
141
+ self.output_hidden_states.nbytes,
142
142
  ]
143
143
  item_lens = [
144
144
  self.output_ids[0].nbytes,
145
- self.output_hidden_states[0].nbytes,
146
145
  self.output_token_logprobs_val[0].nbytes,
147
146
  self.output_token_logprobs_idx[0].nbytes,
148
147
  self.output_top_logprobs_val[0].nbytes,
149
148
  self.output_top_logprobs_idx[0].nbytes,
149
+ self.output_hidden_states[0].nbytes,
150
150
  ]
151
151
  return ptrs, data_lens, item_lens
152
152
 
153
153
  def get_buf(self, idx: int):
154
154
  return (
155
155
  self.output_ids[idx],
156
- self.output_hidden_states[idx],
157
156
  self.output_token_logprobs_val[idx],
158
157
  self.output_token_logprobs_idx[idx],
159
158
  self.output_top_logprobs_val[idx],
160
159
  self.output_top_logprobs_idx[idx],
160
+ self.output_hidden_states[idx],
161
161
  )
162
162
 
163
163
  def set_buf(self, req: Req):
164
164
 
165
165
  self.output_ids[req.metadata_buffer_index][0] = req.output_ids[0]
166
- if req.hidden_states_tensor is not None:
167
- self.output_hidden_states[req.metadata_buffer_index].copy_(
168
- req.hidden_states_tensor
169
- )
170
166
  if req.return_logprob:
171
167
  if req.output_token_logprobs_val: # not none or empty list
172
168
  self.output_token_logprobs_val[req.metadata_buffer_index][0] = (
@@ -189,6 +185,11 @@ class MetadataBuffers:
189
185
  ] = torch.tensor(
190
186
  req.output_top_logprobs_idx[0], dtype=torch.int32, device="cpu"
191
187
  )
188
+ # for PD + spec decode
189
+ if req.hidden_states_tensor is not None:
190
+ self.output_hidden_states[req.metadata_buffer_index].copy_(
191
+ req.hidden_states_tensor
192
+ )
192
193
 
193
194
 
194
195
  #########################
@@ -115,13 +115,13 @@ class Engine(EngineBase):
115
115
  atexit.register(self.shutdown)
116
116
 
117
117
  # Allocate ports for inter-process communications
118
- port_args = PortArgs.init_new(server_args)
118
+ self.port_args = PortArgs.init_new(server_args)
119
119
  logger.info(f"{server_args=}")
120
120
 
121
121
  # Launch subprocesses
122
122
  tokenizer_manager, template_manager, scheduler_info = _launch_subprocesses(
123
123
  server_args=server_args,
124
- port_args=port_args,
124
+ port_args=self.port_args,
125
125
  )
126
126
  self.server_args = server_args
127
127
  self.tokenizer_manager = tokenizer_manager
@@ -130,7 +130,7 @@ class Engine(EngineBase):
130
130
 
131
131
  context = zmq.Context(2)
132
132
  self.send_to_rpc = get_zmq_socket(
133
- context, zmq.DEALER, port_args.rpc_ipc_name, True
133
+ context, zmq.DEALER, self.port_args.rpc_ipc_name, True
134
134
  )
135
135
 
136
136
  def generate(
@@ -242,6 +242,7 @@ class Engine(EngineBase):
242
242
  token_ids_logprob: Optional[Union[List[List[int]], List[int]]] = None,
243
243
  lora_path: Optional[List[Optional[str]]] = None,
244
244
  custom_logit_processor: Optional[Union[List[str], str]] = None,
245
+ return_hidden_states: bool = False,
245
246
  stream: bool = False,
246
247
  bootstrap_host: Optional[Union[List[str], str]] = None,
247
248
  bootstrap_port: Optional[Union[List[int], int]] = None,
@@ -274,6 +275,7 @@ class Engine(EngineBase):
274
275
  top_logprobs_num=top_logprobs_num,
275
276
  token_ids_logprob=token_ids_logprob,
276
277
  lora_path=lora_path,
278
+ return_hidden_states=return_hidden_states,
277
279
  stream=stream,
278
280
  custom_logit_processor=custom_logit_processor,
279
281
  bootstrap_host=bootstrap_host,
@@ -14,7 +14,8 @@
14
14
  """Pydantic models for OpenAI API protocol"""
15
15
 
16
16
  import time
17
- from typing import Dict, List, Optional, Union
17
+ from dataclasses import dataclass
18
+ from typing import Any, Dict, List, Optional, Union
18
19
 
19
20
  from pydantic import (
20
21
  BaseModel,
@@ -195,6 +196,9 @@ class CompletionRequest(BaseModel):
195
196
  bootstrap_port: Optional[int] = None
196
197
  bootstrap_room: Optional[int] = None
197
198
 
199
+ # For request id
200
+ rid: Optional[Union[List[str], str]] = None
201
+
198
202
  @field_validator("max_tokens")
199
203
  @classmethod
200
204
  def validate_max_tokens_positive(cls, v):
@@ -309,6 +313,18 @@ class ChatCompletionMessageGenericParam(BaseModel):
309
313
  reasoning_content: Optional[str] = None
310
314
  tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None])
311
315
 
316
+ @field_validator("role", mode="before")
317
+ @classmethod
318
+ def _normalize_role(cls, v):
319
+ if isinstance(v, str):
320
+ v_lower = v.lower()
321
+ if v_lower not in {"system", "assistant", "tool"}:
322
+ raise ValueError(
323
+ "'role' must be one of 'system', 'assistant', or 'tool' (case-insensitive)."
324
+ )
325
+ return v_lower
326
+ raise ValueError("'role' must be a string")
327
+
312
328
 
313
329
  class ChatCompletionMessageUserParam(BaseModel):
314
330
  role: Literal["user"]
@@ -429,8 +445,8 @@ class ChatCompletionRequest(BaseModel):
429
445
  stream_reasoning: bool = True
430
446
  chat_template_kwargs: Optional[Dict] = None
431
447
 
432
- # The request id.
433
- rid: Optional[str] = None
448
+ # For request id
449
+ rid: Optional[Union[List[str], str]] = None
434
450
 
435
451
  # For PD disaggregation
436
452
  bootstrap_host: Optional[str] = None
@@ -528,7 +544,7 @@ class EmbeddingRequest(BaseModel):
528
544
  user: Optional[str] = None
529
545
 
530
546
  # The request id.
531
- rid: Optional[str] = None
547
+ rid: Optional[Union[List[str], str]] = None
532
548
 
533
549
 
534
550
  class EmbeddingObject(BaseModel):
@@ -587,3 +603,30 @@ OpenAIServingRequest = Union[
587
603
  ScoringRequest,
588
604
  V1RerankReqInput,
589
605
  ]
606
+
607
+
608
+ @dataclass
609
+ class MessageProcessingResult:
610
+ """Result of processing chat messages and applying templates.
611
+
612
+ This dataclass encapsulates all the outputs from message processing including
613
+ prompt generation, multimodal data extraction, and constraint preparation.
614
+ Used internally by OpenAIServingChat to pass processed data between methods.
615
+
616
+ Args:
617
+ prompt: The final text prompt after applying chat template
618
+ prompt_ids: Either the text prompt (str) or tokenized IDs (List[int])
619
+ image_data: Extracted image data from messages, if any
620
+ audio_data: Extracted audio data from messages, if any
621
+ modalities: List of modality types present in the messages
622
+ stop: Combined stop strings from template and request
623
+ tool_call_constraint: Optional constraint for structured tool calls
624
+ """
625
+
626
+ prompt: str
627
+ prompt_ids: Union[str, List[int]]
628
+ image_data: Optional[Any]
629
+ audio_data: Optional[Any]
630
+ modalities: List[str]
631
+ stop: List[str]
632
+ tool_call_constraint: Optional[Any] = None