sglang 0.4.1.post5__py3-none-any.whl → 0.4.1.post7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (129) hide show
  1. sglang/__init__.py +21 -23
  2. sglang/api.py +2 -7
  3. sglang/bench_offline_throughput.py +24 -16
  4. sglang/bench_one_batch.py +51 -3
  5. sglang/bench_one_batch_server.py +1 -1
  6. sglang/bench_serving.py +37 -28
  7. sglang/lang/backend/runtime_endpoint.py +183 -4
  8. sglang/lang/chat_template.py +15 -4
  9. sglang/launch_server.py +1 -1
  10. sglang/srt/_custom_ops.py +80 -42
  11. sglang/srt/configs/device_config.py +1 -1
  12. sglang/srt/configs/model_config.py +16 -6
  13. sglang/srt/constrained/base_grammar_backend.py +21 -0
  14. sglang/srt/constrained/xgrammar_backend.py +8 -4
  15. sglang/srt/conversation.py +14 -1
  16. sglang/srt/distributed/__init__.py +3 -3
  17. sglang/srt/distributed/communication_op.py +2 -1
  18. sglang/srt/distributed/device_communicators/cuda_wrapper.py +2 -1
  19. sglang/srt/distributed/device_communicators/custom_all_reduce.py +107 -40
  20. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
  21. sglang/srt/distributed/device_communicators/hpu_communicator.py +2 -1
  22. sglang/srt/distributed/device_communicators/pynccl.py +80 -1
  23. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +112 -2
  24. sglang/srt/distributed/device_communicators/shm_broadcast.py +5 -72
  25. sglang/srt/distributed/device_communicators/xpu_communicator.py +2 -1
  26. sglang/srt/distributed/parallel_state.py +1 -1
  27. sglang/srt/distributed/utils.py +2 -1
  28. sglang/srt/entrypoints/engine.py +449 -0
  29. sglang/srt/entrypoints/http_server.py +579 -0
  30. sglang/srt/layers/activation.py +3 -3
  31. sglang/srt/layers/attention/flashinfer_backend.py +27 -12
  32. sglang/srt/layers/attention/triton_backend.py +4 -6
  33. sglang/srt/layers/attention/vision.py +204 -0
  34. sglang/srt/layers/dp_attention.py +69 -0
  35. sglang/srt/layers/linear.py +76 -102
  36. sglang/srt/layers/logits_processor.py +48 -63
  37. sglang/srt/layers/moe/ep_moe/layer.py +4 -4
  38. sglang/srt/layers/moe/fused_moe_native.py +69 -0
  39. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -6
  40. sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -14
  41. sglang/srt/layers/moe/topk.py +4 -2
  42. sglang/srt/layers/parameter.py +26 -17
  43. sglang/srt/layers/quantization/__init__.py +22 -23
  44. sglang/srt/layers/quantization/fp8.py +112 -55
  45. sglang/srt/layers/quantization/fp8_utils.py +1 -1
  46. sglang/srt/layers/quantization/int8_kernel.py +54 -0
  47. sglang/srt/layers/quantization/modelopt_quant.py +2 -3
  48. sglang/srt/layers/quantization/w8a8_int8.py +117 -0
  49. sglang/srt/layers/radix_attention.py +2 -0
  50. sglang/srt/layers/rotary_embedding.py +1179 -31
  51. sglang/srt/layers/sampler.py +39 -1
  52. sglang/srt/layers/vocab_parallel_embedding.py +17 -4
  53. sglang/srt/lora/lora.py +1 -9
  54. sglang/srt/managers/configure_logging.py +46 -0
  55. sglang/srt/managers/data_parallel_controller.py +79 -72
  56. sglang/srt/managers/detokenizer_manager.py +23 -8
  57. sglang/srt/managers/image_processor.py +158 -2
  58. sglang/srt/managers/io_struct.py +54 -15
  59. sglang/srt/managers/schedule_batch.py +49 -22
  60. sglang/srt/managers/schedule_policy.py +26 -12
  61. sglang/srt/managers/scheduler.py +319 -181
  62. sglang/srt/managers/session_controller.py +1 -0
  63. sglang/srt/managers/tokenizer_manager.py +303 -158
  64. sglang/srt/managers/tp_worker.py +6 -4
  65. sglang/srt/managers/tp_worker_overlap_thread.py +5 -8
  66. sglang/srt/managers/utils.py +44 -0
  67. sglang/srt/mem_cache/memory_pool.py +110 -77
  68. sglang/srt/metrics/collector.py +25 -11
  69. sglang/srt/model_executor/cuda_graph_runner.py +4 -6
  70. sglang/srt/model_executor/model_runner.py +80 -21
  71. sglang/srt/model_loader/loader.py +8 -6
  72. sglang/srt/model_loader/weight_utils.py +55 -2
  73. sglang/srt/models/baichuan.py +6 -6
  74. sglang/srt/models/chatglm.py +2 -2
  75. sglang/srt/models/commandr.py +3 -3
  76. sglang/srt/models/dbrx.py +4 -4
  77. sglang/srt/models/deepseek.py +3 -3
  78. sglang/srt/models/deepseek_v2.py +8 -8
  79. sglang/srt/models/exaone.py +2 -2
  80. sglang/srt/models/gemma.py +2 -2
  81. sglang/srt/models/gemma2.py +6 -24
  82. sglang/srt/models/gpt2.py +3 -5
  83. sglang/srt/models/gpt_bigcode.py +1 -1
  84. sglang/srt/models/granite.py +2 -2
  85. sglang/srt/models/grok.py +3 -3
  86. sglang/srt/models/internlm2.py +2 -2
  87. sglang/srt/models/llama.py +41 -4
  88. sglang/srt/models/minicpm.py +2 -2
  89. sglang/srt/models/minicpm3.py +6 -6
  90. sglang/srt/models/minicpmv.py +1238 -0
  91. sglang/srt/models/mixtral.py +3 -3
  92. sglang/srt/models/mixtral_quant.py +3 -3
  93. sglang/srt/models/mllama.py +2 -2
  94. sglang/srt/models/olmo.py +3 -3
  95. sglang/srt/models/olmo2.py +4 -4
  96. sglang/srt/models/olmoe.py +7 -13
  97. sglang/srt/models/phi3_small.py +2 -2
  98. sglang/srt/models/qwen.py +2 -2
  99. sglang/srt/models/qwen2.py +52 -4
  100. sglang/srt/models/qwen2_eagle.py +131 -0
  101. sglang/srt/models/qwen2_moe.py +3 -3
  102. sglang/srt/models/qwen2_vl.py +22 -122
  103. sglang/srt/models/stablelm.py +2 -2
  104. sglang/srt/models/torch_native_llama.py +3 -3
  105. sglang/srt/models/xverse.py +6 -6
  106. sglang/srt/models/xverse_moe.py +6 -6
  107. sglang/srt/openai_api/protocol.py +2 -0
  108. sglang/srt/sampling/custom_logit_processor.py +38 -0
  109. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +15 -5
  110. sglang/srt/sampling/sampling_batch_info.py +153 -9
  111. sglang/srt/sampling/sampling_params.py +4 -2
  112. sglang/srt/server.py +4 -1037
  113. sglang/srt/server_args.py +84 -32
  114. sglang/srt/speculative/eagle_worker.py +1 -0
  115. sglang/srt/torch_memory_saver_adapter.py +59 -0
  116. sglang/srt/utils.py +130 -63
  117. sglang/test/runners.py +8 -13
  118. sglang/test/test_programs.py +1 -1
  119. sglang/test/test_utils.py +3 -1
  120. sglang/utils.py +12 -2
  121. sglang/version.py +1 -1
  122. {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/METADATA +26 -13
  123. {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/RECORD +126 -117
  124. sglang/launch_server_llavavid.py +0 -25
  125. sglang/srt/constrained/__init__.py +0 -16
  126. sglang/srt/distributed/device_communicators/__init__.py +0 -0
  127. {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/LICENSE +0 -0
  128. {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/WHEEL +0 -0
  129. {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/top_level.txt +0 -0
@@ -131,6 +131,7 @@ class Session:
131
131
  sampling_params=req.sampling_params,
132
132
  lora_path=req.lora_path,
133
133
  session_id=self.session_id,
134
+ custom_logit_processor=req.custom_logit_processor,
134
135
  )
135
136
  if last_req is not None:
136
137
  new_req.image_inputs = last_req.image_inputs
@@ -18,10 +18,14 @@ import copy
18
18
  import dataclasses
19
19
  import logging
20
20
  import os
21
+ import pickle
21
22
  import signal
22
23
  import sys
24
+ import threading
23
25
  import time
24
26
  import uuid
27
+ from datetime import datetime
28
+ from http import HTTPStatus
25
29
  from typing import Any, Awaitable, Dict, Generic, List, Optional, Tuple, TypeVar, Union
26
30
 
27
31
  import fastapi
@@ -43,6 +47,7 @@ from sglang.srt.managers.io_struct import (
43
47
  BatchStrOut,
44
48
  BatchTokenIDOut,
45
49
  CloseSessionReqInput,
50
+ ConfigureLoggingReq,
46
51
  EmbeddingReqInput,
47
52
  FlushCacheReq,
48
53
  GenerateReqInput,
@@ -53,6 +58,10 @@ from sglang.srt.managers.io_struct import (
53
58
  OpenSessionReqInput,
54
59
  OpenSessionReqOutput,
55
60
  ProfileReq,
61
+ ReleaseMemoryOccupationReqInput,
62
+ ReleaseMemoryOccupationReqOutput,
63
+ ResumeMemoryOccupationReqInput,
64
+ ResumeMemoryOccupationReqOutput,
56
65
  SessionParams,
57
66
  TokenizedEmbeddingReqInput,
58
67
  TokenizedGenerateReqInput,
@@ -71,6 +80,7 @@ from sglang.srt.utils import (
71
80
  get_zmq_socket,
72
81
  kill_process_tree,
73
82
  )
83
+ from sglang.utils import TypeBasedDispatcher, get_exception_traceback
74
84
 
75
85
  asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
76
86
 
@@ -103,16 +113,19 @@ class TokenizerManager:
103
113
  port_args: PortArgs,
104
114
  ):
105
115
  # Parse args
116
+
106
117
  self.server_args = server_args
107
118
  self.enable_metrics = server_args.enable_metrics
119
+ self.log_requests = server_args.log_requests
120
+ self.log_requests_level = 0
108
121
 
109
122
  # Init inter-process communication
110
123
  context = zmq.asyncio.Context(2)
111
124
  self.recv_from_detokenizer = get_zmq_socket(
112
- context, zmq.PULL, port_args.tokenizer_ipc_name
125
+ context, zmq.PULL, port_args.tokenizer_ipc_name, True
113
126
  )
114
127
  self.send_to_scheduler = get_zmq_socket(
115
- context, zmq.PUSH, port_args.scheduler_input_ipc_name
128
+ context, zmq.PUSH, port_args.scheduler_input_ipc_name, True
116
129
  )
117
130
 
118
131
  # Read model args
@@ -145,6 +158,7 @@ class TokenizerManager:
145
158
  server_args.tokenizer_path,
146
159
  tokenizer_mode=server_args.tokenizer_mode,
147
160
  trust_remote_code=server_args.trust_remote_code,
161
+ revision=server_args.revision,
148
162
  )
149
163
  self.tokenizer = self.processor.tokenizer
150
164
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
@@ -158,11 +172,15 @@ class TokenizerManager:
158
172
  server_args.tokenizer_path,
159
173
  tokenizer_mode=server_args.tokenizer_mode,
160
174
  trust_remote_code=server_args.trust_remote_code,
175
+ revision=server_args.revision,
161
176
  )
162
177
 
163
178
  # Store states
164
- self.to_create_loop = True
179
+ self.no_create_loop = False
165
180
  self.rid_to_state: Dict[str, ReqState] = {}
181
+ self.dump_requests_folder = "" # By default do not dump
182
+ self.dump_requests_threshold = 1000
183
+ self.dump_request_list: List[Tuple] = []
166
184
 
167
185
  # The event to notify the weight sync is finished.
168
186
  self.model_update_lock = RWLock()
@@ -188,6 +206,14 @@ class TokenizerManager:
188
206
  self.get_weights_by_name_communicator = _Communicator(
189
207
  self.send_to_scheduler, server_args.dp_size
190
208
  )
209
+ self.release_memory_occupation_communicator = _Communicator(
210
+ self.send_to_scheduler, server_args.dp_size
211
+ )
212
+ self.resume_memory_occupation_communicator = _Communicator(
213
+ self.send_to_scheduler, server_args.dp_size
214
+ )
215
+ # Set after scheduler is initialized
216
+ self.max_req_input_len = None
191
217
 
192
218
  # Metrics
193
219
  if self.enable_metrics:
@@ -198,6 +224,44 @@ class TokenizerManager:
198
224
  },
199
225
  )
200
226
 
227
+ self._result_dispatcher = TypeBasedDispatcher(
228
+ [
229
+ (
230
+ (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut),
231
+ self._handle_batch_output,
232
+ ),
233
+ (OpenSessionReqOutput, self._handle_open_session_req_output),
234
+ (
235
+ UpdateWeightFromDiskReqOutput,
236
+ self._handle_update_weights_from_disk_req_output,
237
+ ),
238
+ (
239
+ InitWeightsUpdateGroupReqOutput,
240
+ self.init_weights_update_group_communicator.handle_recv,
241
+ ),
242
+ (
243
+ UpdateWeightsFromDistributedReqOutput,
244
+ self.update_weights_from_distributed_communicator.handle_recv,
245
+ ),
246
+ (
247
+ UpdateWeightsFromTensorReqOutput,
248
+ self.update_weights_from_tensor_communicator.handle_recv,
249
+ ),
250
+ (
251
+ GetWeightsByNameReqOutput,
252
+ self.get_weights_by_name_communicator.handle_recv,
253
+ ),
254
+ (
255
+ ReleaseMemoryOccupationReqOutput,
256
+ self.release_memory_occupation_communicator.handle_recv,
257
+ ),
258
+ (
259
+ ResumeMemoryOccupationReqOutput,
260
+ self.resume_memory_occupation_communicator.handle_recv,
261
+ ),
262
+ ]
263
+ )
264
+
201
265
  async def generate_request(
202
266
  self,
203
267
  obj: Union[GenerateReqInput, EmbeddingReqInput],
@@ -215,8 +279,11 @@ class TokenizerManager:
215
279
 
216
280
  obj.normalize_batch_and_arguments()
217
281
 
218
- if self.server_args.log_requests:
219
- logger.info(f"Receive: obj={dataclass_to_string_truncated(obj)}")
282
+ if self.log_requests:
283
+ max_length = 2048 if self.log_requests_level == 0 else 1 << 30
284
+ logger.info(
285
+ f"Receive: obj={dataclass_to_string_truncated(obj, max_length)}"
286
+ )
220
287
 
221
288
  async with self.model_update_lock.reader_lock:
222
289
  is_single = obj.is_single
@@ -248,15 +315,21 @@ class TokenizerManager:
248
315
  )
249
316
  input_embeds = obj.input_embeds
250
317
  input_ids = obj.input_ids
251
- elif obj.input_ids is None:
252
- input_ids = self.tokenizer.encode(input_text)
253
- else:
318
+ elif obj.input_ids is not None:
254
319
  input_ids = obj.input_ids
320
+ else:
321
+ if self.tokenizer is None:
322
+ raise ValueError(
323
+ "The engine initialized with skip_tokenizer_init=True cannot "
324
+ "accept text prompts. Please provide input_ids or re-initialize "
325
+ "the engine with skip_tokenizer_init=False."
326
+ )
327
+ input_ids = self.tokenizer.encode(input_text)
255
328
 
256
329
  if self.is_generation:
257
330
  # TODO: also support getting embeddings for multimodal models
258
331
  image_inputs: Dict = await self.image_processor.process_images_async(
259
- obj.image_data, input_text or input_ids, obj
332
+ obj.image_data, input_text or input_ids, obj, self.max_req_input_len
260
333
  )
261
334
  if image_inputs and "input_ids" in image_inputs:
262
335
  input_ids = image_inputs["input_ids"]
@@ -267,12 +340,28 @@ class TokenizerManager:
267
340
  SessionParams(**obj.session_params) if obj.session_params else None
268
341
  )
269
342
 
270
- if obj.input_ids is not None and len(input_ids) >= self.context_len:
343
+ input_token_num = len(input_ids) if input_ids is not None else 0
344
+ if input_token_num >= self.context_len:
271
345
  raise ValueError(
272
- f"The input ({len(input_ids)} tokens) is longer than the "
346
+ f"The input ({input_token_num} tokens) is longer than the "
273
347
  f"model's context length ({self.context_len} tokens)."
274
348
  )
275
349
 
350
+ if (
351
+ obj.sampling_params.get("max_new_tokens") is not None
352
+ and obj.sampling_params.get("max_new_tokens") + input_token_num
353
+ >= self.context_len
354
+ ):
355
+ raise ValueError(
356
+ f"Requested token count exceeds the model's maximum context length "
357
+ f"of {self.context_len} tokens. You requested a total of "
358
+ f"{obj.sampling_params.get('max_new_tokens') + input_token_num} "
359
+ f"tokens: {input_token_num} tokens from the input messages and "
360
+ f"{obj.sampling_params.get('max_new_tokens')} tokens for the "
361
+ f"completion. Please reduce the number of tokens in the input "
362
+ f"messages or the completion to fit within the limit."
363
+ )
364
+
276
365
  # Parse sampling parameters
277
366
  sampling_params = SamplingParams(**obj.sampling_params)
278
367
  sampling_params.normalize(self.tokenizer)
@@ -293,6 +382,7 @@ class TokenizerManager:
293
382
  lora_path=obj.lora_path,
294
383
  input_embeds=input_embeds,
295
384
  session_params=session_params,
385
+ custom_logit_processor=obj.custom_logit_processor,
296
386
  )
297
387
  elif isinstance(obj, EmbeddingReqInput):
298
388
  tokenized_obj = TokenizedEmbeddingReqInput(
@@ -336,10 +426,21 @@ class TokenizerManager:
336
426
 
337
427
  state.out_list = []
338
428
  if state.finished:
339
- if self.server_args.log_requests:
340
- msg = f"Finish: obj={dataclass_to_string_truncated(obj)}, out={dataclass_to_string_truncated(out)}"
429
+ if self.log_requests:
430
+ max_length = 2048 if self.log_requests_level == 0 else 1 << 30
431
+ msg = f"Finish: obj={dataclass_to_string_truncated(obj, max_length)}, out={dataclass_to_string_truncated(out, max_length)}"
341
432
  logger.info(msg)
342
433
  del self.rid_to_state[obj.rid]
434
+
435
+ # Check if this was an abort/error created by scheduler
436
+ if isinstance(out["meta_info"].get("finish_reason"), dict):
437
+ finish_reason = out["meta_info"]["finish_reason"]
438
+ if (
439
+ finish_reason.get("type") == "abort"
440
+ and finish_reason.get("status_code") == HTTPStatus.BAD_REQUEST
441
+ ):
442
+ raise ValueError(finish_reason["message"])
443
+
343
444
  yield out
344
445
  break
345
446
 
@@ -548,6 +649,22 @@ class TokenizerManager:
548
649
  else:
549
650
  return all_parameters
550
651
 
652
+ async def release_memory_occupation(
653
+ self,
654
+ obj: ReleaseMemoryOccupationReqInput,
655
+ request: Optional[fastapi.Request] = None,
656
+ ):
657
+ self.auto_create_handle_loop()
658
+ await self.release_memory_occupation_communicator(obj)
659
+
660
+ async def resume_memory_occupation(
661
+ self,
662
+ obj: ResumeMemoryOccupationReqInput,
663
+ request: Optional[fastapi.Request] = None,
664
+ ):
665
+ self.auto_create_handle_loop()
666
+ await self.resume_memory_occupation_communicator(obj)
667
+
551
668
  async def open_session(
552
669
  self, obj: OpenSessionReqInput, request: Optional[fastapi.Request] = None
553
670
  ):
@@ -568,9 +685,19 @@ class TokenizerManager:
568
685
  async def close_session(
569
686
  self, obj: CloseSessionReqInput, request: Optional[fastapi.Request] = None
570
687
  ):
571
- assert not self.to_create_loop, "close session should not be the first request"
572
688
  await self.send_to_scheduler.send_pyobj(obj)
573
689
 
690
+ def configure_logging(self, obj: ConfigureLoggingReq):
691
+ if obj.log_requests is not None:
692
+ self.log_requests = obj.log_requests
693
+ if obj.log_requests_level is not None:
694
+ self.log_requests_level = obj.log_requests_level
695
+ if obj.dump_requests_folder is not None:
696
+ self.dump_requests_folder = obj.dump_requests_folder
697
+ if obj.dump_requests_threshold is not None:
698
+ self.dump_requests_threshold = obj.dump_requests_threshold
699
+ logging.info(f"Config logging: {obj=}")
700
+
574
701
  def create_abort_task(self, obj: GenerateReqInput):
575
702
  # Abort the request if the client is disconnected.
576
703
  async def abort_request():
@@ -586,22 +713,35 @@ class TokenizerManager:
586
713
  return background_tasks
587
714
 
588
715
  def auto_create_handle_loop(self):
589
- if not self.to_create_loop:
716
+ if self.no_create_loop:
590
717
  return
591
718
 
592
- self.to_create_loop = False
719
+ self.no_create_loop = True
593
720
  loop = asyncio.get_event_loop()
594
- self.asyncio_tasks.add(loop.create_task(self.handle_loop()))
721
+ self.asyncio_tasks.add(
722
+ loop.create_task(print_exception_wrapper(self.handle_loop))
723
+ )
595
724
 
596
- signal_handler = SignalHandler(self)
597
- loop.add_signal_handler(signal.SIGTERM, signal_handler.signal_handler)
598
- self.asyncio_tasks.add(loop.create_task(self.sigterm_watchdog()))
725
+ # We cannot add signal handler when the tokenizer manager is not in
726
+ # the main thread due to the CPython limitation.
727
+ if threading.current_thread() is threading.main_thread():
728
+ signal_handler = SignalHandler(self)
729
+ loop.add_signal_handler(signal.SIGTERM, signal_handler.signal_handler)
730
+ else:
731
+ logger.warning(
732
+ "Signal handler is not added because the tokenizer manager is "
733
+ "not in the main thread. This disables graceful shutdown of the "
734
+ "tokenizer manager when SIGTERM is received."
735
+ )
736
+ self.asyncio_tasks.add(
737
+ loop.create_task(print_exception_wrapper(self.sigterm_watchdog))
738
+ )
599
739
 
600
740
  async def sigterm_watchdog(self):
601
741
  while not self.gracefully_exit:
602
742
  await asyncio.sleep(5)
603
743
 
604
- # drain requests
744
+ # Drain requests
605
745
  while True:
606
746
  remain_num_req = len(self.rid_to_state)
607
747
  logger.info(
@@ -619,143 +759,64 @@ class TokenizerManager:
619
759
  """The event loop that handles requests"""
620
760
 
621
761
  while True:
622
- recv_obj: Union[
623
- BatchStrOut,
624
- BatchEmbeddingOut,
625
- BatchTokenIDOut,
626
- UpdateWeightFromDiskReqOutput,
627
- UpdateWeightsFromDistributedReqOutput,
628
- GetWeightsByNameReqOutput,
629
- InitWeightsUpdateGroupReqOutput,
630
- ] = await self.recv_from_detokenizer.recv_pyobj()
631
-
632
- if isinstance(recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut)):
633
- for i, rid in enumerate(recv_obj.rids):
634
- state = self.rid_to_state.get(rid, None)
635
- if state is None:
636
- continue
637
-
638
- meta_info = {
639
- "id": rid,
640
- "finish_reason": recv_obj.finished_reasons[i],
641
- "prompt_tokens": recv_obj.prompt_tokens[i],
642
- }
762
+ recv_obj = await self.recv_from_detokenizer.recv_pyobj()
763
+ self._result_dispatcher(recv_obj)
643
764
 
644
- if getattr(state.obj, "return_logprob", False):
645
- self.convert_logprob_style(
646
- meta_info,
647
- state.obj.top_logprobs_num,
648
- state.obj.return_text_in_logprobs,
649
- recv_obj,
650
- i,
651
- )
652
-
653
- if not isinstance(recv_obj, BatchEmbeddingOut):
654
- meta_info.update(
655
- {
656
- "completion_tokens": recv_obj.completion_tokens[i],
657
- "cached_tokens": recv_obj.cached_tokens[i],
658
- }
659
- )
660
-
661
- if isinstance(recv_obj, BatchStrOut):
662
- out_dict = {
663
- "text": recv_obj.output_strs[i],
664
- "meta_info": meta_info,
665
- }
666
- if self.server_args.return_token_ids:
667
- out_dict.update(
668
- {
669
- "input_ids": recv_obj.origin_input_ids[i],
670
- "output_ids": recv_obj.output_ids[i],
671
- }
672
- )
673
- elif isinstance(recv_obj, BatchTokenIDOut):
674
- out_dict = {
675
- "token_ids": recv_obj.output_ids[i],
676
- "meta_info": meta_info,
677
- }
678
- else:
679
- assert isinstance(recv_obj, BatchEmbeddingOut)
680
- out_dict = {
681
- "embedding": recv_obj.embeddings[i],
682
- "meta_info": meta_info,
683
- }
684
- state.out_list.append(out_dict)
685
- state.finished = recv_obj.finished_reasons[i] is not None
686
- state.event.set()
687
-
688
- if self.enable_metrics:
689
- completion_tokens = (
690
- recv_obj.completion_tokens[i]
691
- if getattr(recv_obj, "completion_tokens", None)
692
- else 0
693
- )
694
-
695
- if state.first_token_time is None:
696
- state.first_token_time = time.time()
697
- self.metrics_collector.observe_time_to_first_token(
698
- state.first_token_time - state.created_time
699
- )
700
- else:
701
- if completion_tokens >= 2:
702
- # Compute time_per_output_token for the streaming case
703
- self.metrics_collector.observe_time_per_output_token(
704
- (time.time() - state.first_token_time)
705
- / (completion_tokens - 1)
706
- )
707
-
708
- if state.finished:
709
- self.metrics_collector.inc_prompt_tokens(
710
- recv_obj.prompt_tokens[i]
711
- )
712
- self.metrics_collector.inc_generation_tokens(
713
- completion_tokens
714
- )
715
- self.metrics_collector.observe_e2e_request_latency(
716
- time.time() - state.created_time
717
- )
718
- # Compute time_per_output_token for the non-streaming case
719
- if (
720
- hasattr(state.obj, "stream")
721
- and not state.obj.stream
722
- and completion_tokens >= 1
723
- ):
724
- self.metrics_collector.observe_time_per_output_token(
725
- (time.time() - state.created_time)
726
- / completion_tokens
727
- )
728
- elif isinstance(recv_obj, OpenSessionReqOutput):
729
- self.session_futures[recv_obj.session_id].set_result(
730
- recv_obj.session_id if recv_obj.success else None
765
+ def _handle_batch_output(
766
+ self, recv_obj: Union[BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut]
767
+ ):
768
+ for i, rid in enumerate(recv_obj.rids):
769
+ state = self.rid_to_state.get(rid, None)
770
+ if state is None:
771
+ continue
772
+
773
+ meta_info = {
774
+ "id": rid,
775
+ "finish_reason": recv_obj.finished_reasons[i],
776
+ "prompt_tokens": recv_obj.prompt_tokens[i],
777
+ }
778
+
779
+ if getattr(state.obj, "return_logprob", False):
780
+ self.convert_logprob_style(
781
+ meta_info,
782
+ state.obj.top_logprobs_num,
783
+ state.obj.return_text_in_logprobs,
784
+ recv_obj,
785
+ i,
786
+ )
787
+
788
+ if not isinstance(recv_obj, BatchEmbeddingOut):
789
+ meta_info.update(
790
+ {
791
+ "completion_tokens": recv_obj.completion_tokens[i],
792
+ "cached_tokens": recv_obj.cached_tokens[i],
793
+ }
731
794
  )
732
- elif isinstance(recv_obj, UpdateWeightFromDiskReqOutput):
733
- if self.server_args.dp_size == 1:
734
- self.model_update_result.set_result(recv_obj)
735
- else: # self.server_args.dp_size > 1
736
- self.model_update_tmp.append(recv_obj)
737
- # set future if the all results are recevied
738
- if len(self.model_update_tmp) == self.server_args.dp_size:
739
- self.model_update_result.set_result(self.model_update_tmp)
740
- elif isinstance(recv_obj, InitWeightsUpdateGroupReqOutput):
741
- assert (
742
- self.server_args.dp_size == 1
743
- ), "dp_size must be 1 for init parameter update group"
744
- self.init_weights_update_group_communicator.handle_recv(recv_obj)
745
- elif isinstance(recv_obj, UpdateWeightsFromDistributedReqOutput):
746
- assert (
747
- self.server_args.dp_size == 1
748
- ), "dp_size must be 1 for update weights from distributed"
749
- self.update_weights_from_distributed_communicator.handle_recv(recv_obj)
750
- elif isinstance(recv_obj, UpdateWeightsFromTensorReqOutput):
751
- assert (
752
- self.server_args.dp_size == 1
753
- ), "dp_size must be 1 for update weights from distributed"
754
- self.update_weights_from_tensor_communicator.handle_recv(recv_obj)
755
- elif isinstance(recv_obj, GetWeightsByNameReqOutput):
756
- self.get_weights_by_name_communicator.handle_recv(recv_obj)
795
+
796
+ if isinstance(recv_obj, BatchStrOut):
797
+ out_dict = {
798
+ "text": recv_obj.output_strs[i],
799
+ "meta_info": meta_info,
800
+ }
801
+ elif isinstance(recv_obj, BatchTokenIDOut):
802
+ out_dict = {
803
+ "token_ids": recv_obj.output_ids[i],
804
+ "meta_info": meta_info,
805
+ }
757
806
  else:
758
- raise ValueError(f"Invalid object: {recv_obj=}")
807
+ assert isinstance(recv_obj, BatchEmbeddingOut)
808
+ out_dict = {
809
+ "embedding": recv_obj.embeddings[i],
810
+ "meta_info": meta_info,
811
+ }
812
+ state.out_list.append(out_dict)
813
+ state.finished = recv_obj.finished_reasons[i] is not None
814
+ state.event.set()
815
+
816
+ if self.enable_metrics and state.obj.log_metrics:
817
+ self.collect_metrics(state, recv_obj, i)
818
+ if self.dump_requests_folder and state.finished and state.obj.log_metrics:
819
+ self.dump_requests(state, out_dict)
759
820
 
760
821
  def convert_logprob_style(
761
822
  self,
@@ -775,9 +836,6 @@ class TokenizerManager:
775
836
  recv_obj.output_token_logprobs_idx[recv_obj_index],
776
837
  return_text_in_logprobs,
777
838
  )
778
- meta_info["normalized_prompt_logprob"] = recv_obj.normalized_prompt_logprob[
779
- recv_obj_index
780
- ]
781
839
 
782
840
  if top_logprobs_num > 0:
783
841
  meta_info["input_top_logprobs"] = self.detokenize_top_logprobs_tokens(
@@ -827,6 +885,93 @@ class TokenizerManager:
827
885
  ret.append(None)
828
886
  return ret
829
887
 
888
+ def collect_metrics(self, state: ReqState, recv_obj: BatchStrOut, i: int):
889
+ completion_tokens = (
890
+ recv_obj.completion_tokens[i]
891
+ if getattr(recv_obj, "completion_tokens", None)
892
+ else 0
893
+ )
894
+
895
+ if state.first_token_time is None:
896
+ state.first_token_time = time.time()
897
+ self.metrics_collector.observe_time_to_first_token(
898
+ state.first_token_time - state.created_time
899
+ )
900
+ else:
901
+ if completion_tokens >= 2:
902
+ # Compute time_per_output_token for the streaming case
903
+ self.metrics_collector.observe_time_per_output_token(
904
+ (time.time() - state.first_token_time) / (completion_tokens - 1)
905
+ )
906
+
907
+ if state.finished:
908
+ self.metrics_collector.observe_one_finished_request(
909
+ recv_obj.prompt_tokens[i], completion_tokens
910
+ )
911
+ self.metrics_collector.observe_e2e_request_latency(
912
+ time.time() - state.created_time
913
+ )
914
+ # Compute time_per_output_token for the non-streaming case
915
+ if (
916
+ hasattr(state.obj, "stream")
917
+ and not state.obj.stream
918
+ and completion_tokens >= 1
919
+ ):
920
+ self.metrics_collector.observe_time_per_output_token(
921
+ (time.time() - state.created_time) / completion_tokens
922
+ )
923
+
924
+ def dump_requests(self, state: ReqState, out_dict: dict):
925
+ self.dump_request_list.append(
926
+ (state.obj, out_dict, state.created_time, time.time())
927
+ )
928
+
929
+ if len(self.dump_request_list) >= self.dump_requests_threshold:
930
+ filename = os.path.join(
931
+ self.dump_requests_folder,
932
+ datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + ".pkl",
933
+ )
934
+ logger.info(f"Dump {len(self.dump_request_list)} requests to {filename}")
935
+
936
+ to_dump = self.dump_request_list
937
+ self.dump_request_list = []
938
+
939
+ def background_task():
940
+ os.makedirs(self.dump_requests_folder, exist_ok=True)
941
+ with open(filename, "wb") as f:
942
+ pickle.dump(to_dump, f)
943
+
944
+ # Schedule the task to run in the background without awaiting it
945
+ asyncio.create_task(asyncio.to_thread(background_task))
946
+
947
+ def _handle_open_session_req_output(self, recv_obj):
948
+ self.session_futures[recv_obj.session_id].set_result(
949
+ recv_obj.session_id if recv_obj.success else None
950
+ )
951
+
952
+ def _handle_update_weights_from_disk_req_output(self, recv_obj):
953
+ if self.server_args.dp_size == 1:
954
+ self.model_update_result.set_result(recv_obj)
955
+ else: # self.server_args.dp_size > 1
956
+ self.model_update_tmp.append(recv_obj)
957
+ # set future if the all results are recevied
958
+ if len(self.model_update_tmp) == self.server_args.dp_size:
959
+ self.model_update_result.set_result(self.model_update_tmp)
960
+
961
+
962
+ async def print_exception_wrapper(func):
963
+ """
964
+ Sometimes an asyncio function does not print exception.
965
+ We do another wrapper to handle the exception.
966
+ """
967
+ try:
968
+ await func()
969
+ except Exception:
970
+ traceback = get_exception_traceback()
971
+ logger.error(f"TokenizerManager hit an exception: {traceback}")
972
+ kill_process_tree(os.getpid(), include_parent=True)
973
+ sys.exit(1)
974
+
830
975
 
831
976
  class SignalHandler:
832
977
  def __init__(self, tokenizer_manager):
@@ -83,6 +83,7 @@ class TpModelWorker:
83
83
  server_args.tokenizer_path,
84
84
  tokenizer_mode=server_args.tokenizer_mode,
85
85
  trust_remote_code=server_args.trust_remote_code,
86
+ revision=server_args.revision,
86
87
  )
87
88
  self.tokenizer = self.processor.tokenizer
88
89
  else:
@@ -90,6 +91,7 @@ class TpModelWorker:
90
91
  server_args.tokenizer_path,
91
92
  tokenizer_mode=server_args.tokenizer_mode,
92
93
  trust_remote_code=server_args.trust_remote_code,
94
+ revision=server_args.revision,
93
95
  )
94
96
  self.device = self.model_runner.device
95
97
 
@@ -101,6 +103,7 @@ class TpModelWorker:
101
103
  self.max_total_num_tokens // 2
102
104
  if server_args.max_running_requests is None
103
105
  else server_args.max_running_requests
106
+ // (server_args.dp_size if server_args.enable_dp_attention else 1)
104
107
  ),
105
108
  self.model_runner.req_to_token_pool.size,
106
109
  )
@@ -142,16 +145,15 @@ class TpModelWorker:
142
145
  def get_tp_cpu_group(self):
143
146
  return self.model_runner.tp_group.cpu_group
144
147
 
148
+ def get_attention_tp_cpu_group(self):
149
+ return self.model_runner.attention_tp_group.cpu_group
150
+
145
151
  def get_memory_pool(self):
146
152
  return (
147
153
  self.model_runner.req_to_token_pool,
148
154
  self.model_runner.token_to_kv_pool,
149
155
  )
150
156
 
151
- def forward_batch_idle(self, model_worker_batch: ModelWorkerBatch):
152
- forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
153
- self.model_runner.forward(forward_batch)
154
-
155
157
  def forward_batch_generation(
156
158
  self,
157
159
  model_worker_batch: ModelWorkerBatch,