sglang 0.4.8.post1__py3-none-any.whl → 0.4.9.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (158) hide show
  1. sglang/bench_one_batch_server.py +17 -2
  2. sglang/bench_serving.py +170 -24
  3. sglang/srt/configs/internvl.py +4 -2
  4. sglang/srt/configs/janus_pro.py +1 -1
  5. sglang/srt/configs/model_config.py +60 -1
  6. sglang/srt/configs/update_config.py +119 -0
  7. sglang/srt/conversation.py +69 -1
  8. sglang/srt/disaggregation/decode.py +21 -5
  9. sglang/srt/disaggregation/mooncake/conn.py +35 -4
  10. sglang/srt/disaggregation/nixl/conn.py +6 -6
  11. sglang/srt/disaggregation/prefill.py +2 -2
  12. sglang/srt/disaggregation/utils.py +1 -1
  13. sglang/srt/distributed/parallel_state.py +44 -17
  14. sglang/srt/entrypoints/EngineBase.py +8 -0
  15. sglang/srt/entrypoints/engine.py +40 -6
  16. sglang/srt/entrypoints/http_server.py +111 -24
  17. sglang/srt/entrypoints/http_server_engine.py +1 -1
  18. sglang/srt/entrypoints/openai/protocol.py +4 -2
  19. sglang/srt/eplb/__init__.py +0 -0
  20. sglang/srt/{managers → eplb}/eplb_algorithms/__init__.py +1 -1
  21. sglang/srt/{managers → eplb}/eplb_manager.py +2 -4
  22. sglang/srt/{eplb_simulator → eplb/eplb_simulator}/reader.py +1 -1
  23. sglang/srt/{managers → eplb}/expert_distribution.py +1 -5
  24. sglang/srt/{managers → eplb}/expert_location.py +1 -1
  25. sglang/srt/{managers → eplb}/expert_location_dispatch.py +1 -1
  26. sglang/srt/{model_executor → eplb}/expert_location_updater.py +17 -1
  27. sglang/srt/hf_transformers_utils.py +2 -1
  28. sglang/srt/layers/activation.py +2 -2
  29. sglang/srt/layers/amx_utils.py +86 -0
  30. sglang/srt/layers/attention/ascend_backend.py +219 -0
  31. sglang/srt/layers/attention/flashattention_backend.py +32 -9
  32. sglang/srt/layers/attention/tbo_backend.py +37 -9
  33. sglang/srt/layers/communicator.py +20 -2
  34. sglang/srt/layers/dp_attention.py +9 -3
  35. sglang/srt/layers/elementwise.py +76 -12
  36. sglang/srt/layers/flashinfer_comm_fusion.py +202 -0
  37. sglang/srt/layers/layernorm.py +26 -0
  38. sglang/srt/layers/linear.py +84 -14
  39. sglang/srt/layers/logits_processor.py +4 -4
  40. sglang/srt/layers/moe/cutlass_w4a8_moe.py +215 -0
  41. sglang/srt/layers/moe/ep_moe/kernels.py +81 -8
  42. sglang/srt/layers/moe/ep_moe/layer.py +176 -15
  43. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +23 -17
  44. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +3 -2
  45. sglang/srt/layers/moe/fused_moe_triton/layer.py +211 -74
  46. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +176 -0
  47. sglang/srt/layers/moe/router.py +60 -22
  48. sglang/srt/layers/moe/topk.py +10 -28
  49. sglang/srt/layers/parameter.py +67 -7
  50. sglang/srt/layers/quantization/__init__.py +2 -0
  51. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +1 -1
  52. sglang/srt/layers/quantization/fp8.py +72 -7
  53. sglang/srt/layers/quantization/fp8_kernel.py +1 -1
  54. sglang/srt/layers/quantization/fp8_utils.py +1 -2
  55. sglang/srt/layers/quantization/gptq.py +5 -1
  56. sglang/srt/layers/quantization/modelopt_quant.py +244 -1
  57. sglang/srt/layers/quantization/moe_wna16.py +1 -1
  58. sglang/srt/layers/quantization/quant_utils.py +166 -0
  59. sglang/srt/layers/quantization/w4afp8.py +264 -0
  60. sglang/srt/layers/quantization/w8a8_int8.py +52 -1
  61. sglang/srt/layers/rotary_embedding.py +2 -2
  62. sglang/srt/layers/vocab_parallel_embedding.py +20 -10
  63. sglang/srt/lora/lora.py +4 -5
  64. sglang/srt/lora/lora_manager.py +73 -20
  65. sglang/srt/lora/triton_ops/gate_up_lora_b.py +30 -19
  66. sglang/srt/lora/triton_ops/qkv_lora_b.py +30 -19
  67. sglang/srt/lora/triton_ops/sgemm_lora_a.py +27 -11
  68. sglang/srt/lora/triton_ops/sgemm_lora_b.py +27 -15
  69. sglang/srt/managers/cache_controller.py +41 -195
  70. sglang/srt/managers/configure_logging.py +1 -1
  71. sglang/srt/managers/io_struct.py +58 -14
  72. sglang/srt/managers/mm_utils.py +77 -61
  73. sglang/srt/managers/multimodal_processor.py +2 -6
  74. sglang/srt/managers/multimodal_processors/qwen_audio.py +94 -0
  75. sglang/srt/managers/schedule_batch.py +78 -85
  76. sglang/srt/managers/scheduler.py +130 -64
  77. sglang/srt/managers/scheduler_output_processor_mixin.py +8 -2
  78. sglang/srt/managers/session_controller.py +12 -3
  79. sglang/srt/managers/tokenizer_manager.py +314 -103
  80. sglang/srt/managers/tp_worker.py +13 -1
  81. sglang/srt/managers/tp_worker_overlap_thread.py +8 -0
  82. sglang/srt/mem_cache/allocator.py +290 -0
  83. sglang/srt/mem_cache/chunk_cache.py +34 -2
  84. sglang/srt/mem_cache/hiradix_cache.py +2 -0
  85. sglang/srt/mem_cache/memory_pool.py +402 -66
  86. sglang/srt/mem_cache/memory_pool_host.py +6 -109
  87. sglang/srt/mem_cache/multimodal_cache.py +3 -0
  88. sglang/srt/mem_cache/radix_cache.py +8 -4
  89. sglang/srt/model_executor/cuda_graph_runner.py +2 -1
  90. sglang/srt/model_executor/forward_batch_info.py +17 -4
  91. sglang/srt/model_executor/model_runner.py +297 -56
  92. sglang/srt/model_loader/loader.py +41 -0
  93. sglang/srt/model_loader/weight_utils.py +72 -4
  94. sglang/srt/models/deepseek_nextn.py +1 -3
  95. sglang/srt/models/deepseek_v2.py +195 -45
  96. sglang/srt/models/deepseek_vl2.py +3 -5
  97. sglang/srt/models/gemma3_causal.py +1 -2
  98. sglang/srt/models/gemma3n_causal.py +4 -3
  99. sglang/srt/models/gemma3n_mm.py +4 -20
  100. sglang/srt/models/hunyuan.py +1 -1
  101. sglang/srt/models/kimi_vl.py +1 -2
  102. sglang/srt/models/llama.py +10 -4
  103. sglang/srt/models/llama4.py +32 -45
  104. sglang/srt/models/llama_eagle3.py +61 -11
  105. sglang/srt/models/llava.py +5 -5
  106. sglang/srt/models/minicpmo.py +2 -2
  107. sglang/srt/models/mistral.py +1 -1
  108. sglang/srt/models/mllama4.py +402 -89
  109. sglang/srt/models/phi4mm.py +1 -3
  110. sglang/srt/models/pixtral.py +3 -7
  111. sglang/srt/models/qwen2.py +31 -3
  112. sglang/srt/models/qwen2_5_vl.py +1 -3
  113. sglang/srt/models/qwen2_audio.py +200 -0
  114. sglang/srt/models/qwen2_moe.py +32 -6
  115. sglang/srt/models/qwen2_vl.py +1 -4
  116. sglang/srt/models/qwen3.py +94 -25
  117. sglang/srt/models/qwen3_moe.py +68 -21
  118. sglang/srt/models/vila.py +3 -8
  119. sglang/srt/{mm_utils.py → multimodal/mm_utils.py} +2 -2
  120. sglang/srt/{managers/multimodal_processors → multimodal/processors}/base_processor.py +140 -158
  121. sglang/srt/{managers/multimodal_processors → multimodal/processors}/clip.py +2 -13
  122. sglang/srt/{managers/multimodal_processors → multimodal/processors}/deepseek_vl_v2.py +4 -11
  123. sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3.py +3 -10
  124. sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3n.py +5 -20
  125. sglang/srt/{managers/multimodal_processors → multimodal/processors}/internvl.py +3 -10
  126. sglang/srt/{managers/multimodal_processors → multimodal/processors}/janus_pro.py +3 -9
  127. sglang/srt/{managers/multimodal_processors → multimodal/processors}/kimi_vl.py +6 -13
  128. sglang/srt/{managers/multimodal_processors → multimodal/processors}/llava.py +2 -10
  129. sglang/srt/{managers/multimodal_processors → multimodal/processors}/minicpm.py +5 -12
  130. sglang/srt/{managers/multimodal_processors → multimodal/processors}/mlama.py +2 -14
  131. sglang/srt/{managers/multimodal_processors → multimodal/processors}/mllama4.py +65 -66
  132. sglang/srt/{managers/multimodal_processors → multimodal/processors}/phi4mm.py +4 -14
  133. sglang/srt/{managers/multimodal_processors → multimodal/processors}/pixtral.py +3 -9
  134. sglang/srt/{managers/multimodal_processors → multimodal/processors}/qwen_vl.py +8 -14
  135. sglang/srt/{managers/multimodal_processors → multimodal/processors}/vila.py +13 -31
  136. sglang/srt/operations_strategy.py +6 -2
  137. sglang/srt/reasoning_parser.py +26 -0
  138. sglang/srt/sampling/sampling_batch_info.py +39 -1
  139. sglang/srt/server_args.py +84 -22
  140. sglang/srt/speculative/build_eagle_tree.py +57 -18
  141. sglang/srt/speculative/eagle_worker.py +6 -4
  142. sglang/srt/two_batch_overlap.py +203 -27
  143. sglang/srt/utils.py +343 -163
  144. sglang/srt/warmup.py +12 -3
  145. sglang/test/runners.py +10 -1
  146. sglang/test/test_cutlass_w4a8_moe.py +281 -0
  147. sglang/test/test_utils.py +15 -3
  148. sglang/utils.py +5 -5
  149. sglang/version.py +1 -1
  150. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/METADATA +12 -8
  151. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/RECORD +157 -146
  152. sglang/math_utils.py +0 -8
  153. /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek.py +0 -0
  154. /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek_vec.py +0 -0
  155. /sglang/srt/{eplb_simulator → eplb/eplb_simulator}/__init__.py +0 -0
  156. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/WHEEL +0 -0
  157. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/licenses/LICENSE +0 -0
  158. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/top_level.txt +0 -0
@@ -83,6 +83,9 @@ from sglang.srt.managers.io_struct import (
83
83
  HealthCheckOutput,
84
84
  InitWeightsUpdateGroupReqInput,
85
85
  InitWeightsUpdateGroupReqOutput,
86
+ LoadLoRAAdapterReqInput,
87
+ LoadLoRAAdapterReqOutput,
88
+ LoRAUpdateResult,
86
89
  OpenSessionReqInput,
87
90
  OpenSessionReqOutput,
88
91
  ProfileReq,
@@ -99,6 +102,8 @@ from sglang.srt.managers.io_struct import (
99
102
  SlowDownReqOutput,
100
103
  TokenizedEmbeddingReqInput,
101
104
  TokenizedGenerateReqInput,
105
+ UnloadLoRAAdapterReqInput,
106
+ UnloadLoRAAdapterReqOutput,
102
107
  UpdateWeightFromDiskReqInput,
103
108
  UpdateWeightFromDiskReqOutput,
104
109
  UpdateWeightsFromDistributedReqInput,
@@ -106,11 +111,7 @@ from sglang.srt.managers.io_struct import (
106
111
  UpdateWeightsFromTensorReqInput,
107
112
  UpdateWeightsFromTensorReqOutput,
108
113
  )
109
- from sglang.srt.managers.multimodal_processor import (
110
- get_dummy_processor,
111
- get_mm_processor,
112
- import_processors,
113
- )
114
+ from sglang.srt.managers.multimodal_processor import get_mm_processor, import_processors
114
115
  from sglang.srt.metrics.collector import TokenizerMetricsCollector
115
116
  from sglang.srt.sampling.sampling_params import SamplingParams
116
117
  from sglang.srt.server_args import PortArgs, ServerArgs
@@ -145,7 +146,9 @@ class ReqState:
145
146
 
146
147
  # For streaming output
147
148
  last_output_offset: int = 0
149
+
148
150
  # For incremental state update.
151
+ # TODO(lianmin): do not initialize some lists if not needed.
149
152
  text: str = ""
150
153
  output_ids: List[int] = dataclasses.field(default_factory=list)
151
154
  input_token_logprobs_val: List[float] = dataclasses.field(default_factory=list)
@@ -180,6 +183,8 @@ class TokenizerManager:
180
183
  if server_args.preferred_sampling_params
181
184
  else None
182
185
  )
186
+ self.crash_dump_folder = server_args.crash_dump_folder
187
+ self.crash_dump_performed = False # Flag to ensure dump is only called once
183
188
 
184
189
  # Init inter-process communication
185
190
  context = zmq.asyncio.Context(2)
@@ -194,11 +199,12 @@ class TokenizerManager:
194
199
  self.model_path = server_args.model_path
195
200
  self.served_model_name = server_args.served_model_name
196
201
  self.model_config = ModelConfig.from_server_args(server_args)
197
-
198
202
  self.is_generation = self.model_config.is_generation
199
203
  self.is_image_gen = self.model_config.is_image_gen
200
204
  self.context_len = self.model_config.context_len
201
205
  self.image_token_id = self.model_config.image_token_id
206
+ self._updating = False
207
+ self._cond = asyncio.Condition()
202
208
 
203
209
  if self.model_config.is_multimodal:
204
210
  import_processors()
@@ -236,6 +242,12 @@ class TokenizerManager:
236
242
  revision=server_args.revision,
237
243
  )
238
244
 
245
+ # Initialize loaded loRA adapters with the initial lora paths in the server_args.
246
+ # This list will be updated when new LoRA adapters are loaded or unloaded dynamically.
247
+ self.loaded_lora_adapters: Dict[str, str] = dict(
248
+ self.server_args.lora_paths or {}
249
+ )
250
+
239
251
  # Store states
240
252
  self.no_create_loop = False
241
253
  self.rid_to_state: Dict[str, ReqState] = {}
@@ -245,20 +257,38 @@ class TokenizerManager:
245
257
  self.dump_requests_folder = "" # By default do not dump
246
258
  self.dump_requests_threshold = 1000
247
259
  self.dump_request_list: List[Tuple] = []
260
+ self.crash_dump_request_list: deque[Tuple] = deque()
248
261
  self.log_request_metadata = self.get_log_request_metadata()
262
+ self.session_futures = {} # session_id -> asyncio event
263
+ self.max_req_input_len = None
264
+ self.asyncio_tasks = set()
249
265
 
250
266
  # The event to notify the weight sync is finished.
251
267
  self.model_update_lock = RWLock()
252
268
  self.model_update_result: Optional[Awaitable[UpdateWeightFromDiskReqOutput]] = (
253
269
  None
254
270
  )
255
- self.asyncio_tasks = set()
256
271
 
257
- # For session info
258
- self.session_futures = {} # session_id -> asyncio event
272
+ # For pd disaggregtion
273
+ self.disaggregation_mode = DisaggregationMode(
274
+ self.server_args.disaggregation_mode
275
+ )
276
+ self.disaggregation_transfer_backend = TransferBackend(
277
+ self.server_args.disaggregation_transfer_backend
278
+ )
279
+ # Start kv boostrap server on prefill
280
+ if self.disaggregation_mode == DisaggregationMode.PREFILL:
281
+ # only start bootstrap server on prefill tm
282
+ kv_bootstrap_server_class = get_kv_class(
283
+ self.disaggregation_transfer_backend, KVClassType.BOOTSTRAP_SERVER
284
+ )
285
+ self.bootstrap_server = kv_bootstrap_server_class(
286
+ self.server_args.disaggregation_bootstrap_port
287
+ )
259
288
 
260
- # Set after scheduler is initialized
261
- self.max_req_input_len = None
289
+ # For load balancing
290
+ self.current_load = 0
291
+ self.current_load_lock = asyncio.Lock()
262
292
 
263
293
  # Metrics
264
294
  if self.enable_metrics:
@@ -301,7 +331,6 @@ class TokenizerManager:
301
331
  self.profile_communicator = _Communicator(
302
332
  self.send_to_scheduler, server_args.dp_size
303
333
  )
304
- self.health_check_communitcator = _Communicator(self.send_to_scheduler, 1)
305
334
  self.get_internal_state_communicator = _Communicator(
306
335
  self.send_to_scheduler, server_args.dp_size
307
336
  )
@@ -311,6 +340,9 @@ class TokenizerManager:
311
340
  self.expert_distribution_communicator = _Communicator(
312
341
  self.send_to_scheduler, server_args.dp_size
313
342
  )
343
+ self.update_lora_adapter_communicator = _Communicator(
344
+ self.send_to_scheduler, server_args.dp_size
345
+ )
314
346
 
315
347
  self._result_dispatcher = TypeBasedDispatcher(
316
348
  [
@@ -377,38 +409,25 @@ class TokenizerManager:
377
409
  ExpertDistributionReqOutput,
378
410
  self.expert_distribution_communicator.handle_recv,
379
411
  ),
412
+ (
413
+ LoRAUpdateResult,
414
+ self.update_lora_adapter_communicator.handle_recv,
415
+ ),
380
416
  (HealthCheckOutput, lambda x: None),
381
417
  ]
382
418
  )
383
419
 
384
- # For pd disaggregtion
385
- self.disaggregation_mode = DisaggregationMode(
386
- self.server_args.disaggregation_mode
387
- )
388
- self.transfer_backend = TransferBackend(
389
- self.server_args.disaggregation_transfer_backend
390
- )
391
- # Start kv boostrap server on prefill
392
- if self.disaggregation_mode == DisaggregationMode.PREFILL:
393
- # only start bootstrap server on prefill tm
394
- kv_bootstrap_server_class = get_kv_class(
395
- self.transfer_backend, KVClassType.BOOTSTRAP_SERVER
396
- )
397
- self.bootstrap_server = kv_bootstrap_server_class(
398
- self.server_args.disaggregation_bootstrap_port
399
- )
400
-
401
- self.current_load = 0
402
- self.current_load_lock = asyncio.Lock()
403
-
404
420
  async def generate_request(
405
421
  self,
406
422
  obj: Union[GenerateReqInput, EmbeddingReqInput],
407
423
  request: Optional[fastapi.Request] = None,
408
424
  ):
409
425
  created_time = time.time()
426
+ async with self._cond:
427
+ await self._cond.wait_for(lambda: not self._updating)
410
428
 
411
429
  self.auto_create_handle_loop()
430
+ obj.normalize_batch_and_arguments()
412
431
 
413
432
  if isinstance(obj, EmbeddingReqInput) and self.is_generation:
414
433
  raise ValueError(
@@ -416,22 +435,6 @@ class TokenizerManager:
416
435
  "Please add `--is-embedding` when launching the server or try another model."
417
436
  )
418
437
 
419
- obj.normalize_batch_and_arguments()
420
-
421
- if isinstance(obj, GenerateReqInput):
422
- return_hidden_states = obj.return_hidden_states
423
- has_return_hidden_states = return_hidden_states == True or (
424
- isinstance(return_hidden_states, list) and any(return_hidden_states)
425
- )
426
- if (
427
- not self.server_args.enable_return_hidden_states
428
- and has_return_hidden_states
429
- ):
430
- raise ValueError(
431
- "return_hidden_states=True requires the server to be started "
432
- "with --enable-return-hidden-states (ServerArgs.enable_return_hidden_states)."
433
- )
434
-
435
438
  if self.log_requests:
436
439
  max_length, skip_names, _ = self.log_request_metadata
437
440
  logger.info(
@@ -439,8 +442,7 @@ class TokenizerManager:
439
442
  )
440
443
 
441
444
  async with self.model_update_lock.reader_lock:
442
- is_single = obj.is_single
443
- if is_single:
445
+ if obj.is_single:
444
446
  tokenized_obj = await self._tokenize_one_request(obj)
445
447
  state = self._send_one_request(obj, tokenized_obj, created_time)
446
448
  async for response in self._wait_one_response(obj, state, request):
@@ -491,23 +493,28 @@ class TokenizerManager:
491
493
  token_type_ids = encoded.get("token_type_ids", [None])[0]
492
494
 
493
495
  if self.mm_processor and obj.contains_mm_input():
494
- image_inputs = await self.mm_processor.process_mm_data_async(
496
+ if not isinstance(obj.image_data, list):
497
+ obj.image_data = [obj.image_data]
498
+ if not isinstance(obj.audio_data, list):
499
+ obj.audio_data = [obj.audio_data]
500
+ mm_inputs: Dict = await self.mm_processor.process_mm_data_async(
495
501
  image_data=obj.image_data,
502
+ audio_data=obj.audio_data,
496
503
  input_text=input_text or input_ids,
497
504
  request_obj=obj,
498
505
  max_req_input_len=self.max_req_input_len,
499
506
  )
500
- if image_inputs and "input_ids" in image_inputs:
501
- input_ids = image_inputs["input_ids"]
507
+ if mm_inputs and "input_ids" in mm_inputs:
508
+ input_ids = mm_inputs["input_ids"]
502
509
  else:
503
- image_inputs: Optional[Dict] = None
510
+ mm_inputs = None
504
511
 
505
- self._validate_token_len(obj, input_ids)
512
+ self._validate_one_request(obj, input_ids)
506
513
  return self._create_tokenized_object(
507
- obj, input_text, input_ids, input_embeds, image_inputs, token_type_ids
514
+ obj, input_text, input_ids, input_embeds, mm_inputs, token_type_ids
508
515
  )
509
516
 
510
- def _validate_token_len(
517
+ def _validate_one_request(
511
518
  self, obj: Union[GenerateReqInput, EmbeddingReqInput], input_ids: List[int]
512
519
  ) -> None:
513
520
  """Validates that the input token count and the requested token count doesn't exceed the model's context length."""
@@ -536,25 +543,15 @@ class TokenizerManager:
536
543
  )
537
544
  raise ValueError(error_msg)
538
545
 
539
- def _create_tokenized_object(
540
- self,
541
- obj: Union[GenerateReqInput, EmbeddingReqInput],
542
- input_text: str,
543
- input_ids: List[int],
544
- input_embeds: Optional[Union[List[float], None]] = None,
545
- image_inputs: Optional[Dict] = None,
546
- token_type_ids: Optional[List[int]] = None,
547
- ) -> Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput]:
548
- """Create a tokenized request object from common parameters."""
549
-
550
- if self.is_generation:
551
- return_logprob = obj.return_logprob
552
- logprob_start_len = obj.logprob_start_len
553
- top_logprobs_num = obj.top_logprobs_num
554
- token_ids_logprob = obj.token_ids_logprob
555
- session_params = (
556
- SessionParams(**obj.session_params) if obj.session_params else None
557
- )
546
+ if isinstance(obj, GenerateReqInput):
547
+ if (
548
+ obj.return_hidden_states
549
+ and not self.server_args.enable_return_hidden_states
550
+ ):
551
+ raise ValueError(
552
+ "The server is not configured to return the hidden states. "
553
+ "Please set `--enable-return-hidden-states` to enable this feature."
554
+ )
558
555
  if (
559
556
  obj.custom_logit_processor
560
557
  and not self.server_args.enable_custom_logit_processor
@@ -563,7 +560,27 @@ class TokenizerManager:
563
560
  "The server is not configured to enable custom logit processor. "
564
561
  "Please set `--enable-custom-logits-processor` to enable this feature."
565
562
  )
563
+ if self.server_args.lora_paths and obj.lora_path:
564
+ self._validate_lora_adapters(obj)
565
+
566
+ def _validate_input_ids_in_vocab(
567
+ self, input_ids: List[int], vocab_size: int
568
+ ) -> None:
569
+ if any(id >= vocab_size for id in input_ids):
570
+ raise ValueError(
571
+ f"The input_ids {input_ids} contains values greater than the vocab size ({vocab_size})."
572
+ )
566
573
 
574
+ def _create_tokenized_object(
575
+ self,
576
+ obj: Union[GenerateReqInput, EmbeddingReqInput],
577
+ input_text: str,
578
+ input_ids: List[int],
579
+ input_embeds: Optional[Union[List[float], None]] = None,
580
+ mm_inputs: Optional[Dict] = None,
581
+ token_type_ids: Optional[List[int]] = None,
582
+ ) -> Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput]:
583
+ """Create a tokenized request object from common parameters."""
567
584
  # Parse sampling parameters
568
585
  # Note: if there are preferred sampling params, we use them if they are not
569
586
  # explicitly passed in sampling_params
@@ -577,16 +594,20 @@ class TokenizerManager:
577
594
 
578
595
  # Build return object
579
596
  if isinstance(obj, GenerateReqInput):
597
+ session_params = (
598
+ SessionParams(**obj.session_params) if obj.session_params else None
599
+ )
600
+
580
601
  tokenized_obj = TokenizedGenerateReqInput(
581
602
  obj.rid,
582
603
  input_text,
583
604
  input_ids,
584
- image_inputs,
605
+ mm_inputs,
585
606
  sampling_params,
586
- return_logprob,
587
- logprob_start_len,
588
- top_logprobs_num,
589
- token_ids_logprob,
607
+ obj.return_logprob,
608
+ obj.logprob_start_len,
609
+ obj.top_logprobs_num,
610
+ obj.token_ids_logprob,
590
611
  obj.stream,
591
612
  bootstrap_host=obj.bootstrap_host,
592
613
  bootstrap_port=obj.bootstrap_port,
@@ -603,7 +624,7 @@ class TokenizerManager:
603
624
  obj.rid,
604
625
  input_text,
605
626
  input_ids,
606
- image_inputs,
627
+ mm_inputs,
607
628
  token_type_ids,
608
629
  sampling_params,
609
630
  )
@@ -641,9 +662,9 @@ class TokenizerManager:
641
662
  ) -> None:
642
663
  """Validate constraints for batch tokenization processing."""
643
664
  for i in range(batch_size):
644
- if self.is_generation and obj[i].image_data:
665
+ if self.is_generation and obj[i].contains_mm_input():
645
666
  raise ValueError(
646
- "For image input processing do not set `enable_tokenizer_batch_encode`."
667
+ "For multimodal input processing do not set `enable_tokenizer_batch_encode`."
647
668
  )
648
669
  if obj[i].input_ids is not None:
649
670
  raise ValueError(
@@ -654,6 +675,21 @@ class TokenizerManager:
654
675
  "Batch tokenization is not needed for input_embeds. Do not set `enable_tokenizer_batch_encode`."
655
676
  )
656
677
 
678
+ def _validate_lora_adapters(self, obj: GenerateReqInput):
679
+ """Validate that the requested LoRA adapters are loaded."""
680
+ requested_adapters = (
681
+ set(obj.lora_path) if isinstance(obj.lora_path, list) else {obj.lora_path}
682
+ )
683
+ loaded_adapters = (
684
+ self.loaded_lora_adapters.keys() if self.loaded_lora_adapters else set()
685
+ )
686
+ unloaded_adapters = requested_adapters - loaded_adapters
687
+ if unloaded_adapters:
688
+ raise ValueError(
689
+ f"The following requested LoRA adapters are not loaded: {unloaded_adapters}\n"
690
+ f"Loaded adapters: {loaded_adapters}."
691
+ )
692
+
657
693
  def _send_one_request(
658
694
  self,
659
695
  obj: Union[GenerateReqInput, EmbeddingReqInput],
@@ -815,10 +851,10 @@ class TokenizerManager:
815
851
  async def flush_cache(self) -> FlushCacheReqOutput:
816
852
  return (await self.flush_cache_communicator(FlushCacheReqInput()))[0]
817
853
 
818
- def abort_request(self, rid: str):
819
- if rid not in self.rid_to_state:
854
+ def abort_request(self, rid: str = "", abort_all: bool = False):
855
+ if not abort_all and rid not in self.rid_to_state:
820
856
  return
821
- req = AbortReq(rid)
857
+ req = AbortReq(rid, abort_all)
822
858
  self.send_to_scheduler.send_pyobj(req)
823
859
 
824
860
  if self.enable_metrics:
@@ -871,6 +907,16 @@ class TokenizerManager:
871
907
  self.auto_create_handle_loop()
872
908
  await self.expert_distribution_communicator(ExpertDistributionReq.DUMP_RECORD)
873
909
 
910
+ async def pause_generation(self):
911
+ async with self._cond:
912
+ self._updating = True
913
+ self.abort_request(abort_all=True)
914
+
915
+ async def continue_generation(self):
916
+ async with self._cond:
917
+ self._updating = False
918
+ self._cond.notify_all()
919
+
874
920
  async def update_weights_from_disk(
875
921
  self,
876
922
  obj: UpdateWeightFromDiskReqInput,
@@ -883,6 +929,9 @@ class TokenizerManager:
883
929
  obj.load_format = self.server_args.load_format
884
930
  logger.info("Start update_weights. Load format=%s", obj.load_format)
885
931
 
932
+ if obj.abort_all_requests:
933
+ self.abort_request(abort_all=True)
934
+
886
935
  if True: # Keep this redundant check to simplify some internal code sync
887
936
  # Hold the lock if it is not async. This means that weight sync
888
937
  # cannot run while requests are in progress.
@@ -938,6 +987,9 @@ class TokenizerManager:
938
987
  self.server_args.dp_size == 1 or self.server_args.enable_dp_attention
939
988
  ), "dp_size must be 1 or dp attention must be enabled for update weights from distributed"
940
989
 
990
+ if obj.abort_all_requests:
991
+ self.abort_request(abort_all=True)
992
+
941
993
  # This means that weight sync
942
994
  # cannot run while requests are in progress.
943
995
  async with self.model_update_lock.writer_lock:
@@ -954,12 +1006,60 @@ class TokenizerManager:
954
1006
  self.server_args.dp_size == 1 or self.server_args.enable_dp_attention
955
1007
  ), "dp_size must be 1 or dp attention must be enabled for update weights from tensor"
956
1008
 
1009
+ if obj.abort_all_requests:
1010
+ self.abort_request(abort_all=True)
1011
+
957
1012
  # This means that weight sync
958
1013
  # cannot run while requests are in progress.
959
1014
  async with self.model_update_lock.writer_lock:
960
1015
  result = (await self.update_weights_from_tensor_communicator(obj))[0]
961
1016
  return result.success, result.message
962
1017
 
1018
+ async def load_lora_adapter(
1019
+ self,
1020
+ obj: LoadLoRAAdapterReqInput,
1021
+ _: Optional[fastapi.Request] = None,
1022
+ ) -> LoadLoRAAdapterReqOutput:
1023
+ self.auto_create_handle_loop()
1024
+
1025
+ # TODO (lifuhuang): Remove this after we verify that dynamic lora loading works
1026
+ # with dp_size > 1.
1027
+ assert (
1028
+ self.server_args.dp_size == 1
1029
+ ), "dp_size must be 1 for dynamic lora loading"
1030
+ logger.info(
1031
+ "Start load Lora adapter. Lora name=%s, path=%s",
1032
+ obj.lora_name,
1033
+ obj.lora_path,
1034
+ )
1035
+
1036
+ async with self.model_update_lock.writer_lock:
1037
+ result = (await self.update_lora_adapter_communicator(obj))[0]
1038
+ self.loaded_lora_adapters = result.loaded_adapters
1039
+ return result
1040
+
1041
+ async def unload_lora_adapter(
1042
+ self,
1043
+ obj: UnloadLoRAAdapterReqInput,
1044
+ _: Optional[fastapi.Request] = None,
1045
+ ) -> UnloadLoRAAdapterReqOutput:
1046
+ self.auto_create_handle_loop()
1047
+
1048
+ # TODO (lifuhuang): Remove this after we verify that dynamic lora loading works
1049
+ # with dp_size > 1.
1050
+ assert (
1051
+ self.server_args.dp_size == 1
1052
+ ), "dp_size must be 1 for dynamic lora loading"
1053
+ logger.info(
1054
+ "Start unload Lora adapter. Lora name=%s",
1055
+ obj.lora_name,
1056
+ )
1057
+
1058
+ async with self.model_update_lock.writer_lock:
1059
+ result = (await self.update_lora_adapter_communicator(obj))[0]
1060
+ self.loaded_lora_adapters = result.loaded_adapters
1061
+ return result
1062
+
963
1063
  async def get_weights_by_name(
964
1064
  self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
965
1065
  ):
@@ -1056,12 +1156,38 @@ class TokenizerManager:
1056
1156
  "image_data",
1057
1157
  "audio_data",
1058
1158
  "lora_path",
1159
+ "sampling_params",
1160
+ ]
1161
+ )
1162
+ out_skip_names = set(
1163
+ [
1164
+ "text",
1165
+ "output_ids",
1166
+ "embedding",
1059
1167
  ]
1060
1168
  )
1061
- out_skip_names = set(["text", "output_ids", "embedding"])
1062
1169
  elif self.log_requests_level == 1:
1063
- max_length = 2048
1170
+ max_length = 1 << 30
1171
+ skip_names = set(
1172
+ [
1173
+ "text",
1174
+ "input_ids",
1175
+ "input_embeds",
1176
+ "image_data",
1177
+ "audio_data",
1178
+ "lora_path",
1179
+ ]
1180
+ )
1181
+ out_skip_names = set(
1182
+ [
1183
+ "text",
1184
+ "output_ids",
1185
+ "embedding",
1186
+ ]
1187
+ )
1064
1188
  elif self.log_requests_level == 2:
1189
+ max_length = 2048
1190
+ elif self.log_requests_level == 3:
1065
1191
  max_length = 1 << 30
1066
1192
  else:
1067
1193
  raise ValueError(
@@ -1078,6 +1204,8 @@ class TokenizerManager:
1078
1204
  self.dump_requests_folder = obj.dump_requests_folder
1079
1205
  if obj.dump_requests_threshold is not None:
1080
1206
  self.dump_requests_threshold = obj.dump_requests_threshold
1207
+ if obj.crash_dump_folder is not None:
1208
+ self.crash_dump_folder = obj.crash_dump_folder
1081
1209
  logging.info(f"Config logging: {obj=}")
1082
1210
  self.log_request_metadata = self.get_log_request_metadata()
1083
1211
 
@@ -1126,6 +1254,52 @@ class TokenizerManager:
1126
1254
  loop.create_task(print_exception_wrapper(self.sigterm_watchdog))
1127
1255
  )
1128
1256
 
1257
+ def dump_requests_before_crash(self):
1258
+ if self.crash_dump_performed:
1259
+ logger.info(
1260
+ "SIGTERM/SIGQUIT/Exception triggered, but crash dump already performed, skipping."
1261
+ )
1262
+ return
1263
+ logger.error(f"Dumping requests before crash. {self.crash_dump_folder=}")
1264
+ self.crash_dump_performed = True
1265
+ if not self.crash_dump_folder:
1266
+ return
1267
+
1268
+ data_to_dump = []
1269
+ if self.crash_dump_request_list:
1270
+ data_to_dump.extend(self.crash_dump_request_list)
1271
+
1272
+ # Add unfinished requests from rid_to_state
1273
+ unfinished_requests = []
1274
+ for rid, state in self.rid_to_state.items():
1275
+ if not state.finished:
1276
+ unfinished_requests.append(
1277
+ (state.obj, {}, state.created_time, time.time())
1278
+ )
1279
+ if unfinished_requests:
1280
+ data_to_dump.extend(unfinished_requests)
1281
+
1282
+ if not data_to_dump:
1283
+ return
1284
+
1285
+ filename = os.path.join(
1286
+ self.crash_dump_folder,
1287
+ os.getenv("HOSTNAME", None),
1288
+ f'crash_dump_{datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}.pkl',
1289
+ )
1290
+
1291
+ os.makedirs(os.path.dirname(filename), exist_ok=True)
1292
+ # Include server_args in the dump
1293
+ data_to_dump_with_server_args = {
1294
+ "server_args": self.server_args,
1295
+ "requests": data_to_dump,
1296
+ }
1297
+ with open(filename, "wb") as f:
1298
+ pickle.dump(data_to_dump_with_server_args, f)
1299
+ logger.error(
1300
+ f"Dumped {len(self.crash_dump_request_list)} finished and {len(unfinished_requests)} unfinished requests before crash to {filename}"
1301
+ )
1302
+
1129
1303
  async def sigterm_watchdog(self):
1130
1304
  while not self.gracefully_exit:
1131
1305
  await asyncio.sleep(5)
@@ -1135,11 +1309,12 @@ class TokenizerManager:
1135
1309
  remain_num_req = len(self.rid_to_state)
1136
1310
 
1137
1311
  if self.health_check_failed:
1138
- # if health check failed, exit immediately
1312
+ # if health check failed, we should exit immediately
1139
1313
  logger.error(
1140
1314
  "Signal SIGTERM received while health check failed. Exiting... remaining number of requests: %d",
1141
1315
  remain_num_req,
1142
1316
  )
1317
+ self.dump_requests_before_crash()
1143
1318
  break
1144
1319
 
1145
1320
  elif get_bool_env_var("SGL_FORCE_SHUTDOWN"):
@@ -1156,6 +1331,7 @@ class TokenizerManager:
1156
1331
  if remain_num_req > 0:
1157
1332
  await asyncio.sleep(5)
1158
1333
  else:
1334
+ self.dump_requests_before_crash()
1159
1335
  break
1160
1336
 
1161
1337
  kill_process_tree(os.getpid(), include_parent=True)
@@ -1233,16 +1409,7 @@ class TokenizerManager:
1233
1409
  "meta_info": meta_info,
1234
1410
  }
1235
1411
  elif isinstance(recv_obj, BatchMultimodalOut):
1236
- if isinstance(recv_obj.outputs[i], str):
1237
- out_dict = {
1238
- "text": recv_obj.outputs[i],
1239
- "meta_info": meta_info,
1240
- }
1241
- else:
1242
- out_dict = {
1243
- "outputs": json.dumps(recv_obj.outputs[i]),
1244
- "meta_info": meta_info,
1245
- }
1412
+ raise NotImplementedError("BatchMultimodalOut not implemented")
1246
1413
  else:
1247
1414
  assert isinstance(recv_obj, BatchEmbeddingOut)
1248
1415
  out_dict = {
@@ -1266,6 +1433,8 @@ class TokenizerManager:
1266
1433
  self.collect_metrics(state, recv_obj, i)
1267
1434
  if self.dump_requests_folder and state.finished and state.obj.log_metrics:
1268
1435
  self.dump_requests(state, out_dict)
1436
+ if self.crash_dump_folder and state.finished and state.obj.log_metrics:
1437
+ self.record_request_for_crash_dump(state, out_dict)
1269
1438
 
1270
1439
  def convert_logprob_style(
1271
1440
  self,
@@ -1277,6 +1446,9 @@ class TokenizerManager:
1277
1446
  recv_obj: BatchStrOut,
1278
1447
  recv_obj_index: int,
1279
1448
  ):
1449
+ if recv_obj.input_token_logprobs_val is None:
1450
+ return
1451
+
1280
1452
  if len(recv_obj.input_token_logprobs_val) > 0:
1281
1453
  state.input_token_logprobs_val.extend(
1282
1454
  recv_obj.input_token_logprobs_val[recv_obj_index]
@@ -1396,7 +1568,10 @@ class TokenizerManager:
1396
1568
  else 0
1397
1569
  )
1398
1570
 
1399
- if state.first_token_time == 0.0:
1571
+ if (
1572
+ state.first_token_time == 0.0
1573
+ and self.disaggregation_mode != DisaggregationMode.PREFILL
1574
+ ):
1400
1575
  state.first_token_time = state.last_time = time.time()
1401
1576
  state.last_completion_tokens = completion_tokens
1402
1577
  self.metrics_collector.observe_time_to_first_token(
@@ -1444,16 +1619,49 @@ class TokenizerManager:
1444
1619
  to_dump = self.dump_request_list
1445
1620
  self.dump_request_list = []
1446
1621
 
1622
+ to_dump_with_server_args = {
1623
+ "server_args": self.server_args,
1624
+ "requests": to_dump,
1625
+ }
1626
+
1447
1627
  def background_task():
1448
1628
  os.makedirs(self.dump_requests_folder, exist_ok=True)
1449
1629
  with open(filename, "wb") as f:
1450
- pickle.dump(to_dump, f)
1630
+ pickle.dump(to_dump_with_server_args, f)
1451
1631
 
1452
1632
  # Schedule the task to run in the background without awaiting it
1453
1633
  asyncio.create_task(asyncio.to_thread(background_task))
1454
1634
 
1635
+ def record_request_for_crash_dump(self, state: ReqState, out_dict: dict):
1636
+ current_time = time.time()
1637
+ self.crash_dump_request_list.append(
1638
+ (state.obj, out_dict, state.created_time, current_time)
1639
+ )
1640
+ # Remove requests older than 5 minutes based on finish time
1641
+ while (
1642
+ self.crash_dump_request_list
1643
+ and current_time - self.crash_dump_request_list[0][3] >= 300
1644
+ ):
1645
+ self.crash_dump_request_list.popleft()
1646
+
1455
1647
  def _handle_abort_req(self, recv_obj):
1456
- self.rid_to_state.pop(recv_obj.rid, None)
1648
+ state = self.rid_to_state[recv_obj.rid]
1649
+ state.finished = True
1650
+ state.out_list.append(
1651
+ {
1652
+ "text": "",
1653
+ "meta_info": {
1654
+ "id": recv_obj.rid,
1655
+ "finish_reason": {
1656
+ "type": "abort",
1657
+ "message": "Abort before prefill",
1658
+ },
1659
+ "prompt_tokens": 0,
1660
+ "completion_tokens": 0,
1661
+ },
1662
+ }
1663
+ )
1664
+ state.event.set()
1457
1665
 
1458
1666
  def _handle_open_session_req_output(self, recv_obj):
1459
1667
  self.session_futures[recv_obj.session_id].set_result(
@@ -1574,6 +1782,8 @@ async def print_exception_wrapper(func):
1574
1782
  except Exception:
1575
1783
  traceback = get_exception_traceback()
1576
1784
  logger.error(f"TokenizerManager hit an exception: {traceback}")
1785
+ if hasattr(func, "__self__") and isinstance(func.__self__, TokenizerManager):
1786
+ func.__self__.dump_requests_before_crash()
1577
1787
  kill_process_tree(os.getpid(), include_parent=True)
1578
1788
  sys.exit(1)
1579
1789
 
@@ -1592,6 +1802,7 @@ class SignalHandler:
1592
1802
  logger.error(
1593
1803
  "Received sigquit from a child process. It usually means the child failed."
1594
1804
  )
1805
+ self.tokenizer_manager.dump_requests_before_crash()
1595
1806
  kill_process_tree(os.getpid())
1596
1807
 
1597
1808