sglang 0.4.8.post1__py3-none-any.whl → 0.4.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (141) hide show
  1. sglang/bench_one_batch_server.py +17 -2
  2. sglang/bench_serving.py +168 -22
  3. sglang/srt/configs/internvl.py +4 -2
  4. sglang/srt/configs/janus_pro.py +1 -1
  5. sglang/srt/configs/model_config.py +48 -0
  6. sglang/srt/configs/update_config.py +119 -0
  7. sglang/srt/conversation.py +34 -0
  8. sglang/srt/disaggregation/decode.py +21 -5
  9. sglang/srt/disaggregation/nixl/conn.py +6 -6
  10. sglang/srt/disaggregation/prefill.py +2 -2
  11. sglang/srt/disaggregation/utils.py +1 -1
  12. sglang/srt/distributed/parallel_state.py +44 -17
  13. sglang/srt/entrypoints/EngineBase.py +8 -0
  14. sglang/srt/entrypoints/engine.py +40 -6
  15. sglang/srt/entrypoints/http_server.py +111 -24
  16. sglang/srt/entrypoints/openai/protocol.py +4 -2
  17. sglang/srt/eplb/__init__.py +0 -0
  18. sglang/srt/{managers → eplb}/eplb_algorithms/__init__.py +1 -1
  19. sglang/srt/{managers → eplb}/eplb_manager.py +2 -4
  20. sglang/srt/{eplb_simulator → eplb/eplb_simulator}/reader.py +1 -1
  21. sglang/srt/{managers → eplb}/expert_distribution.py +1 -5
  22. sglang/srt/{managers → eplb}/expert_location.py +1 -1
  23. sglang/srt/{managers → eplb}/expert_location_dispatch.py +1 -1
  24. sglang/srt/{model_executor → eplb}/expert_location_updater.py +17 -1
  25. sglang/srt/hf_transformers_utils.py +2 -1
  26. sglang/srt/layers/activation.py +2 -2
  27. sglang/srt/layers/amx_utils.py +86 -0
  28. sglang/srt/layers/attention/ascend_backend.py +219 -0
  29. sglang/srt/layers/attention/flashattention_backend.py +32 -9
  30. sglang/srt/layers/attention/tbo_backend.py +37 -9
  31. sglang/srt/layers/communicator.py +18 -2
  32. sglang/srt/layers/dp_attention.py +9 -3
  33. sglang/srt/layers/elementwise.py +76 -12
  34. sglang/srt/layers/flashinfer_comm_fusion.py +202 -0
  35. sglang/srt/layers/layernorm.py +26 -0
  36. sglang/srt/layers/linear.py +84 -14
  37. sglang/srt/layers/logits_processor.py +4 -4
  38. sglang/srt/layers/moe/ep_moe/kernels.py +23 -8
  39. sglang/srt/layers/moe/ep_moe/layer.py +36 -13
  40. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +23 -17
  41. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -2
  42. sglang/srt/layers/moe/fused_moe_triton/layer.py +76 -16
  43. sglang/srt/layers/moe/router.py +60 -22
  44. sglang/srt/layers/moe/topk.py +10 -28
  45. sglang/srt/layers/parameter.py +67 -7
  46. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +1 -1
  47. sglang/srt/layers/quantization/fp8.py +44 -0
  48. sglang/srt/layers/quantization/fp8_kernel.py +1 -1
  49. sglang/srt/layers/quantization/fp8_utils.py +1 -2
  50. sglang/srt/layers/quantization/gptq.py +5 -1
  51. sglang/srt/layers/quantization/moe_wna16.py +1 -1
  52. sglang/srt/layers/quantization/quant_utils.py +166 -0
  53. sglang/srt/layers/quantization/w8a8_int8.py +52 -1
  54. sglang/srt/layers/rotary_embedding.py +2 -2
  55. sglang/srt/layers/vocab_parallel_embedding.py +11 -7
  56. sglang/srt/lora/lora.py +4 -5
  57. sglang/srt/lora/lora_manager.py +73 -20
  58. sglang/srt/managers/configure_logging.py +1 -1
  59. sglang/srt/managers/io_struct.py +50 -13
  60. sglang/srt/managers/mm_utils.py +73 -59
  61. sglang/srt/managers/multimodal_processor.py +2 -6
  62. sglang/srt/managers/multimodal_processors/qwen_audio.py +94 -0
  63. sglang/srt/managers/schedule_batch.py +77 -84
  64. sglang/srt/managers/scheduler.py +113 -59
  65. sglang/srt/managers/scheduler_output_processor_mixin.py +8 -2
  66. sglang/srt/managers/session_controller.py +12 -3
  67. sglang/srt/managers/tokenizer_manager.py +314 -103
  68. sglang/srt/managers/tp_worker.py +13 -1
  69. sglang/srt/managers/tp_worker_overlap_thread.py +8 -0
  70. sglang/srt/mem_cache/allocator.py +290 -0
  71. sglang/srt/mem_cache/chunk_cache.py +34 -2
  72. sglang/srt/mem_cache/memory_pool.py +289 -3
  73. sglang/srt/mem_cache/multimodal_cache.py +3 -0
  74. sglang/srt/model_executor/cuda_graph_runner.py +2 -1
  75. sglang/srt/model_executor/forward_batch_info.py +17 -4
  76. sglang/srt/model_executor/model_runner.py +297 -56
  77. sglang/srt/model_loader/loader.py +41 -0
  78. sglang/srt/model_loader/weight_utils.py +72 -4
  79. sglang/srt/models/deepseek_nextn.py +1 -3
  80. sglang/srt/models/deepseek_v2.py +181 -45
  81. sglang/srt/models/deepseek_vl2.py +3 -5
  82. sglang/srt/models/gemma3_causal.py +1 -2
  83. sglang/srt/models/gemma3n_causal.py +4 -3
  84. sglang/srt/models/gemma3n_mm.py +4 -20
  85. sglang/srt/models/hunyuan.py +1 -1
  86. sglang/srt/models/kimi_vl.py +1 -2
  87. sglang/srt/models/llama.py +10 -4
  88. sglang/srt/models/llama4.py +32 -45
  89. sglang/srt/models/llama_eagle3.py +61 -11
  90. sglang/srt/models/llava.py +5 -5
  91. sglang/srt/models/minicpmo.py +2 -2
  92. sglang/srt/models/mistral.py +1 -1
  93. sglang/srt/models/mllama4.py +43 -11
  94. sglang/srt/models/phi4mm.py +1 -3
  95. sglang/srt/models/pixtral.py +3 -7
  96. sglang/srt/models/qwen2.py +31 -3
  97. sglang/srt/models/qwen2_5_vl.py +1 -3
  98. sglang/srt/models/qwen2_audio.py +200 -0
  99. sglang/srt/models/qwen2_moe.py +32 -6
  100. sglang/srt/models/qwen2_vl.py +1 -4
  101. sglang/srt/models/qwen3.py +94 -25
  102. sglang/srt/models/qwen3_moe.py +68 -21
  103. sglang/srt/models/vila.py +3 -8
  104. sglang/srt/{managers/multimodal_processors → multimodal/processors}/base_processor.py +140 -158
  105. sglang/srt/{managers/multimodal_processors → multimodal/processors}/clip.py +2 -13
  106. sglang/srt/{managers/multimodal_processors → multimodal/processors}/deepseek_vl_v2.py +4 -11
  107. sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3.py +3 -10
  108. sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3n.py +5 -20
  109. sglang/srt/{managers/multimodal_processors → multimodal/processors}/internvl.py +3 -10
  110. sglang/srt/{managers/multimodal_processors → multimodal/processors}/janus_pro.py +3 -9
  111. sglang/srt/{managers/multimodal_processors → multimodal/processors}/kimi_vl.py +6 -13
  112. sglang/srt/{managers/multimodal_processors → multimodal/processors}/llava.py +2 -10
  113. sglang/srt/{managers/multimodal_processors → multimodal/processors}/minicpm.py +5 -12
  114. sglang/srt/{managers/multimodal_processors → multimodal/processors}/mlama.py +2 -14
  115. sglang/srt/{managers/multimodal_processors → multimodal/processors}/mllama4.py +3 -6
  116. sglang/srt/{managers/multimodal_processors → multimodal/processors}/phi4mm.py +4 -14
  117. sglang/srt/{managers/multimodal_processors → multimodal/processors}/pixtral.py +3 -9
  118. sglang/srt/{managers/multimodal_processors → multimodal/processors}/qwen_vl.py +8 -14
  119. sglang/srt/{managers/multimodal_processors → multimodal/processors}/vila.py +13 -31
  120. sglang/srt/operations_strategy.py +6 -2
  121. sglang/srt/reasoning_parser.py +26 -0
  122. sglang/srt/sampling/sampling_batch_info.py +39 -1
  123. sglang/srt/server_args.py +69 -22
  124. sglang/srt/speculative/build_eagle_tree.py +57 -18
  125. sglang/srt/speculative/eagle_worker.py +6 -4
  126. sglang/srt/two_batch_overlap.py +200 -27
  127. sglang/srt/utils.py +306 -146
  128. sglang/srt/warmup.py +12 -3
  129. sglang/test/runners.py +10 -1
  130. sglang/test/test_utils.py +15 -3
  131. sglang/version.py +1 -1
  132. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.dist-info}/METADATA +9 -6
  133. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.dist-info}/RECORD +140 -133
  134. sglang/math_utils.py +0 -8
  135. /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek.py +0 -0
  136. /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek_vec.py +0 -0
  137. /sglang/srt/{eplb_simulator → eplb/eplb_simulator}/__init__.py +0 -0
  138. /sglang/srt/{mm_utils.py → multimodal/mm_utils.py} +0 -0
  139. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.dist-info}/WHEEL +0 -0
  140. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.dist-info}/licenses/LICENSE +0 -0
  141. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.dist-info}/top_level.txt +0 -0
@@ -159,7 +159,7 @@ class NixlKVManager(CommonKVManager):
159
159
  self.kv_args.kv_data_ptrs, self.kv_args.kv_data_lens
160
160
  ):
161
161
  kv_addrs.append((kv_data_ptr, kv_data_len, self.kv_args.gpu_id, ""))
162
- self.kv_descs = self.agent.register_memory(kv_addrs, "VRAM", is_sorted=True)
162
+ self.kv_descs = self.agent.register_memory(kv_addrs, "VRAM", is_sorted=False)
163
163
  logger.debug(f"Register kv tensors, len(kv_addr)= {len(kv_addrs)}")
164
164
  if not self.kv_descs:
165
165
  raise Exception("NIXL memory registration failed for kv tensors")
@@ -168,7 +168,7 @@ class NixlKVManager(CommonKVManager):
168
168
  self.kv_args.aux_data_ptrs, self.kv_args.aux_data_lens
169
169
  ):
170
170
  aux_addrs.append((aux_data_ptr, aux_data_len, 0, ""))
171
- self.aux_descs = self.agent.register_memory(aux_addrs, "DRAM", is_sorted=True)
171
+ self.aux_descs = self.agent.register_memory(aux_addrs, "DRAM", is_sorted=False)
172
172
  logger.debug(f"Register aux tensors, len(aux_addrs)= {len(aux_addrs)}")
173
173
  if not self.aux_descs:
174
174
  raise Exception("NIXL memory registration failed for aux tensors")
@@ -215,8 +215,8 @@ class NixlKVManager(CommonKVManager):
215
215
  logger.debug(
216
216
  f"len(src_addrs): before group: {len(prefill_kv_indices)}, after group: {len(src_addrs)}"
217
217
  )
218
- src_descs = self.agent.get_xfer_descs(src_addrs, "VRAM", is_sorted=True)
219
- dst_descs = self.agent.get_xfer_descs(dst_addrs, "VRAM", is_sorted=True)
218
+ src_descs = self.agent.get_xfer_descs(src_addrs, "VRAM", is_sorted=False)
219
+ dst_descs = self.agent.get_xfer_descs(dst_addrs, "VRAM", is_sorted=False)
220
220
  # Transfer data
221
221
  xfer_handle = self.agent.initialize_xfer(
222
222
  "WRITE",
@@ -248,8 +248,8 @@ class NixlKVManager(CommonKVManager):
248
248
  decode_aux_addr = dst_aux_ptrs[0] + dst_aux_index * aux_item_len
249
249
  src_addrs = [(prefill_aux_addr, aux_item_len, 0)]
250
250
  dst_addrs = [(decode_aux_addr, aux_item_len, 0)]
251
- src_descs = self.agent.get_xfer_descs(src_addrs, "DRAM", is_sorted=True)
252
- dst_descs = self.agent.get_xfer_descs(dst_addrs, "DRAM", is_sorted=True)
251
+ src_descs = self.agent.get_xfer_descs(src_addrs, "DRAM", is_sorted=False)
252
+ dst_descs = self.agent.get_xfer_descs(dst_addrs, "DRAM", is_sorted=False)
253
253
  # Transfer data
254
254
  xfer_handle = self.agent.initialize_xfer(
255
255
  "WRITE",
@@ -276,7 +276,7 @@ class SchedulerDisaggregationPrefillMixin:
276
276
  batch = self.get_new_batch_prefill()
277
277
 
278
278
  if require_mlp_sync(self.server_args):
279
- batch, _ = self.prepare_mlp_sync_batch(batch)
279
+ batch = self.prepare_mlp_sync_batch(batch)
280
280
  self.cur_batch = batch
281
281
 
282
282
  if batch:
@@ -310,7 +310,7 @@ class SchedulerDisaggregationPrefillMixin:
310
310
  batch = self.get_new_batch_prefill()
311
311
 
312
312
  if require_mlp_sync(self.server_args):
313
- batch, _ = self.prepare_mlp_sync_batch(batch)
313
+ batch = self.prepare_mlp_sync_batch(batch)
314
314
  self.cur_batch = batch
315
315
  if batch:
316
316
  result = self.run_batch(batch)
@@ -74,7 +74,7 @@ class ReqToMetadataIdxAllocator:
74
74
  def available_size(self):
75
75
  return len(self.free_slots)
76
76
 
77
- def alloc(self) -> List[int]:
77
+ def alloc(self) -> Optional[int]:
78
78
  if len(self.free_slots) == 0:
79
79
  return None
80
80
 
@@ -42,8 +42,10 @@ from torch.distributed import Backend, ProcessGroup
42
42
  from sglang.srt.utils import (
43
43
  direct_register_custom_op,
44
44
  get_bool_env_var,
45
+ get_int_env_var,
45
46
  is_cuda_alike,
46
47
  is_npu,
48
+ is_shm_available,
47
49
  supports_custom_op,
48
50
  )
49
51
 
@@ -222,6 +224,7 @@ class GroupCoordinator:
222
224
  self.local_rank = local_rank
223
225
  self.device_group = None
224
226
  self.cpu_group = None
227
+ self.local_size = get_int_env_var("LOCAL_SIZE", 0)
225
228
 
226
229
  for ranks in group_ranks:
227
230
  device_group = torch.distributed.new_group(
@@ -440,9 +443,12 @@ class GroupCoordinator:
440
443
  return input_
441
444
 
442
445
  if input_.is_cpu:
443
- import intel_extension_for_pytorch as ipex
444
-
445
- ipex.distributed.all_reduce(input_, group=self.device_group)
446
+ if is_shm_available(input_.dtype, self.world_size, self.local_size):
447
+ torch.ops.sgl_kernel.shm_allreduce(
448
+ input_, torch.distributed.ReduceOp.SUM
449
+ )
450
+ else:
451
+ torch.distributed.all_reduce(input_, group=self.device_group)
446
452
  return input_
447
453
 
448
454
  if not supports_custom_op():
@@ -570,6 +576,16 @@ class GroupCoordinator:
570
576
  output_tensor = torch.empty(
571
577
  output_size, dtype=input_.dtype, device=input_.device
572
578
  )
579
+
580
+ if input_.is_cpu:
581
+ if is_shm_available(input_.dtype, self.world_size, self.local_size):
582
+ return torch.ops.sgl_kernel.shm_allgather(input_, dim)
583
+ else:
584
+ torch.distributed.all_gather_into_tensor(
585
+ output_tensor, input_, group=self.device_group
586
+ )
587
+ return output_tensor
588
+
573
589
  # All-gather.
574
590
  self.all_gather_into_tensor(output_tensor, input_)
575
591
  # Reshape
@@ -683,18 +699,25 @@ class GroupCoordinator:
683
699
  )
684
700
 
685
701
  # Serialize object to tensor and get the size as well
686
- object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8)
702
+ object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8).cuda(
703
+ device=torch.cuda.current_device()
704
+ )
687
705
 
688
706
  size_tensor = torch.tensor(
689
- [object_tensor.numel()], dtype=torch.long, device="cpu"
707
+ [object_tensor.numel()],
708
+ dtype=torch.long,
709
+ device=torch.cuda.current_device(),
690
710
  )
691
711
 
692
712
  # Send object size
693
-
694
- torch.distributed.send(size_tensor, dst=self.ranks[dst], group=self.cpu_group)
713
+ torch.distributed.send(
714
+ size_tensor, dst=self.ranks[dst], group=self.device_group
715
+ )
695
716
 
696
717
  # Send object
697
- torch.distributed.send(object_tensor, dst=self.ranks[dst], group=self.cpu_group)
718
+ torch.distributed.send(
719
+ object_tensor, dst=self.ranks[dst], group=self.device_group
720
+ )
698
721
 
699
722
  return None
700
723
 
@@ -708,29 +731,31 @@ class GroupCoordinator:
708
731
  src != self.rank_in_group
709
732
  ), "Invalid source rank. Source rank is the same as the current rank."
710
733
 
711
- size_tensor = torch.empty(1, dtype=torch.long, device="cpu")
734
+ size_tensor = torch.empty(
735
+ 1, dtype=torch.long, device=torch.cuda.current_device()
736
+ )
712
737
 
713
738
  # Receive object size
714
739
  rank_size = torch.distributed.recv(
715
- size_tensor, src=self.ranks[src], group=self.cpu_group
740
+ size_tensor, src=self.ranks[src], group=self.device_group
716
741
  )
717
742
 
718
743
  # Tensor to receive serialized objects into.
719
744
  object_tensor = torch.empty( # type: ignore[call-overload]
720
745
  size_tensor.item(), # type: ignore[arg-type]
721
746
  dtype=torch.uint8,
722
- device="cpu",
747
+ device=torch.cuda.current_device(),
723
748
  )
724
749
 
725
750
  rank_object = torch.distributed.recv(
726
- object_tensor, src=self.ranks[src], group=self.cpu_group
751
+ object_tensor, src=self.ranks[src], group=self.device_group
727
752
  )
728
753
 
729
754
  assert (
730
755
  rank_object == rank_size
731
756
  ), "Received object sender rank does not match the size sender rank."
732
757
 
733
- obj = pickle.loads(object_tensor.numpy().tobytes())
758
+ obj = pickle.loads(object_tensor.cpu().numpy().tobytes())
734
759
 
735
760
  return obj
736
761
 
@@ -841,14 +866,16 @@ class GroupCoordinator:
841
866
  dst = (self.rank_in_group + 1) % self.world_size
842
867
  assert dst < self.world_size, f"Invalid dst rank ({dst})"
843
868
 
844
- metadata_list: List[Tuple[Any, Any]] = []
845
869
  assert isinstance(
846
870
  tensor_dict, dict
847
871
  ), f"Expecting a dictionary, got {type(tensor_dict)}"
848
872
  metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
849
- # `metadata_list` lives in CPU memory.
850
- # `send_object_list` has serialization & deserialization,
851
- # all happening on CPU. Therefore, we can use the CPU group.
873
+ # Note: While switching to Device-to-Device (D2D) would introduce an extra
874
+ # Device-to-Host (D2H) memory copy overhead for serialization, our benchmarks
875
+ # show better overall transmission performance with D2D due to:
876
+ # 1. Superior D2D transfer bandwidth
877
+ # 2. Ability to overlap send and recv operations
878
+ # Thus the net performance gain justifies this approach.
852
879
  self.send_object(metadata_list, dst=dst)
853
880
  for tensor in tensor_list:
854
881
  if tensor.numel() == 0:
@@ -48,6 +48,14 @@ class EngineBase(ABC):
48
48
  """Update model weights with in-memory tensor data."""
49
49
  pass
50
50
 
51
+ def load_lora_adapter(self, lora_name: str, lora_path: str):
52
+ """Load a new LoRA adapter without re-launching the engine."""
53
+ pass
54
+
55
+ def unload_lora_adapter(self, lora_name: str):
56
+ """Unload a LoRA adapter without re-launching the engine."""
57
+ pass
58
+
51
59
  @abstractmethod
52
60
  def release_memory_occupation(self):
53
61
  """Release GPU memory occupation temporarily."""
@@ -48,10 +48,12 @@ from sglang.srt.managers.io_struct import (
48
48
  GetWeightsByNameReqInput,
49
49
  ImageDataItem,
50
50
  InitWeightsUpdateGroupReqInput,
51
+ LoadLoRAAdapterReqInput,
51
52
  ReleaseMemoryOccupationReqInput,
52
53
  ResumeMemoryOccupationReqInput,
53
54
  RpcReqInput,
54
55
  RpcReqOutput,
56
+ UnloadLoRAAdapterReqInput,
55
57
  UpdateWeightFromDiskReqInput,
56
58
  UpdateWeightsFromDistributedReqInput,
57
59
  UpdateWeightsFromTensorReqInput,
@@ -416,12 +418,21 @@ class Engine(EngineBase):
416
418
  self.tokenizer_manager.init_weights_update_group(obj, None)
417
419
  )
418
420
 
419
- def update_weights_from_distributed(self, name: str, dtype, shape):
421
+ def update_weights_from_distributed(
422
+ self,
423
+ names: list[str],
424
+ dtypes: list[str],
425
+ shapes: list[list[int]],
426
+ group_name: str = "weight_update_group",
427
+ flush_cache: bool = True,
428
+ ):
420
429
  """Update weights from distributed source."""
421
430
  obj = UpdateWeightsFromDistributedReqInput(
422
- name=name,
423
- dtype=dtype,
424
- shape=shape,
431
+ names=names,
432
+ dtypes=dtypes,
433
+ shapes=shapes,
434
+ group_name=group_name,
435
+ flush_cache=flush_cache,
425
436
  )
426
437
  loop = asyncio.get_event_loop()
427
438
  return loop.run_until_complete(
@@ -478,6 +489,29 @@ class Engine(EngineBase):
478
489
  self.tokenizer_manager.get_weights_by_name(obj, None)
479
490
  )
480
491
 
492
+ def load_lora_adapter(self, lora_name: str, lora_path: str):
493
+ """Load a new LoRA adapter without re-launching the engine."""
494
+
495
+ obj = LoadLoRAAdapterReqInput(
496
+ lora_name=lora_name,
497
+ lora_path=lora_path,
498
+ )
499
+
500
+ loop = asyncio.get_event_loop()
501
+ return loop.run_until_complete(
502
+ self.tokenizer_manager.load_lora_adapter(obj, None)
503
+ )
504
+
505
+ def unload_lora_adapter(self, lora_name: str):
506
+ """Unload a LoRA adapter without re-launching the engine."""
507
+
508
+ obj = UnloadLoRAAdapterReqInput(lora_name=lora_name)
509
+
510
+ loop = asyncio.get_event_loop()
511
+ return loop.run_until_complete(
512
+ self.tokenizer_manager.unload_lora_adapter(obj, None)
513
+ )
514
+
481
515
  def release_memory_occupation(self, tags: Optional[List[str]] = None):
482
516
  obj = ReleaseMemoryOccupationReqInput(tags=tags)
483
517
  loop = asyncio.get_event_loop()
@@ -608,7 +642,7 @@ def _set_envs_and_config(server_args: ServerArgs):
608
642
  if server_args.attention_backend == "flashinfer":
609
643
  assert_pkg_version(
610
644
  "flashinfer_python",
611
- "0.2.6.post1",
645
+ "0.2.7.post1",
612
646
  "Please uninstall the old version and "
613
647
  "reinstall the latest version by following the instructions "
614
648
  "at https://docs.flashinfer.ai/installation.html.",
@@ -616,7 +650,7 @@ def _set_envs_and_config(server_args: ServerArgs):
616
650
  if _is_cuda:
617
651
  assert_pkg_version(
618
652
  "sgl-kernel",
619
- "0.1.9",
653
+ "0.2.4",
620
654
  "Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`",
621
655
  )
622
656
 
@@ -72,6 +72,7 @@ from sglang.srt.managers.io_struct import (
72
72
  GenerateReqInput,
73
73
  GetWeightsByNameReqInput,
74
74
  InitWeightsUpdateGroupReqInput,
75
+ LoadLoRAAdapterReqInput,
75
76
  OpenSessionReqInput,
76
77
  ParseFunctionCallReq,
77
78
  ProfileReqInput,
@@ -80,6 +81,7 @@ from sglang.srt.managers.io_struct import (
80
81
  SeparateReasoningReqInput,
81
82
  SetInternalStateReq,
82
83
  SlowDownReqInput,
84
+ UnloadLoRAAdapterReqInput,
83
85
  UpdateWeightFromDiskReqInput,
84
86
  UpdateWeightsFromDistributedReqInput,
85
87
  UpdateWeightsFromTensorReqInput,
@@ -124,8 +126,6 @@ def set_global_state(global_state: _GlobalState):
124
126
 
125
127
  @asynccontextmanager
126
128
  async def lifespan(fast_api_app: FastAPI):
127
- server_args: ServerArgs = fast_api_app.server_args
128
-
129
129
  # Initialize OpenAI serving handlers
130
130
  fast_api_app.state.openai_serving_completion = OpenAIServingCompletion(
131
131
  _global_state.tokenizer_manager, _global_state.template_manager
@@ -143,9 +143,12 @@ async def lifespan(fast_api_app: FastAPI):
143
143
  _global_state.tokenizer_manager
144
144
  )
145
145
 
146
+ server_args: ServerArgs = fast_api_app.server_args
146
147
  if server_args.warmups is not None:
147
148
  await execute_warmups(
148
- server_args.warmups.split(","), _global_state.tokenizer_manager
149
+ server_args.disaggregation_mode,
150
+ server_args.warmups.split(","),
151
+ _global_state.tokenizer_manager,
149
152
  )
150
153
  logger.info("Warmup ended")
151
154
 
@@ -278,13 +281,17 @@ async def get_model_info():
278
281
  "model_path": _global_state.tokenizer_manager.model_path,
279
282
  "tokenizer_path": _global_state.tokenizer_manager.server_args.tokenizer_path,
280
283
  "is_generation": _global_state.tokenizer_manager.is_generation,
284
+ "preferred_sampling_params": _global_state.tokenizer_manager.server_args.preferred_sampling_params,
281
285
  }
282
286
  return result
283
287
 
284
288
 
285
289
  @app.get("/get_server_info")
286
290
  async def get_server_info():
287
- internal_states = await _global_state.tokenizer_manager.get_internal_state()
291
+ # Returns interna states per DP.
292
+ internal_states: List[Dict[Any, Any]] = (
293
+ await _global_state.tokenizer_manager.get_internal_state()
294
+ )
288
295
  return {
289
296
  **dataclasses.asdict(_global_state.tokenizer_manager.server_args),
290
297
  **_global_state.scheduler_info,
@@ -298,6 +305,8 @@ async def get_load():
298
305
  return await _global_state.tokenizer_manager.get_load()
299
306
 
300
307
 
308
+ # example usage:
309
+ # curl -s -X POST http://localhost:30000/set_internal_state -H "Content-Type: application/json" -d '{"server_args": {"max_micro_batch_size": 8}}'
301
310
  @app.api_route("/set_internal_state", methods=["POST", "PUT"])
302
311
  async def set_internal_state(obj: SetInternalStateReq, request: Request):
303
312
  res = await _global_state.tokenizer_manager.set_internal_state(obj)
@@ -351,8 +360,7 @@ async def generate_from_file_request(file: UploadFile, request: Request):
351
360
  obj = GenerateReqInput(
352
361
  input_embeds=input_embeds,
353
362
  sampling_params={
354
- "repetition_penalty": 1.2,
355
- "temperature": 0.2,
363
+ "temperature": 0.0,
356
364
  "max_new_tokens": 512,
357
365
  },
358
366
  )
@@ -391,16 +399,6 @@ async def classify_request(obj: EmbeddingReqInput, request: Request):
391
399
  return _create_error_response(e)
392
400
 
393
401
 
394
- @app.api_route(
395
- "/v1/rerank", methods=["POST", "PUT"], dependencies=[Depends(validate_json_request)]
396
- )
397
- async def v1_rerank_request(request: V1RerankReqInput, raw_request: Request):
398
- """Endpoint for reranking documents based on query relevance."""
399
- return await raw_request.app.state.openai_serving_rerank.handle_request(
400
- request, raw_request
401
- )
402
-
403
-
404
402
  @app.api_route("/flush_cache", methods=["GET", "POST"])
405
403
  async def flush_cache():
406
404
  """Flush the radix cache."""
@@ -595,6 +593,40 @@ async def slow_down(obj: SlowDownReqInput, request: Request):
595
593
  return _create_error_response(e)
596
594
 
597
595
 
596
+ @app.api_route("/load_lora_adapter", methods=["POST"])
597
+ async def load_lora_adapter(obj: LoadLoRAAdapterReqInput, request: Request):
598
+ """Load a new LoRA adapter without re-launching the server."""
599
+ result = await _global_state.tokenizer_manager.load_lora_adapter(obj, request)
600
+
601
+ if result.success:
602
+ return ORJSONResponse(
603
+ result,
604
+ status_code=HTTPStatus.OK,
605
+ )
606
+ else:
607
+ return ORJSONResponse(
608
+ result,
609
+ status_code=HTTPStatus.BAD_REQUEST,
610
+ )
611
+
612
+
613
+ @app.api_route("/unload_lora_adapter", methods=["POST"])
614
+ async def unload_lora_adapter(obj: UnloadLoRAAdapterReqInput, request: Request):
615
+ """Load a new LoRA adapter without re-launching the server."""
616
+ result = await _global_state.tokenizer_manager.unload_lora_adapter(obj, request)
617
+
618
+ if result.success:
619
+ return ORJSONResponse(
620
+ result,
621
+ status_code=HTTPStatus.OK,
622
+ )
623
+ else:
624
+ return ORJSONResponse(
625
+ result,
626
+ status_code=HTTPStatus.BAD_REQUEST,
627
+ )
628
+
629
+
598
630
  @app.api_route("/open_session", methods=["GET", "POST"])
599
631
  async def open_session(obj: OpenSessionReqInput, request: Request):
600
632
  """Open a session, and return its unique session id."""
@@ -630,7 +662,9 @@ async def configure_logging(obj: ConfigureLoggingReq, request: Request):
630
662
  async def abort_request(obj: AbortReq, request: Request):
631
663
  """Abort a request."""
632
664
  try:
633
- _global_state.tokenizer_manager.abort_request(rid=obj.rid)
665
+ _global_state.tokenizer_manager.abort_request(
666
+ rid=obj.rid, abort_all=obj.abort_all
667
+ )
634
668
  return Response(status_code=200)
635
669
  except Exception as e:
636
670
  return _create_error_response(e)
@@ -678,6 +712,26 @@ async def separate_reasoning_request(obj: SeparateReasoningReqInput, request: Re
678
712
  return ORJSONResponse(content=response_data, status_code=200)
679
713
 
680
714
 
715
+ @app.post("/pause_generation")
716
+ async def pause_generation(request: Request):
717
+ """Pause generation."""
718
+ await _global_state.tokenizer_manager.pause_generation()
719
+ return ORJSONResponse(
720
+ content={"message": "Generation paused successfully.", "status": "ok"},
721
+ status_code=200,
722
+ )
723
+
724
+
725
+ @app.post("/continue_generation")
726
+ async def continue_generation(request: Request):
727
+ """Continue generation."""
728
+ await _global_state.tokenizer_manager.continue_generation()
729
+ return ORJSONResponse(
730
+ content={"message": "Generation continued successfully.", "status": "ok"},
731
+ status_code=200,
732
+ )
733
+
734
+
681
735
  ##### OpenAI-compatible API endpoints #####
682
736
 
683
737
 
@@ -805,6 +859,16 @@ async def v1_score_request(request: ScoringRequest, raw_request: Request):
805
859
  )
806
860
 
807
861
 
862
+ @app.api_route(
863
+ "/v1/rerank", methods=["POST", "PUT"], dependencies=[Depends(validate_json_request)]
864
+ )
865
+ async def v1_rerank_request(request: V1RerankReqInput, raw_request: Request):
866
+ """Endpoint for reranking documents based on query relevance."""
867
+ return await raw_request.app.state.openai_serving_rerank.handle_request(
868
+ request, raw_request
869
+ )
870
+
871
+
808
872
  def _create_error_response(e):
809
873
  return ORJSONResponse(
810
874
  {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
@@ -851,6 +915,15 @@ def launch_server(
851
915
  add_prometheus_middleware(app)
852
916
  enable_func_timer()
853
917
 
918
+ image_token_text = None
919
+ if (
920
+ tokenizer_manager.image_token_id is not None
921
+ and not server_args.skip_tokenizer_init
922
+ ):
923
+ image_token_text = tokenizer_manager.tokenizer.decode(
924
+ [tokenizer_manager.image_token_id]
925
+ )
926
+
854
927
  # Send a warmup request - we will create the thread launch it
855
928
  # in the lifespan after all other warmups have fired.
856
929
  warmup_thread = threading.Thread(
@@ -858,7 +931,7 @@ def launch_server(
858
931
  args=(
859
932
  server_args,
860
933
  pipe_finish_writer,
861
- _global_state.tokenizer_manager.image_token_id,
934
+ image_token_text,
862
935
  launch_callback,
863
936
  ),
864
937
  )
@@ -881,11 +954,9 @@ def launch_server(
881
954
  warmup_thread.join()
882
955
 
883
956
 
884
- def _wait_and_warmup(
957
+ def _execute_server_warmup(
885
958
  server_args: ServerArgs,
886
959
  pipe_finish_writer: Optional[multiprocessing.connection.Connection],
887
- image_token_text: str,
888
- launch_callback: Optional[Callable[[], None]] = None,
889
960
  ):
890
961
  headers = {}
891
962
  url = server_args.url()
@@ -910,7 +981,7 @@ def _wait_and_warmup(
910
981
  pipe_finish_writer.send(last_traceback)
911
982
  logger.error(f"Initialization failed. warmup error: {last_traceback}")
912
983
  kill_process_tree(os.getpid())
913
- return
984
+ return success
914
985
 
915
986
  model_info = res.json()
916
987
 
@@ -984,12 +1055,28 @@ def _wait_and_warmup(
984
1055
  pipe_finish_writer.send(last_traceback)
985
1056
  logger.error(f"Initialization failed. warmup error: {last_traceback}")
986
1057
  kill_process_tree(os.getpid())
987
- return
1058
+ return False
988
1059
 
989
1060
  # Debug print
990
- # logger.info(f"{res.json()=}")
1061
+ # logger.info(f"warmup request returns: {res.json()=}")
1062
+ return success
1063
+
1064
+
1065
+ def _wait_and_warmup(
1066
+ server_args: ServerArgs,
1067
+ pipe_finish_writer: Optional[multiprocessing.connection.Connection],
1068
+ image_token_text: str,
1069
+ launch_callback: Optional[Callable[[], None]] = None,
1070
+ ):
1071
+ if not server_args.skip_server_warmup:
1072
+ if not _execute_server_warmup(
1073
+ server_args,
1074
+ pipe_finish_writer,
1075
+ ):
1076
+ return
991
1077
 
992
1078
  logger.info("The server is fired up and ready to roll!")
1079
+
993
1080
  if pipe_finish_writer is not None:
994
1081
  pipe_finish_writer.send("ready")
995
1082
 
@@ -236,7 +236,7 @@ class CompletionResponseStreamChoice(BaseModel):
236
236
  index: int
237
237
  text: str
238
238
  logprobs: Optional[LogProbs] = None
239
- finish_reason: Optional[Literal["stop", "length", "content_filter"]] = None
239
+ finish_reason: Optional[Literal["stop", "length", "content_filter", "abort"]] = None
240
240
  matched_stop: Union[None, int, str] = None
241
241
  hidden_states: Optional[object] = None
242
242
 
@@ -510,7 +510,9 @@ class ChatCompletionResponseStreamChoice(BaseModel):
510
510
  delta: DeltaMessage
511
511
  logprobs: Optional[Union[LogProbs, ChoiceLogprobs]] = None
512
512
  finish_reason: Optional[
513
- Literal["stop", "length", "tool_calls", "content_filter", "function_call"]
513
+ Literal[
514
+ "stop", "length", "tool_calls", "content_filter", "function_call", "abort"
515
+ ]
514
516
  ] = None
515
517
  matched_stop: Union[None, int, str] = None
516
518
 
File without changes
@@ -3,7 +3,7 @@ from typing import Optional
3
3
 
4
4
  import torch
5
5
 
6
- from sglang.srt.managers.eplb_algorithms import deepseek, deepseek_vec
6
+ from sglang.srt.eplb.eplb_algorithms import deepseek, deepseek_vec
7
7
 
8
8
 
9
9
  class EplbAlgorithm(Enum):
@@ -4,10 +4,8 @@ from typing import TYPE_CHECKING, List
4
4
 
5
5
  import torch.cuda
6
6
 
7
- from sglang.srt.managers.expert_distribution import (
8
- get_global_expert_distribution_recorder,
9
- )
10
- from sglang.srt.managers.expert_location import ExpertLocationMetadata
7
+ from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
8
+ from sglang.srt.eplb.expert_location import ExpertLocationMetadata
11
9
 
12
10
  if TYPE_CHECKING:
13
11
  from sglang.srt.model_executor.model_runner import ModelRunner
@@ -4,7 +4,7 @@ from pathlib import Path
4
4
  import torch
5
5
  from tqdm import tqdm
6
6
 
7
- from sglang.srt.managers.expert_distribution import (
7
+ from sglang.srt.eplb.expert_distribution import (
8
8
  _convert_global_physical_count_to_logical_count,
9
9
  )
10
10
 
@@ -24,7 +24,7 @@ import einops
24
24
  import torch
25
25
  import torch.distributed
26
26
 
27
- from sglang.srt.managers.expert_location import ExpertLocationMetadata
27
+ from sglang.srt.eplb.expert_location import ExpertLocationMetadata
28
28
  from sglang.srt.managers.schedule_batch import global_server_args_dict
29
29
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
30
30
  from sglang.srt.server_args import ServerArgs
@@ -479,10 +479,6 @@ class _SelectExpertsSinglePassGatherer(_LayerBasedGpuSinglePassGatherer):
479
479
  def on_select_experts(self, layer_idx: int, topk_ids: torch.Tensor):
480
480
  topk_ids = topk_ids.flatten()
481
481
  mask = topk_ids != -1
482
- assert self._data[layer_idx, :].shape == topk_ids.shape, (
483
- "Shape mismatch between data and topk_ids."
484
- "Selecting expert is not supported for multiple token prediction at the moment."
485
- )
486
482
  self._data[layer_idx, :].scatter_add_(
487
483
  dim=0, index=topk_ids.masked_fill(~mask, 0).long(), src=mask.int()
488
484
  )
@@ -23,7 +23,7 @@ import torch.distributed
23
23
  import torch.nn.functional as F
24
24
 
25
25
  from sglang.srt.configs.model_config import ModelConfig
26
- from sglang.srt.managers import eplb_algorithms
26
+ from sglang.srt.eplb import eplb_algorithms
27
27
  from sglang.srt.model_loader import get_model_architecture
28
28
  from sglang.srt.server_args import ServerArgs
29
29
 
@@ -17,7 +17,7 @@ from typing import Literal, Optional
17
17
 
18
18
  import torch
19
19
 
20
- from sglang.srt.managers.expert_location import get_global_expert_location_metadata
20
+ from sglang.srt.eplb.expert_location import get_global_expert_location_metadata
21
21
  from sglang.srt.managers.schedule_batch import global_server_args_dict
22
22
 
23
23