sglang 0.4.10.post1__py3-none-any.whl → 0.4.10.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (80) hide show
  1. sglang/compile_deep_gemm.py +8 -1
  2. sglang/global_config.py +5 -1
  3. sglang/srt/conversation.py +0 -112
  4. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +1 -0
  5. sglang/srt/disaggregation/prefill.py +1 -0
  6. sglang/srt/distributed/device_communicators/pynccl.py +7 -0
  7. sglang/srt/distributed/device_communicators/pynccl_allocator.py +133 -0
  8. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +42 -3
  9. sglang/srt/distributed/parallel_state.py +11 -0
  10. sglang/srt/entrypoints/engine.py +4 -2
  11. sglang/srt/entrypoints/http_server.py +35 -15
  12. sglang/srt/eplb/expert_distribution.py +4 -2
  13. sglang/srt/hf_transformers_utils.py +25 -10
  14. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  15. sglang/srt/layers/attention/flashattention_backend.py +7 -11
  16. sglang/srt/layers/attention/trtllm_mla_backend.py +6 -6
  17. sglang/srt/layers/attention/vision.py +27 -10
  18. sglang/srt/layers/communicator.py +14 -4
  19. sglang/srt/layers/linear.py +7 -1
  20. sglang/srt/layers/logits_processor.py +9 -1
  21. sglang/srt/layers/moe/ep_moe/layer.py +11 -35
  22. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=352,device_name=NVIDIA_RTX_6000_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  23. sglang/srt/layers/moe/fused_moe_triton/layer.py +26 -23
  24. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +0 -31
  25. sglang/srt/layers/moe/token_dispatcher/__init__.py +23 -0
  26. sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +12 -1
  27. sglang/srt/layers/moe/{ep_moe/token_dispatcher.py → token_dispatcher/deepep.py} +8 -15
  28. sglang/srt/layers/moe/utils.py +43 -0
  29. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +3 -2
  30. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +1 -1
  31. sglang/srt/layers/quantization/fp8.py +5 -1
  32. sglang/srt/layers/quantization/fp8_kernel.py +0 -4
  33. sglang/srt/layers/vocab_parallel_embedding.py +7 -1
  34. sglang/srt/lora/lora_registry.py +7 -0
  35. sglang/srt/managers/cache_controller.py +8 -4
  36. sglang/srt/managers/data_parallel_controller.py +52 -2
  37. sglang/srt/managers/io_struct.py +6 -1
  38. sglang/srt/managers/schedule_batch.py +3 -2
  39. sglang/srt/managers/schedule_policy.py +3 -1
  40. sglang/srt/managers/scheduler.py +144 -6
  41. sglang/srt/managers/template_manager.py +25 -22
  42. sglang/srt/managers/tokenizer_manager.py +114 -62
  43. sglang/srt/managers/utils.py +45 -1
  44. sglang/srt/mem_cache/cpp_radix_tree/radix_tree.py +182 -0
  45. sglang/srt/mem_cache/hicache_storage.py +13 -21
  46. sglang/srt/mem_cache/radix_cache_cpp.py +229 -0
  47. sglang/srt/mem_cache/storage/hf3fs/hf3fs_utils.cpp +35 -0
  48. sglang/srt/model_executor/cuda_graph_runner.py +17 -3
  49. sglang/srt/model_executor/forward_batch_info.py +13 -3
  50. sglang/srt/model_executor/model_runner.py +5 -0
  51. sglang/srt/models/deepseek_v2.py +23 -17
  52. sglang/srt/models/glm4_moe.py +82 -19
  53. sglang/srt/models/grok.py +3 -3
  54. sglang/srt/models/llama4.py +13 -2
  55. sglang/srt/models/mixtral.py +3 -3
  56. sglang/srt/models/mllama4.py +428 -19
  57. sglang/srt/models/qwen2_moe.py +1 -4
  58. sglang/srt/models/qwen3_moe.py +7 -8
  59. sglang/srt/models/step3_vl.py +1 -1
  60. sglang/srt/multimodal/processors/base_processor.py +4 -3
  61. sglang/srt/multimodal/processors/gemma3n.py +0 -7
  62. sglang/srt/operations_strategy.py +1 -1
  63. sglang/srt/server_args.py +80 -20
  64. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +18 -0
  65. sglang/srt/two_batch_overlap.py +6 -4
  66. sglang/srt/utils.py +3 -24
  67. sglang/srt/weight_sync/utils.py +1 -1
  68. sglang/test/runners.py +2 -2
  69. sglang/test/test_utils.py +3 -3
  70. sglang/version.py +1 -1
  71. {sglang-0.4.10.post1.dist-info → sglang-0.4.10.post2.dist-info}/METADATA +3 -2
  72. {sglang-0.4.10.post1.dist-info → sglang-0.4.10.post2.dist-info}/RECORD +80 -74
  73. /sglang/srt/mem_cache/{mooncake_store → storage/mooncake_store}/mooncake_store.py +0 -0
  74. /sglang/srt/mem_cache/{mooncake_store → storage/mooncake_store}/unit_test.py +0 -0
  75. /sglang/srt/mem_cache/{nixl → storage/nixl}/hicache_nixl.py +0 -0
  76. /sglang/srt/mem_cache/{nixl → storage/nixl}/nixl_utils.py +0 -0
  77. /sglang/srt/mem_cache/{nixl → storage/nixl}/test_hicache_nixl_storage.py +0 -0
  78. {sglang-0.4.10.post1.dist-info → sglang-0.4.10.post2.dist-info}/WHEEL +0 -0
  79. {sglang-0.4.10.post1.dist-info → sglang-0.4.10.post2.dist-info}/licenses/LICENSE +0 -0
  80. {sglang-0.4.10.post1.dist-info → sglang-0.4.10.post2.dist-info}/top_level.txt +0 -0
@@ -64,6 +64,7 @@ from sglang.srt.hf_transformers_utils import (
64
64
  )
65
65
  from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
66
66
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
67
+ from sglang.srt.layers.moe.utils import DeepEPMode, MoeA2ABackend
67
68
  from sglang.srt.managers.io_struct import (
68
69
  AbortReq,
69
70
  CloseSessionReqInput,
@@ -125,7 +126,7 @@ from sglang.srt.managers.scheduler_update_weights_mixin import (
125
126
  from sglang.srt.managers.session_controller import Session
126
127
  from sglang.srt.managers.tp_worker import TpModelWorker
127
128
  from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
128
- from sglang.srt.managers.utils import validate_input_length
129
+ from sglang.srt.managers.utils import DPBalanceMeta, validate_input_length
129
130
  from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
130
131
  from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
131
132
  from sglang.srt.mem_cache.radix_cache import RadixCache
@@ -137,7 +138,6 @@ from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
137
138
  from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
138
139
  from sglang.srt.two_batch_overlap import TboDPAttentionPreparer
139
140
  from sglang.srt.utils import (
140
- DeepEPMode,
141
141
  DynamicGradMode,
142
142
  broadcast_pyobj,
143
143
  configure_gc_logger,
@@ -203,6 +203,7 @@ class Scheduler(
203
203
  moe_ep_rank: int,
204
204
  pp_rank: int,
205
205
  dp_rank: Optional[int],
206
+ dp_balance_meta: Optional[DPBalanceMeta] = None,
206
207
  ):
207
208
  # Parse args
208
209
  self.server_args = server_args
@@ -522,6 +523,15 @@ class Scheduler(
522
523
  ]
523
524
  )
524
525
 
526
+ self.balance_meta = dp_balance_meta
527
+ if (
528
+ server_args.enable_dp_attention
529
+ and server_args.load_balance_method == "minimum_tokens"
530
+ ):
531
+ assert dp_balance_meta is not None
532
+
533
+ self.recv_dp_balance_id_this_term = []
534
+
525
535
  def init_tokenizer(self):
526
536
  server_args = self.server_args
527
537
 
@@ -569,7 +579,23 @@ class Scheduler(
569
579
  page_size=self.page_size,
570
580
  )
571
581
  else:
572
- if self.enable_hierarchical_cache:
582
+ if os.environ.get("SGLANG_EXPERIMENTAL_CPP_RADIX_TREE") == "1":
583
+ # lazy import to avoid JIT overhead
584
+ from sglang.srt.mem_cache.radix_cache_cpp import RadixCacheCpp
585
+
586
+ self.tree_cache = RadixCacheCpp(
587
+ disable=False,
588
+ use_hicache=self.enable_hierarchical_cache,
589
+ req_to_token_pool=self.req_to_token_pool,
590
+ token_to_kv_pool=self.token_to_kv_pool_allocator,
591
+ tp_cache_group=self.tp_cpu_group,
592
+ page_size=self.page_size,
593
+ hicache_ratio=server_args.hicache_ratio,
594
+ hicache_size=server_args.hicache_size,
595
+ hicache_write_policy=server_args.hicache_write_policy,
596
+ enable_kv_cache_events=self.enable_kv_cache_events,
597
+ )
598
+ elif self.enable_hierarchical_cache:
573
599
  self.tree_cache = HiRadixCache(
574
600
  req_to_token_pool=self.req_to_token_pool,
575
601
  token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
@@ -1033,6 +1059,12 @@ class Scheduler(
1033
1059
  self,
1034
1060
  recv_req: TokenizedGenerateReqInput,
1035
1061
  ):
1062
+ if (
1063
+ self.server_args.enable_dp_attention
1064
+ and self.server_args.load_balance_method == "minimum_tokens"
1065
+ ):
1066
+ self.recv_dp_balance_id_this_term.append(recv_req.dp_balance_id)
1067
+
1036
1068
  # Create a new request
1037
1069
  if (
1038
1070
  recv_req.session_params is None
@@ -1443,6 +1475,11 @@ class Scheduler(
1443
1475
 
1444
1476
  # Handle DP attention
1445
1477
  if need_dp_attn_preparation:
1478
+ if (
1479
+ self.server_args.load_balance_method == "minimum_tokens"
1480
+ and self.forward_ct % 40 == 0
1481
+ ):
1482
+ self.handle_dp_balance_data(ret)
1446
1483
  ret = self.prepare_mlp_sync_batch(ret)
1447
1484
 
1448
1485
  return ret
@@ -1744,6 +1781,9 @@ class Scheduler(
1744
1781
  elif batch.forward_mode.is_dummy_first():
1745
1782
  self.set_next_batch_sampling_info_done(batch)
1746
1783
 
1784
+ self.maybe_send_health_check_signal()
1785
+
1786
+ def maybe_send_health_check_signal(self):
1747
1787
  if self.return_health_check_ct:
1748
1788
  # Return some signal for the health check.
1749
1789
  # This is used to prevent the health check signal being blocked by long context prefill.
@@ -1762,12 +1802,94 @@ class Scheduler(
1762
1802
  spec_algorithm=self.spec_algorithm,
1763
1803
  speculative_num_draft_tokens=self.server_args.speculative_num_draft_tokens,
1764
1804
  enable_two_batch_overlap=self.server_args.enable_two_batch_overlap,
1765
- enable_deepep_moe=self.server_args.enable_deepep_moe,
1766
- deepep_mode=DeepEPMode[self.server_args.deepep_mode],
1805
+ enable_deepep_moe=MoeA2ABackend(
1806
+ self.server_args.moe_a2a_backend
1807
+ ).is_deepep(),
1808
+ deepep_mode=DeepEPMode(self.server_args.deepep_mode),
1767
1809
  require_mlp_tp_gather=require_mlp_tp_gather(self.server_args),
1768
1810
  disable_overlap_schedule=self.server_args.disable_overlap_schedule,
1769
1811
  )
1770
1812
 
1813
+ def handle_dp_balance_data(self, local_batch: ScheduleBatch):
1814
+ def gather_dp_balance_info(holding_tokens_list) -> Union[None, List[List[int]]]:
1815
+ """gather recv_dp_balance_id_this_term and holding tokens per worker for dp balance"""
1816
+ recv_list = self.recv_dp_balance_id_this_term
1817
+ assert len(recv_list) <= 511, (
1818
+ "The number of requests received this round is too large. "
1819
+ "Please increase gather_tensor_size and onfly_info_size."
1820
+ )
1821
+ # The maximum size of the tensor used for gathering data from all workers.
1822
+ gather_tensor_size = 512
1823
+
1824
+ # recv_tensor: | holding_tokens | len(recv_dp_balance_id) | recv_dp_balance_ids
1825
+ recv_tensor = torch.zeros(gather_tensor_size, dtype=torch.int32)
1826
+ recv_tensor[0] = holding_tokens_list
1827
+ recv_tensor[1] = len(
1828
+ recv_list
1829
+ ) # The first element is the length of the list.
1830
+ recv_tensor[2 : len(recv_list) + 2] = torch.tensor(
1831
+ recv_list, dtype=torch.int32
1832
+ )
1833
+
1834
+ if self.tp_rank == 0:
1835
+ gathered_list = [
1836
+ torch.zeros(gather_tensor_size, dtype=torch.int32)
1837
+ for _ in range(self.balance_meta.num_workers)
1838
+ ]
1839
+ else:
1840
+ gathered_list = None
1841
+
1842
+ torch.distributed.gather(
1843
+ recv_tensor, gathered_list, group=self.tp_cpu_group
1844
+ )
1845
+
1846
+ gathered_id_list_per_worker = None
1847
+ if self.tp_rank == 0:
1848
+ gathered_id_list_per_worker = []
1849
+ holding_tokens_list = []
1850
+ for tensor in gathered_list:
1851
+ holding_tokens_list.append(tensor[0].item())
1852
+ list_length = tensor[1].item()
1853
+ gathered_id_list_per_worker.append(
1854
+ tensor[2 : list_length + 2].tolist()
1855
+ )
1856
+
1857
+ return gathered_id_list_per_worker, holding_tokens_list
1858
+
1859
+ def write_shared_dp_balance_info(new_recv_rid_lists, local_tokens):
1860
+ meta = self.balance_meta
1861
+
1862
+ with meta.mutex:
1863
+ onfly_list: List[Dict[int, int]] = meta.get_shared_onfly()
1864
+ assert len(new_recv_rid_lists) == len(
1865
+ onfly_list
1866
+ ), "num_worker not equal"
1867
+ # 1.Check if the rid received by each worker this round is present in onfly.
1868
+ # If it is, remove the corresponding onfly item.
1869
+ worker_id = 0
1870
+ for new_recv_rids, on_fly_reqs in zip(new_recv_rid_lists, onfly_list):
1871
+ for new_recv_rid in new_recv_rids:
1872
+ assert (
1873
+ new_recv_rid in on_fly_reqs
1874
+ ), f"{new_recv_rid=} not in {worker_id=} {on_fly_reqs=}, data consistency is wrong"
1875
+ del on_fly_reqs[new_recv_rid]
1876
+ worker_id += 1
1877
+ # 2. Atomically write local_tokens and onfly into shm under the mutex
1878
+ meta.set_shared_onfly_info(onfly_list)
1879
+ meta.set_shared_local_tokens(local_tokens)
1880
+
1881
+ holding_tokens = self.get_load()
1882
+
1883
+ new_recv_dp_balance_id_list, holding_token_list = gather_dp_balance_info(
1884
+ holding_tokens
1885
+ )
1886
+
1887
+ self.recv_dp_balance_id_this_term.clear()
1888
+ if self.tp_rank == 0: # only first worker write info
1889
+ write_shared_dp_balance_info(
1890
+ new_recv_dp_balance_id_list, holding_token_list
1891
+ )
1892
+
1771
1893
  @staticmethod
1772
1894
  def prepare_mlp_sync_batch_raw(
1773
1895
  local_batch: ScheduleBatch,
@@ -2344,11 +2466,19 @@ class IdleSleeper:
2344
2466
 
2345
2467
  def __init__(self, sockets):
2346
2468
  self.poller = zmq.Poller()
2469
+ self.last_empty_time = time.time()
2347
2470
  for s in sockets:
2348
2471
  self.poller.register(s, zmq.POLLIN)
2349
2472
 
2350
2473
  def maybe_sleep(self):
2351
2474
  self.poller.poll(1000)
2475
+ if (
2476
+ global_config.torch_empty_cache_interval > 0
2477
+ and time.time() - self.last_empty_time
2478
+ > global_config.torch_empty_cache_interval
2479
+ ):
2480
+ self.last_empty_time = time.time()
2481
+ torch.cuda.empty_cache()
2352
2482
 
2353
2483
 
2354
2484
  def is_health_check_generate_req(recv_req):
@@ -2368,6 +2498,7 @@ def run_scheduler_process(
2368
2498
  pp_rank: int,
2369
2499
  dp_rank: Optional[int],
2370
2500
  pipe_writer,
2501
+ balance_meta: Optional[DPBalanceMeta] = None,
2371
2502
  ):
2372
2503
  # Generate the prefix
2373
2504
  prefix = ""
@@ -2401,7 +2532,14 @@ def run_scheduler_process(
2401
2532
  # Create a scheduler and run the event loop
2402
2533
  try:
2403
2534
  scheduler = Scheduler(
2404
- server_args, port_args, gpu_id, tp_rank, moe_ep_rank, pp_rank, dp_rank
2535
+ server_args,
2536
+ port_args,
2537
+ gpu_id,
2538
+ tp_rank,
2539
+ moe_ep_rank,
2540
+ pp_rank,
2541
+ dp_rank,
2542
+ dp_balance_meta=balance_meta,
2405
2543
  )
2406
2544
  pipe_writer.send(
2407
2545
  {
@@ -84,26 +84,27 @@ class TemplateManager:
84
84
  if chat_template_arg:
85
85
  self._load_explicit_chat_template(tokenizer_manager, chat_template_arg)
86
86
  else:
87
- # Try HuggingFace template first
88
- hf_template = self._resolve_hf_chat_template(tokenizer_manager)
89
- if hf_template:
90
- self._jinja_template_content_format = (
91
- detect_jinja_template_content_format(hf_template)
92
- )
93
- logger.info(
94
- f"Using default HuggingFace chat template with detected content format: {self._jinja_template_content_format}"
95
- )
96
- return
97
-
98
- # Fallback to SGLang template guessing
87
+ # Guess chat template from model path
99
88
  self.guess_chat_template_from_model_path(model_path)
100
89
 
101
- # Set default format if no template was found
90
+ # If no pre-defined template was found, fallback to HuggingFace template
102
91
  if self._chat_template_name is None:
103
- self._jinja_template_content_format = "string"
104
- logger.info(
105
- "No chat template found, defaulting to 'string' content format"
106
- )
92
+ # Try HuggingFace template first
93
+ hf_template = self._resolve_hf_chat_template(tokenizer_manager)
94
+ if hf_template:
95
+ # override the chat template
96
+ tokenizer_manager.tokenizer.chat_template = hf_template
97
+ self._jinja_template_content_format = (
98
+ detect_jinja_template_content_format(hf_template)
99
+ )
100
+ logger.info(
101
+ f"Using default HuggingFace chat template with detected content format: {self._jinja_template_content_format}"
102
+ )
103
+ return
104
+
105
+ # Default to string content format if no template was found
106
+ self._jinja_template_content_format = "string"
107
+ logger.info("No chat template found, defaulting to 'string' content format")
107
108
 
108
109
  def _load_explicit_chat_template(
109
110
  self, tokenizer_manager, chat_template_arg: str
@@ -257,13 +258,15 @@ class TemplateManager:
257
258
 
258
259
  Returns the chat template string if found, None otherwise.
259
260
  """
260
- tokenizer = tokenizer_manager.tokenizer
261
-
262
- # Try to get AutoTokenizer chat template
263
261
  try:
264
- return tokenizer.get_chat_template()
262
+ if processor := tokenizer_manager.processor:
263
+ if hasattr(processor, "chat_template") and processor.chat_template:
264
+ return processor.chat_template
265
+ if tokenizer := tokenizer_manager.tokenizer:
266
+ if hasattr(tokenizer, "chat_template") and tokenizer.chat_template:
267
+ return tokenizer.chat_template
265
268
  except Exception as e:
266
- logger.debug(f"Error getting chat template via get_chat_template(): {e}")
269
+ logger.debug(f"Error getting chat template: {e}")
267
270
 
268
271
  logger.debug("No HuggingFace chat template found")
269
272
  return None
@@ -29,6 +29,7 @@ import uuid
29
29
  from collections import deque
30
30
  from contextlib import nullcontext
31
31
  from datetime import datetime
32
+ from enum import Enum
32
33
  from http import HTTPStatus
33
34
  from typing import (
34
35
  Any,
@@ -70,7 +71,6 @@ from sglang.srt.managers.io_struct import (
70
71
  BatchMultimodalOut,
71
72
  BatchStrOut,
72
73
  BatchTokenIDOut,
73
- BlockReqType,
74
74
  CloseSessionReqInput,
75
75
  ConfigureLoggingReq,
76
76
  EmbeddingReqInput,
@@ -116,6 +116,7 @@ from sglang.srt.managers.io_struct import (
116
116
  )
117
117
  from sglang.srt.managers.mm_utils import TensorTransportMode
118
118
  from sglang.srt.managers.multimodal_processor import get_mm_processor, import_processors
119
+ from sglang.srt.managers.scheduler import is_health_check_generate_req
119
120
  from sglang.srt.managers.scheduler_input_blocker import input_blocker_guard_region
120
121
  from sglang.srt.metrics.collector import TokenizerMetricsCollector
121
122
  from sglang.srt.sampling.sampling_params import SamplingParams
@@ -202,13 +203,29 @@ class TokenizerManager:
202
203
 
203
204
  if self.model_config.is_multimodal:
204
205
  import_processors()
205
- _processor = get_processor(
206
- server_args.tokenizer_path,
207
- tokenizer_mode=server_args.tokenizer_mode,
208
- trust_remote_code=server_args.trust_remote_code,
209
- revision=server_args.revision,
210
- use_fast=not server_args.disable_fast_image_processor,
211
- )
206
+ try:
207
+ _processor = get_processor(
208
+ server_args.tokenizer_path,
209
+ tokenizer_mode=server_args.tokenizer_mode,
210
+ trust_remote_code=server_args.trust_remote_code,
211
+ revision=server_args.revision,
212
+ use_fast=not server_args.disable_fast_image_processor,
213
+ )
214
+ except ValueError as e:
215
+ error_message = str(e)
216
+ if "does not have a slow version" in error_message:
217
+ logger.info(
218
+ f"Processor {server_args.tokenizer_path} does not have a slow version. Automatically use fast version"
219
+ )
220
+ _processor = get_processor(
221
+ server_args.tokenizer_path,
222
+ tokenizer_mode=server_args.tokenizer_mode,
223
+ trust_remote_code=server_args.trust_remote_code,
224
+ revision=server_args.revision,
225
+ use_fast=True,
226
+ )
227
+ else:
228
+ raise e
212
229
  transport_mode = _determine_tensor_transport_mode(self.server_args)
213
230
 
214
231
  # We want to parallelize the image pre-processing so we create an executor for it
@@ -225,10 +242,10 @@ class TokenizerManager:
225
242
  self.tokenizer = get_tokenizer_from_processor(self.processor)
226
243
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
227
244
  else:
228
- self.mm_processor = None
245
+ self.mm_processor = self.processor = None
229
246
 
230
247
  if server_args.skip_tokenizer_init:
231
- self.tokenizer = self.processor = None
248
+ self.tokenizer = None
232
249
  else:
233
250
  self.tokenizer = get_tokenizer(
234
251
  server_args.tokenizer_path,
@@ -255,6 +272,7 @@ class TokenizerManager:
255
272
  self.health_check_failed = False
256
273
  self.gracefully_exit = False
257
274
  self.last_receive_tstamp = 0
275
+ self.server_status = ServerStatus.Starting
258
276
 
259
277
  # Dumping
260
278
  self.dump_requests_folder = "" # By default do not dump
@@ -1069,38 +1087,56 @@ class TokenizerManager:
1069
1087
  _: Optional[fastapi.Request] = None,
1070
1088
  ) -> LoadLoRAAdapterReqOutput:
1071
1089
  self.auto_create_handle_loop()
1072
- if not self.server_args.enable_lora:
1073
- raise ValueError(
1074
- "LoRA is not enabled. Please set `--enable-lora` to enable LoRA."
1075
- )
1076
1090
 
1077
- # TODO (lifuhuang): Remove this after we verify that dynamic lora loading works
1078
- # with dp_size > 1.
1079
- assert (
1080
- self.server_args.dp_size == 1
1081
- ), "dp_size must be 1 for dynamic lora loading"
1082
- logger.info(
1083
- "Start load Lora adapter. Lora name=%s, path=%s",
1084
- obj.lora_name,
1085
- obj.lora_path,
1086
- )
1091
+ try:
1092
+ if not self.server_args.enable_lora:
1093
+ raise ValueError(
1094
+ "LoRA is not enabled. Please set `--enable-lora` to enable LoRA."
1095
+ )
1087
1096
 
1088
- async with self.lora_update_lock:
1089
- # Generate new uniquely identifiable LoRARef object.
1090
- new_adapter = LoRARef(
1091
- lora_name=obj.lora_name,
1092
- lora_path=obj.lora_path,
1097
+ # TODO (lifuhuang): Remove this after we verify that dynamic lora loading works
1098
+ # with dp_size > 1.
1099
+ assert (
1100
+ self.server_args.dp_size == 1
1101
+ ), "dp_size must be 1 for dynamic lora loading"
1102
+ logger.info(
1103
+ "Start load Lora adapter. Lora name=%s, path=%s",
1104
+ obj.lora_name,
1105
+ obj.lora_path,
1093
1106
  )
1094
1107
 
1095
- # Trigger the actual loading operation at the backend processes.
1096
- obj.lora_id = new_adapter.lora_id
1097
- result = (await self.update_lora_adapter_communicator(obj))[0]
1108
+ async with self.lora_update_lock:
1109
+ if (
1110
+ self.server_args.max_loaded_loras is not None
1111
+ and self.lora_registry.num_registered_loras
1112
+ >= self.server_args.max_loaded_loras
1113
+ ):
1114
+ raise ValueError(
1115
+ f"Cannot load LoRA adapter {obj.lora_name} at path {obj.lora_path}. "
1116
+ f"Maximum number of loaded LoRA adapters is {self.server_args.max_loaded_loras}. "
1117
+ "Please unload some LoRA adapters before loading new ones."
1118
+ )
1098
1119
 
1099
- # Register the LoRA adapter only after loading is successful.
1100
- if result.success:
1101
- await self.lora_registry.register(new_adapter)
1120
+ # Generate new uniquely identifiable LoRARef object.
1121
+ new_adapter = LoRARef(
1122
+ lora_name=obj.lora_name,
1123
+ lora_path=obj.lora_path,
1124
+ )
1102
1125
 
1103
- return result
1126
+ # Trigger the actual loading operation at the backend processes.
1127
+ obj.lora_id = new_adapter.lora_id
1128
+ result = (await self.update_lora_adapter_communicator(obj))[0]
1129
+
1130
+ # Register the LoRA adapter only after loading is successful.
1131
+ if result.success:
1132
+ await self.lora_registry.register(new_adapter)
1133
+
1134
+ return result
1135
+ except ValueError as e:
1136
+ return LoadLoRAAdapterReqOutput(
1137
+ success=False,
1138
+ error_message=str(e),
1139
+ )
1104
1140
 
1105
1141
  async def unload_lora_adapter(
1106
1142
  self,
@@ -1108,37 +1144,41 @@ class TokenizerManager:
1108
1144
  _: Optional[fastapi.Request] = None,
1109
1145
  ) -> UnloadLoRAAdapterReqOutput:
1110
1146
  self.auto_create_handle_loop()
1111
- if not self.server_args.enable_lora:
1112
- raise ValueError(
1113
- "LoRA is not enabled. Please set `--enable-lora` to enable LoRA."
1114
- )
1115
1147
 
1116
- assert (
1117
- obj.lora_name is not None
1118
- ), "lora_name must be provided to unload LoRA adapter"
1148
+ try:
1149
+ if not self.server_args.enable_lora:
1150
+ raise ValueError(
1151
+ "LoRA is not enabled. Please set `--enable-lora` to enable LoRA."
1152
+ )
1119
1153
 
1120
- # TODO (lifuhuang): Remove this after we verify that dynamic lora loading works
1121
- # with dp_size > 1.
1122
- assert (
1123
- self.server_args.dp_size == 1
1124
- ), "dp_size must be 1 for dynamic lora loading"
1125
- logger.info(
1126
- "Start unload Lora adapter. Lora name=%s",
1127
- obj.lora_name,
1128
- )
1154
+ assert (
1155
+ obj.lora_name is not None
1156
+ ), "lora_name must be provided to unload LoRA adapter"
1157
+
1158
+ # TODO (lifuhuang): Remove this after we verify that dynamic lora loading works
1159
+ # with dp_size > 1.
1160
+ assert (
1161
+ self.server_args.dp_size == 1
1162
+ ), "dp_size must be 1 for dynamic lora loading"
1163
+ logger.info(
1164
+ "Start unload Lora adapter. Lora name=%s",
1165
+ obj.lora_name,
1166
+ )
1129
1167
 
1130
- async with self.lora_update_lock:
1131
- # Unregister the LoRA adapter from the registry to stop new requests for this adapter
1132
- # from being started.
1133
- lora_id = await self.lora_registry.unregister(obj.lora_name)
1134
- obj.lora_id = lora_id
1168
+ async with self.lora_update_lock:
1169
+ # Unregister the LoRA adapter from the registry to stop new requests for this adapter
1170
+ # from being started.
1171
+ lora_id = await self.lora_registry.unregister(obj.lora_name)
1172
+ obj.lora_id = lora_id
1135
1173
 
1136
- # Initiate the actual unloading operation at the backend processes only after all
1137
- # ongoing requests using this LoRA adapter are finished.
1138
- await self.lora_registry.wait_for_unload(lora_id)
1139
- result = (await self.update_lora_adapter_communicator(obj))[0]
1174
+ # Initiate the actual unloading operation at the backend processes only after all
1175
+ # ongoing requests using this LoRA adapter are finished.
1176
+ await self.lora_registry.wait_for_unload(lora_id)
1177
+ result = (await self.update_lora_adapter_communicator(obj))[0]
1140
1178
 
1141
- return result
1179
+ return result
1180
+ except ValueError as e:
1181
+ return UnloadLoRAAdapterReqOutput(success=False, rror_message=str(e))
1142
1182
 
1143
1183
  async def get_weights_by_name(
1144
1184
  self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
@@ -1767,6 +1807,8 @@ class TokenizerManager:
1767
1807
  asyncio.create_task(asyncio.to_thread(background_task))
1768
1808
 
1769
1809
  def _handle_abort_req(self, recv_obj):
1810
+ if is_health_check_generate_req(recv_obj):
1811
+ return
1770
1812
  state = self.rid_to_state[recv_obj.rid]
1771
1813
  state.finished = True
1772
1814
  if recv_obj.finished_reason:
@@ -1901,6 +1943,16 @@ class TokenizerManager:
1901
1943
  return scores
1902
1944
 
1903
1945
 
1946
+ class ServerStatus(Enum):
1947
+ Up = "Up"
1948
+ Starting = "Starting"
1949
+ UnHealthy = "UnHealthy"
1950
+ Crashed = "Crashed"
1951
+
1952
+ def is_healthy(self) -> bool:
1953
+ return self == ServerStatus.Up
1954
+
1955
+
1904
1956
  def _determine_tensor_transport_mode(server_args: ServerArgs) -> TensorTransportMode:
1905
1957
  is_cross_node = server_args.dist_init_addr
1906
1958
 
@@ -1,6 +1,7 @@
1
1
  import logging
2
+ import multiprocessing as mp
2
3
  from http import HTTPStatus
3
- from typing import Optional
4
+ from typing import Dict, List, Optional
4
5
 
5
6
  from sglang.srt.managers.schedule_batch import FINISH_ABORT, Req
6
7
 
@@ -38,3 +39,46 @@ def validate_input_length(
38
39
  return error_msg
39
40
 
40
41
  return None
42
+
43
+
44
+ class DPBalanceMeta:
45
+ """
46
+ This class will be use in scheduler and dp controller
47
+ """
48
+
49
+ def __init__(self, num_workers: int):
50
+ self.num_workers = num_workers
51
+ self._manager = mp.Manager()
52
+ self.mutex = self._manager.Lock()
53
+
54
+ init_local_tokens = [0] * self.num_workers
55
+ init_onfly_info = [self._manager.dict() for _ in range(self.num_workers)]
56
+
57
+ self.shared_state = self._manager.Namespace()
58
+ self.shared_state.local_tokens = self._manager.list(init_local_tokens)
59
+ self.shared_state.onfly_info = self._manager.list(init_onfly_info)
60
+
61
+ def destructor(self):
62
+ # we must destructor this class manually
63
+ self._manager.shutdown()
64
+
65
+ def get_shared_onfly(self) -> List[Dict[int, int]]:
66
+ return [dict(d) for d in self.shared_state.onfly_info]
67
+
68
+ def set_shared_onfly_info(self, data: List[Dict[int, int]]):
69
+ self.shared_state.onfly_info = data
70
+
71
+ def get_shared_local_tokens(self) -> List[int]:
72
+ return list(self.shared_state.local_tokens)
73
+
74
+ def set_shared_local_tokens(self, data: List[int]):
75
+ self.shared_state.local_tokens = data
76
+
77
+ def __getstate__(self):
78
+ state = self.__dict__.copy()
79
+ del state["_manager"]
80
+ return state
81
+
82
+ def __setstate__(self, state):
83
+ self.__dict__.update(state)
84
+ self._manager = None