sglang 0.5.2rc0__py3-none-any.whl → 0.5.2rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (84) hide show
  1. sglang/lang/interpreter.py +1 -1
  2. sglang/srt/configs/internvl.py +6 -0
  3. sglang/srt/configs/model_config.py +2 -1
  4. sglang/srt/disaggregation/mini_lb.py +2 -2
  5. sglang/srt/distributed/parallel_state.py +46 -41
  6. sglang/srt/entrypoints/engine.py +1 -1
  7. sglang/srt/entrypoints/http_server.py +5 -1
  8. sglang/srt/entrypoints/openai/protocol.py +3 -3
  9. sglang/srt/entrypoints/openai/serving_chat.py +3 -3
  10. sglang/srt/entrypoints/openai/serving_completions.py +3 -1
  11. sglang/srt/entrypoints/openai/serving_embedding.py +1 -1
  12. sglang/srt/entrypoints/openai/serving_responses.py +1 -1
  13. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  14. sglang/srt/layers/attention/aiter_backend.py +93 -68
  15. sglang/srt/layers/communicator.py +45 -7
  16. sglang/srt/layers/moe/cutlass_w4a8_moe.py +1 -9
  17. sglang/srt/layers/moe/ep_moe/layer.py +2 -7
  18. sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
  19. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  20. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -1048
  21. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
  22. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +796 -0
  23. sglang/srt/layers/moe/fused_moe_triton/layer.py +5 -2
  24. sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
  25. sglang/srt/layers/moe/utils.py +0 -1
  26. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +8 -0
  27. sglang/srt/layers/quantization/modelopt_quant.py +35 -2
  28. sglang/srt/layers/quantization/mxfp4.py +4 -1
  29. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
  30. sglang/srt/layers/quantization/quark/utils.py +97 -0
  31. sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
  32. sglang/srt/layers/quantization/w4afp8.py +30 -25
  33. sglang/srt/layers/rocm_linear_utils.py +44 -0
  34. sglang/srt/layers/rotary_embedding.py +0 -18
  35. sglang/srt/managers/cache_controller.py +42 -39
  36. sglang/srt/managers/detokenizer_manager.py +0 -34
  37. sglang/srt/managers/multi_tokenizer_mixin.py +48 -6
  38. sglang/srt/managers/schedule_policy.py +3 -2
  39. sglang/srt/managers/scheduler.py +7 -100
  40. sglang/srt/managers/scheduler_metrics_mixin.py +113 -7
  41. sglang/srt/managers/template_manager.py +3 -3
  42. sglang/srt/managers/tokenizer_manager.py +1 -0
  43. sglang/srt/mem_cache/allocator.py +1 -1
  44. sglang/srt/mem_cache/hicache_storage.py +15 -10
  45. sglang/srt/mem_cache/hiradix_cache.py +16 -0
  46. sglang/srt/mem_cache/memory_pool_host.py +18 -11
  47. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
  48. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +35 -6
  49. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +32 -13
  50. sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
  51. sglang/srt/metrics/collector.py +12 -4
  52. sglang/srt/metrics/utils.py +48 -0
  53. sglang/srt/model_executor/forward_batch_info.py +16 -17
  54. sglang/srt/model_executor/model_runner.py +1 -1
  55. sglang/srt/models/deepseek_v2.py +245 -36
  56. sglang/srt/models/glm4_moe.py +10 -1
  57. sglang/srt/models/gpt_oss.py +5 -4
  58. sglang/srt/models/internvl.py +28 -0
  59. sglang/srt/models/longcat_flash.py +26 -15
  60. sglang/srt/models/longcat_flash_nextn.py +23 -15
  61. sglang/srt/models/minicpmv.py +165 -3
  62. sglang/srt/models/qwen2_moe.py +4 -1
  63. sglang/srt/models/qwen3.py +8 -2
  64. sglang/srt/models/qwen3_moe.py +39 -8
  65. sglang/srt/models/torch_native_llama.py +1 -1
  66. sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
  67. sglang/srt/server_args.py +79 -2
  68. sglang/srt/speculative/eagle_worker.py +158 -112
  69. sglang/srt/utils.py +12 -10
  70. sglang/test/few_shot_gsm8k.py +1 -0
  71. sglang/test/test_cutlass_w4a8_moe.py +24 -9
  72. sglang/utils.py +1 -0
  73. sglang/version.py +1 -1
  74. {sglang-0.5.2rc0.dist-info → sglang-0.5.2rc2.dist-info}/METADATA +2 -2
  75. {sglang-0.5.2rc0.dist-info → sglang-0.5.2rc2.dist-info}/RECORD +83 -76
  76. sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
  77. /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
  78. /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
  79. /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
  80. /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
  81. /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
  82. {sglang-0.5.2rc0.dist-info → sglang-0.5.2rc2.dist-info}/WHEEL +0 -0
  83. {sglang-0.5.2rc0.dist-info → sglang-0.5.2rc2.dist-info}/licenses/LICENSE +0 -0
  84. {sglang-0.5.2rc0.dist-info → sglang-0.5.2rc2.dist-info}/top_level.txt +0 -0
@@ -324,6 +324,22 @@ class HiCacheController:
324
324
  group_ranks, backend="gloo"
325
325
  )
326
326
 
327
+ # Select the get and set functions
328
+ self.page_get_func = self._generic_page_get
329
+ self.page_set_func = self._generic_page_set
330
+ self.batch_exists_func = self.storage_backend.batch_exists
331
+ self.is_3fs_zerocopy = (
332
+ self.storage_backend_type == "hf3fs"
333
+ and self.mem_pool_host.layout == "page_first"
334
+ )
335
+ if self.storage_backend_type == "mooncake":
336
+ self.page_get_func = self._mooncake_page_get
337
+ self.page_set_func = self._mooncake_page_set
338
+ elif self.is_3fs_zerocopy:
339
+ self.page_get_func = self._3fs_zero_copy_page_get
340
+ self.page_set_func = self._3fs_zero_copy_page_set
341
+ self.batch_exists_func = self._3fs_zero_copy_batch_exists
342
+
327
343
  self.load_cache_event = load_cache_event
328
344
  self.layer_done_counter = LayerDoneCounter(self.mem_pool_device.layer_num)
329
345
  self.mem_pool_device.register_layer_transfer_counter(self.layer_done_counter)
@@ -407,6 +423,7 @@ class HiCacheController:
407
423
  tp_rank=self.tp_rank,
408
424
  tp_size=self.tp_size,
409
425
  is_mla_model=is_mla_backend,
426
+ is_page_first_layout=self.mem_pool_host.layout == "page_first",
410
427
  model_name=model_name,
411
428
  extra_config=extra_config,
412
429
  )
@@ -616,13 +633,19 @@ class HiCacheController:
616
633
  for chunk in chunks:
617
634
  self.host_mem_release_queue.put(chunk)
618
635
 
636
+ def _3fs_zero_copy_batch_exists(self, batch_hashes):
637
+ _batch_hashes, _, factor = self.mem_pool_host.get_buffer_with_hash(batch_hashes)
638
+ hit_page_num = self.storage_backend.batch_exists(_batch_hashes) // factor
639
+ return hit_page_num
640
+
619
641
  def _3fs_zero_copy_page_get(self, operation, hash_values, host_indices):
620
- hashes, dsts = self.mem_pool_host.get_buffer_with_hash(
642
+ hashes, dsts, factor = self.mem_pool_host.get_buffer_with_hash(
621
643
  hash_values, host_indices
622
644
  )
623
645
  page_data = self.storage_backend.batch_get(hashes, dsts)
624
646
  if page_data:
625
- operation.increment(self.page_size * len(hashes))
647
+ inc = self.page_size * len(hashes) // factor
648
+ operation.increment(inc)
626
649
  else:
627
650
  logger.warning(
628
651
  f"Prefetch operation {operation.request_id} failed to retrieve page {hashes}."
@@ -636,7 +659,7 @@ class HiCacheController:
636
659
  )
637
660
  get_result = self.storage_backend.batch_get(
638
661
  key_strs,
639
- target_location=buffer_ptrs,
662
+ target_locations=buffer_ptrs,
640
663
  target_sizes=buffer_sizes,
641
664
  )
642
665
  if get_result != len(hash_values):
@@ -647,9 +670,9 @@ class HiCacheController:
647
670
  operation.increment(get_result * self.page_size)
648
671
 
649
672
  def _generic_page_get(self, operation, hash_values, host_indices):
650
- dummy_page_dst = [self.mem_pool_host.get_dummy_flat_data_page()] * len(
651
- hash_values
652
- )
673
+ dummy_page_dst = [
674
+ self.mem_pool_host.get_dummy_flat_data_page() for _ in hash_values
675
+ ]
653
676
  page_data = self.storage_backend.batch_get(hash_values, dummy_page_dst)
654
677
  if page_data is None:
655
678
  return
@@ -659,26 +682,16 @@ class HiCacheController:
659
682
  f"Prefetch operation {operation.request_id} failed to retrieve page {hash_values[i]}."
660
683
  )
661
684
  break
662
- if operation.increment(self.page_size):
663
- self.mem_pool_host.set_from_flat_data_page(
664
- host_indices[i * self.page_size],
665
- page_data[i],
666
- )
667
- else:
668
- break
685
+ # Must set the data before increasing the completed tokens.
686
+ # Otherwise this page may be read before being set.
687
+ self.mem_pool_host.set_from_flat_data_page(
688
+ host_indices[i * self.page_size],
689
+ page_data[i],
690
+ )
691
+ if not operation.increment(self.page_size):
692
+ break # Operation terminated by controller
669
693
 
670
694
  def _page_transfer(self, operation):
671
- # Select the get function and batch size
672
- if self.storage_backend_type == "mooncake":
673
- get_func = self._mooncake_page_get
674
- elif (
675
- self.storage_backend_type == "hf3fs"
676
- and self.mem_pool_host.layout == "page_first"
677
- ):
678
- get_func = self._3fs_zero_copy_page_get
679
- else:
680
- get_func = self._generic_page_get
681
-
682
695
  # Transfer batch by batch
683
696
  for i in range(0, len(operation.hash_value), self.storage_batch_size):
684
697
  batch_hashes = operation.hash_value[i : i + self.storage_batch_size]
@@ -687,7 +700,7 @@ class HiCacheController:
687
700
  ]
688
701
  prev_completed_tokens = operation.completed_tokens
689
702
  # Get one batch token, and update the completed_tokens if succeed
690
- get_func(operation, batch_hashes, batch_host_indices)
703
+ self.page_get_func(operation, batch_hashes, batch_host_indices)
691
704
  # Check termination
692
705
  if (
693
706
  operation.completed_tokens
@@ -744,7 +757,7 @@ class HiCacheController:
744
757
  batch_tokens[i : i + self.page_size], last_hash
745
758
  )
746
759
  batch_hashes.append(last_hash)
747
- hit_page_num = self.storage_backend.batch_exists(batch_hashes)
760
+ hit_page_num = self.batch_exists_func(batch_hashes)
748
761
  hash_value.extend(batch_hashes[:hit_page_num])
749
762
  storage_query_count += hit_page_num * self.page_size
750
763
  if hit_page_num < len(batch_hashes):
@@ -830,30 +843,20 @@ class HiCacheController:
830
843
  )
831
844
  success = self.storage_backend.batch_set(
832
845
  key_strs,
833
- target_location=buffer_ptrs,
846
+ target_locations=buffer_ptrs,
834
847
  target_sizes=buffer_sizes,
835
848
  )
836
849
  return success
837
850
 
838
851
  # zero copy
839
852
  def _3fs_zero_copy_page_set(self, hash_values, host_indices) -> bool:
840
- hashes, dsts = self.mem_pool_host.get_buffer_with_hash(
853
+ hashes, dsts, _ = self.mem_pool_host.get_buffer_with_hash(
841
854
  hash_values, host_indices
842
855
  )
843
856
  return self.storage_backend.batch_set(hashes, dsts)
844
857
 
845
858
  # Backup batch by batch
846
859
  def _page_backup(self, operation):
847
- # Select the set function and batch size
848
- if self.storage_backend_type == "mooncake":
849
- backup_set_func = self._mooncake_page_set
850
- elif (
851
- self.storage_backend_type == "hf3fs"
852
- and self.mem_pool_host.layout == "page_first"
853
- ):
854
- backup_set_func = self._3fs_zero_copy_page_set
855
- else:
856
- backup_set_func = self._generic_page_set
857
860
  # Backup batch by batch
858
861
  for i in range(0, len(operation.hash_value), self.storage_batch_size):
859
862
  batch_hashes = operation.hash_value[i : i + self.storage_batch_size]
@@ -862,7 +865,7 @@ class HiCacheController:
862
865
  ]
863
866
  # Set one batch token, and record if success.
864
867
  # todo: allow partial success
865
- success = backup_set_func(batch_hashes, batch_host_indices)
868
+ success = self.page_set_func(batch_hashes, batch_host_indices)
866
869
  if not success:
867
870
  logger.warning(
868
871
  f"Write page to storage: {len(batch_hashes)} pages failed."
@@ -39,7 +39,6 @@ from sglang.srt.server_args import PortArgs, ServerArgs
39
39
  from sglang.srt.utils import (
40
40
  configure_logger,
41
41
  freeze_gc,
42
- get_worker_ids_from_req_rids,
43
42
  get_zmq_socket,
44
43
  kill_itself_when_parent_died,
45
44
  )
@@ -120,39 +119,6 @@ class DetokenizerManager(MultiTokenizerMixin):
120
119
  if output is not None:
121
120
  self.send_to_tokenizer.send_pyobj(output)
122
121
 
123
- def multi_tokenizer_manager_event_loop(self):
124
- """The event loop that handles requests, for multi tokenizer manager mode only"""
125
- self.create_sockets_mapping()
126
- while True:
127
- recv_obj = self.recv_from_scheduler.recv_pyobj()
128
- output = self._request_dispatcher(recv_obj)
129
- if output is None:
130
- continue
131
- # Extract worker_id from rid
132
- if isinstance(recv_obj.rids, list):
133
- worker_ids = get_worker_ids_from_req_rids(recv_obj.rids)
134
- else:
135
- raise RuntimeError(
136
- f"for tokenizer_worker_num > 1, recv_obj.rids must be a list"
137
- )
138
-
139
- # Send data using the corresponding socket
140
- for i, worker_id in enumerate(worker_ids):
141
- if isinstance(recv_obj, MultiTokenizerRegisterReq):
142
- if self.register_tokenizer_ipc(recv_obj, worker_id):
143
- logger.info(
144
- f"DetokenizerManager Created ZMQ socket for worker {worker_id}"
145
- )
146
- continue
147
- else:
148
- if worker_id not in self.tokenizer_mapping:
149
- logger.error(
150
- f"Tokenizer Worker ID {worker_id} not registered. Check if the server Process {worker_id} is alive"
151
- )
152
- continue
153
- new_output = self._handle_output_by_index(output, i)
154
- self.tokenizer_mapping[worker_id].send_pyobj(new_output)
155
-
156
122
  def trim_matched_stop(
157
123
  self, output: Union[str, List[int]], finished_reason: Dict, no_stop_trim: bool
158
124
  ):
@@ -23,6 +23,7 @@ import threading
23
23
  from multiprocessing import shared_memory
24
24
  from typing import Dict
25
25
 
26
+ import setproctitle
26
27
  import zmq
27
28
  import zmq.asyncio
28
29
 
@@ -37,11 +38,7 @@ from sglang.srt.managers.io_struct import (
37
38
  )
38
39
  from sglang.srt.managers.tokenizer_manager import TokenizerManager, _Communicator
39
40
  from sglang.srt.server_args import PortArgs, ServerArgs
40
- from sglang.srt.utils import (
41
- get_worker_ids_from_req_rids,
42
- get_zmq_socket,
43
- kill_process_tree,
44
- )
41
+ from sglang.srt.utils import get_zmq_socket, kill_process_tree
45
42
  from sglang.utils import get_exception_traceback
46
43
 
47
44
  logger = logging.getLogger(__name__)
@@ -344,6 +341,48 @@ class MultiTokenizerMixin:
344
341
  new_output = output
345
342
  return new_output
346
343
 
344
+ def get_worker_ids_from_req_rids(self, rids):
345
+ if isinstance(rids, list):
346
+ worker_ids = [int(rid.split("_")[0]) for rid in rids]
347
+ elif isinstance(rids, str):
348
+ worker_ids = [int(rids.split("_")[0])]
349
+ else:
350
+ worker_ids = []
351
+ return worker_ids
352
+
353
+ def multi_tokenizer_manager_event_loop(self):
354
+ """The event loop that handles requests, for multi tokenizer manager mode only"""
355
+ self.create_sockets_mapping()
356
+ while True:
357
+ recv_obj = self.recv_from_scheduler.recv_pyobj()
358
+ output = self._request_dispatcher(recv_obj)
359
+ if output is None:
360
+ continue
361
+ # Extract worker_id from rid
362
+ if isinstance(recv_obj.rids, list):
363
+ worker_ids = self.get_worker_ids_from_req_rids(recv_obj.rids)
364
+ else:
365
+ raise RuntimeError(
366
+ f"for tokenizer_worker_num > 1, recv_obj.rids must be a list"
367
+ )
368
+
369
+ # Send data using the corresponding socket
370
+ for i, worker_id in enumerate(worker_ids):
371
+ if isinstance(recv_obj, MultiTokenizerRegisterReq):
372
+ if self.register_tokenizer_ipc(recv_obj, worker_id):
373
+ logger.info(
374
+ f"DetokenizerManager Created ZMQ socket for worker {worker_id}"
375
+ )
376
+ continue
377
+ else:
378
+ if worker_id not in self.tokenizer_mapping:
379
+ logger.error(
380
+ f"Tokenizer Worker ID {worker_id} not registered. Check if the server Process {worker_id} is alive"
381
+ )
382
+ continue
383
+ new_output = self._handle_output_by_index(output, i)
384
+ self.tokenizer_mapping[worker_id].send_pyobj(new_output)
385
+
347
386
  def clear_tokenizer_mapping(self):
348
387
  if hasattr(self, "tokenizer_mapping"):
349
388
  for socket in self.tokenizer_mapping.values():
@@ -406,7 +445,7 @@ class MultiTokenizerRouter(TokenizerManager, MultiTokenizerMixin):
406
445
  worker_ids = [recv_obj.worker_id]
407
446
  recv_obj = recv_obj.obj
408
447
  else:
409
- worker_ids = get_worker_ids_from_req_rids(recv_obj.rids)
448
+ worker_ids = self.get_worker_ids_from_req_rids(recv_obj.rids)
410
449
 
411
450
  if len(worker_ids) == 0:
412
451
  logger.error(f"Cannot find worker_id from rids {recv_obj.rids}")
@@ -438,6 +477,9 @@ class MultiTokenizerManager(TokenizerManager, MultiTokenizerMixin):
438
477
  server_args: ServerArgs,
439
478
  port_args: PortArgs,
440
479
  ):
480
+ setproctitle.setproctitle(
481
+ f"sglang::http_server/multi_tokenizer_manager:{os.getpid()}"
482
+ )
441
483
  # prevent init prefill bootstrapserver again
442
484
  disaggregation_mode = server_args.disaggregation_mode
443
485
  server_args.disaggregation_mode = "null"
@@ -380,8 +380,9 @@ class PrefillAdder:
380
380
  self.log_input_tokens += extend_input_len
381
381
 
382
382
  def add_chunked_req(self, req: Req):
383
- truncated = req.extend_input_len > self.rem_chunk_tokens
384
- req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens)
383
+ _rem_tokens = min(self.rem_chunk_tokens, int(self.rem_total_tokens))
384
+ truncated = req.extend_input_len > _rem_tokens
385
+ req.extend_input_len = min(req.extend_input_len, _rem_tokens)
385
386
  req.fill_ids = req.fill_ids[: len(req.prefix_indices) + req.extend_input_len]
386
387
  self.can_run_list.append(req)
387
388
  self._update_prefill_budget(
@@ -141,7 +141,7 @@ from sglang.srt.mem_cache.lora_radix_cache import LoRARadixCache
141
141
  from sglang.srt.mem_cache.radix_cache import RadixCache
142
142
  from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
143
143
  from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors
144
- from sglang.srt.reasoning_parser import ReasoningParser
144
+ from sglang.srt.parser.reasoning_parser import ReasoningParser
145
145
  from sglang.srt.server_args import PortArgs, ServerArgs
146
146
  from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
147
147
  from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
@@ -500,6 +500,7 @@ class Scheduler(
500
500
  # Init metrics stats
501
501
  self.init_metrics(tp_rank, pp_rank, dp_rank)
502
502
  self.init_kv_events(server_args.kv_events_config)
503
+ self.init_dp_balance(dp_balance_meta)
503
504
 
504
505
  # Init disaggregation
505
506
  self.disaggregation_mode = DisaggregationMode(
@@ -545,15 +546,6 @@ class Scheduler(
545
546
  ]
546
547
  )
547
548
 
548
- self.balance_meta = dp_balance_meta
549
- if (
550
- server_args.enable_dp_attention
551
- and server_args.load_balance_method == "minimum_tokens"
552
- ):
553
- assert dp_balance_meta is not None
554
-
555
- self.recv_dp_balance_id_this_term = []
556
-
557
549
  def init_tokenizer(self):
558
550
  server_args = self.server_args
559
551
  self.is_generation = self.model_config.is_generation
@@ -1126,11 +1118,7 @@ class Scheduler(
1126
1118
  self,
1127
1119
  recv_req: TokenizedGenerateReqInput,
1128
1120
  ):
1129
- if (
1130
- self.server_args.enable_dp_attention
1131
- and self.server_args.load_balance_method == "minimum_tokens"
1132
- ):
1133
- self.recv_dp_balance_id_this_term.append(recv_req.dp_balance_id)
1121
+ self.maybe_update_dp_balance_data(recv_req)
1134
1122
 
1135
1123
  # Create a new request
1136
1124
  if (
@@ -1568,11 +1556,7 @@ class Scheduler(
1568
1556
 
1569
1557
  # Handle DP attention
1570
1558
  if need_dp_attn_preparation:
1571
- if (
1572
- self.server_args.load_balance_method == "minimum_tokens"
1573
- and self.forward_ct % 40 == 0
1574
- ):
1575
- self.handle_dp_balance_data(ret)
1559
+ self.maybe_handle_dp_balance_data()
1576
1560
  ret = self.prepare_mlp_sync_batch(ret)
1577
1561
 
1578
1562
  return ret
@@ -1897,86 +1881,6 @@ class Scheduler(
1897
1881
  disable_overlap_schedule=self.server_args.disable_overlap_schedule,
1898
1882
  )
1899
1883
 
1900
- def handle_dp_balance_data(self, local_batch: ScheduleBatch):
1901
- def gather_dp_balance_info(holding_tokens_list) -> Union[None, List[List[int]]]:
1902
- """gather recv_dp_balance_id_this_term and holding tokens per worker for dp balance"""
1903
- recv_list = self.recv_dp_balance_id_this_term
1904
- assert len(recv_list) <= 511, (
1905
- "The number of requests received this round is too large. "
1906
- "Please increase gather_tensor_size and onfly_info_size."
1907
- )
1908
- # The maximum size of the tensor used for gathering data from all workers.
1909
- gather_tensor_size = 512
1910
-
1911
- # recv_tensor: | holding_tokens | len(recv_dp_balance_id) | recv_dp_balance_ids
1912
- recv_tensor = torch.zeros(gather_tensor_size, dtype=torch.int32)
1913
- recv_tensor[0] = holding_tokens_list
1914
- recv_tensor[1] = len(
1915
- recv_list
1916
- ) # The first element is the length of the list.
1917
- recv_tensor[2 : len(recv_list) + 2] = torch.tensor(
1918
- recv_list, dtype=torch.int32
1919
- )
1920
-
1921
- if self.tp_rank == 0:
1922
- gathered_list = [
1923
- torch.zeros(gather_tensor_size, dtype=torch.int32)
1924
- for _ in range(self.balance_meta.num_workers)
1925
- ]
1926
- else:
1927
- gathered_list = None
1928
-
1929
- torch.distributed.gather(
1930
- recv_tensor, gathered_list, group=self.tp_cpu_group
1931
- )
1932
-
1933
- gathered_id_list_per_worker = None
1934
- if self.tp_rank == 0:
1935
- gathered_id_list_per_worker = []
1936
- holding_tokens_list = []
1937
- for tensor in gathered_list:
1938
- holding_tokens_list.append(tensor[0].item())
1939
- list_length = tensor[1].item()
1940
- gathered_id_list_per_worker.append(
1941
- tensor[2 : list_length + 2].tolist()
1942
- )
1943
-
1944
- return gathered_id_list_per_worker, holding_tokens_list
1945
-
1946
- def write_shared_dp_balance_info(new_recv_rid_lists, local_tokens):
1947
- meta = self.balance_meta
1948
-
1949
- with meta.mutex:
1950
- onfly_list: List[Dict[int, int]] = meta.get_shared_onfly()
1951
- assert len(new_recv_rid_lists) == len(
1952
- onfly_list
1953
- ), "num_worker not equal"
1954
- # 1.Check if the rid received by each worker this round is present in onfly.
1955
- # If it is, remove the corresponding onfly item.
1956
- worker_id = 0
1957
- for new_recv_rids, on_fly_reqs in zip(new_recv_rid_lists, onfly_list):
1958
- for new_recv_rid in new_recv_rids:
1959
- assert (
1960
- new_recv_rid in on_fly_reqs
1961
- ), f"{new_recv_rid=} not in {worker_id=} {on_fly_reqs=}, data consistency is wrong"
1962
- del on_fly_reqs[new_recv_rid]
1963
- worker_id += 1
1964
- # 2. Atomically write local_tokens and onfly into shm under the mutex
1965
- meta.set_shared_onfly_info(onfly_list)
1966
- meta.set_shared_local_tokens(local_tokens)
1967
-
1968
- holding_tokens = self.get_load()
1969
-
1970
- new_recv_dp_balance_id_list, holding_token_list = gather_dp_balance_info(
1971
- holding_tokens
1972
- )
1973
-
1974
- self.recv_dp_balance_id_this_term.clear()
1975
- if self.tp_rank == 0: # only first worker write info
1976
- write_shared_dp_balance_info(
1977
- new_recv_dp_balance_id_list, holding_token_list
1978
- )
1979
-
1980
1884
  @staticmethod
1981
1885
  def prepare_mlp_sync_batch_raw(
1982
1886
  local_batch: ScheduleBatch,
@@ -2403,6 +2307,9 @@ class Scheduler(
2403
2307
  # This only works for requests that have not started anything.
2404
2308
  # We still need to send something back to TokenizerManager to clean up the state.
2405
2309
  req = self.waiting_queue.pop(i)
2310
+ if self.enable_hicache_storage:
2311
+ # to release prefetch events associated with the request
2312
+ self.tree_cache.release_aborted_request(req.rid)
2406
2313
  self.send_to_tokenizer.send_pyobj(AbortReq(req.rid))
2407
2314
  # For disaggregation decode mode, the request in the waiting queue has KV cache allocated.
2408
2315
  if self.disaggregation_mode == DisaggregationMode.DECODE:
@@ -1,15 +1,24 @@
1
+ from __future__ import annotations
2
+
1
3
  import logging
2
4
  import time
3
5
  from collections import defaultdict
4
- from typing import List, Optional
6
+ from typing import TYPE_CHECKING, Dict, List, Optional, Union
7
+
8
+ import torch
5
9
 
6
10
  from sglang.srt.disaggregation.kv_events import EventPublisherFactory, KVEventBatch
7
11
  from sglang.srt.disaggregation.utils import DisaggregationMode
12
+ from sglang.srt.managers.io_struct import TokenizedGenerateReqInput
8
13
  from sglang.srt.managers.schedule_policy import PrefillAdder
9
14
  from sglang.srt.managers.scheduler import Req, ScheduleBatch
15
+ from sglang.srt.managers.utils import DPBalanceMeta
10
16
  from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
11
17
  from sglang.srt.utils import get_bool_env_var
12
18
 
19
+ if TYPE_CHECKING:
20
+ from sglang.srt.managers.scheduler import Scheduler
21
+
13
22
  logger = logging.getLogger(__name__)
14
23
 
15
24
  RECORD_STEP_TIME = get_bool_env_var("SGLANG_RECORD_STEP_TIME")
@@ -28,7 +37,9 @@ class KvMetrics:
28
37
 
29
38
 
30
39
  class SchedulerMetricsMixin:
31
- def init_metrics(self, tp_rank: int, pp_rank: int, dp_rank: Optional[int]):
40
+ def init_metrics(
41
+ self: Scheduler, tp_rank: int, pp_rank: int, dp_rank: Optional[int]
42
+ ):
32
43
  self.last_gen_throughput: float = 0.0
33
44
  self.last_input_throughput: float = 0.0
34
45
  self.step_time_dict = defaultdict(list) # Dict[batch size -> step time]
@@ -50,14 +61,24 @@ class SchedulerMetricsMixin:
50
61
  labels["dp_rank"] = dp_rank
51
62
  self.metrics_collector = SchedulerMetricsCollector(labels=labels)
52
63
 
53
- def init_kv_events(self, kv_events_config: Optional[str]):
64
+ def init_dp_balance(self: Scheduler, dp_balance_meta: Optional[DPBalanceMeta]):
65
+ self.balance_meta = dp_balance_meta
66
+ if (
67
+ self.server_args.enable_dp_attention
68
+ and self.server_args.load_balance_method == "minimum_tokens"
69
+ ):
70
+ assert dp_balance_meta is not None
71
+
72
+ self.recv_dp_balance_id_this_term = []
73
+
74
+ def init_kv_events(self: Scheduler, kv_events_config: Optional[str]):
54
75
  if self.enable_kv_cache_events:
55
76
  self.kv_event_publisher = EventPublisherFactory.create(
56
77
  kv_events_config, self.attn_dp_rank
57
78
  )
58
79
 
59
80
  def log_prefill_stats(
60
- self,
81
+ self: Scheduler,
61
82
  adder: PrefillAdder,
62
83
  can_run_list: List[Req],
63
84
  running_bs: int,
@@ -138,7 +159,7 @@ class SchedulerMetricsMixin:
138
159
  self._publish_kv_events()
139
160
 
140
161
  def log_decode_stats(
141
- self, can_run_cuda_graph: bool, running_batch: ScheduleBatch = None
162
+ self: Scheduler, can_run_cuda_graph: bool, running_batch: ScheduleBatch = None
142
163
  ):
143
164
  batch = running_batch or self.running_batch
144
165
 
@@ -220,7 +241,7 @@ class SchedulerMetricsMixin:
220
241
  self._emit_kv_metrics()
221
242
  self._publish_kv_events()
222
243
 
223
- def _emit_kv_metrics(self):
244
+ def _emit_kv_metrics(self: Scheduler):
224
245
  kv_metrics = KvMetrics()
225
246
  kv_metrics.request_active_slots = self.stats.num_running_reqs
226
247
  kv_metrics.request_total_slots = self.max_running_requests
@@ -236,9 +257,94 @@ class SchedulerMetricsMixin:
236
257
  if not self.send_metrics_from_scheduler.closed:
237
258
  self.send_metrics_from_scheduler.send_pyobj(kv_metrics)
238
259
 
239
- def _publish_kv_events(self):
260
+ def _publish_kv_events(self: Scheduler):
240
261
  if self.enable_kv_cache_events:
241
262
  events = self.tree_cache.take_events()
242
263
  if events:
243
264
  batch = KVEventBatch(ts=time.time(), events=events)
244
265
  self.kv_event_publisher.publish(batch)
266
+
267
+ def maybe_update_dp_balance_data(
268
+ self: Scheduler, recv_req: TokenizedGenerateReqInput
269
+ ):
270
+ if (
271
+ self.server_args.enable_dp_attention
272
+ and self.server_args.load_balance_method == "minimum_tokens"
273
+ ):
274
+ self.recv_dp_balance_id_this_term.append(recv_req.dp_balance_id)
275
+
276
+ def maybe_handle_dp_balance_data(self: Scheduler):
277
+ if (
278
+ self.server_args.load_balance_method == "minimum_tokens"
279
+ and self.forward_ct % 40 == 0
280
+ ):
281
+ holding_tokens = self.get_load()
282
+
283
+ new_recv_dp_balance_id_list, holding_token_list = (
284
+ self.gather_dp_balance_info(holding_tokens)
285
+ )
286
+
287
+ self.recv_dp_balance_id_this_term.clear()
288
+ if self.tp_rank == 0: # only first worker write info
289
+ self.write_shared_dp_balance_info(
290
+ new_recv_dp_balance_id_list, holding_token_list
291
+ )
292
+
293
+ def gather_dp_balance_info(
294
+ self: Scheduler, holding_tokens_list
295
+ ) -> Union[None, List[List[int]]]:
296
+ """gather recv_dp_balance_id_this_term and holding tokens per worker for dp balance"""
297
+ recv_list = self.recv_dp_balance_id_this_term
298
+ assert len(recv_list) <= 511, (
299
+ "The number of requests received this round is too large. "
300
+ "Please increase gather_tensor_size and onfly_info_size."
301
+ )
302
+ # The maximum size of the tensor used for gathering data from all workers.
303
+ gather_tensor_size = 512
304
+
305
+ # recv_tensor: | holding_tokens | len(recv_dp_balance_id) | recv_dp_balance_ids
306
+ recv_tensor = torch.zeros(gather_tensor_size, dtype=torch.int32)
307
+ recv_tensor[0] = holding_tokens_list
308
+ recv_tensor[1] = len(recv_list) # The first element is the length of the list.
309
+ recv_tensor[2 : len(recv_list) + 2] = torch.tensor(recv_list, dtype=torch.int32)
310
+
311
+ if self.tp_rank == 0:
312
+ gathered_list = [
313
+ torch.zeros(gather_tensor_size, dtype=torch.int32)
314
+ for _ in range(self.balance_meta.num_workers)
315
+ ]
316
+ else:
317
+ gathered_list = None
318
+
319
+ torch.distributed.gather(recv_tensor, gathered_list, group=self.tp_cpu_group)
320
+
321
+ gathered_id_list_per_worker = None
322
+ if self.tp_rank == 0:
323
+ gathered_id_list_per_worker = []
324
+ holding_tokens_list = []
325
+ for tensor in gathered_list:
326
+ holding_tokens_list.append(tensor[0].item())
327
+ list_length = tensor[1].item()
328
+ gathered_id_list_per_worker.append(tensor[2 : list_length + 2].tolist())
329
+
330
+ return gathered_id_list_per_worker, holding_tokens_list
331
+
332
+ def write_shared_dp_balance_info(self: Scheduler, new_recv_rid_lists, local_tokens):
333
+ meta = self.balance_meta
334
+
335
+ with meta.mutex:
336
+ onfly_list: List[Dict[int, int]] = meta.get_shared_onfly()
337
+ assert len(new_recv_rid_lists) == len(onfly_list), "num_worker not equal"
338
+ # 1.Check if the rid received by each worker this round is present in onfly.
339
+ # If it is, remove the corresponding onfly item.
340
+ worker_id = 0
341
+ for new_recv_rids, on_fly_reqs in zip(new_recv_rid_lists, onfly_list):
342
+ for new_recv_rid in new_recv_rids:
343
+ assert (
344
+ new_recv_rid in on_fly_reqs
345
+ ), f"{new_recv_rid=} not in {worker_id=} {on_fly_reqs=}, data consistency is wrong"
346
+ del on_fly_reqs[new_recv_rid]
347
+ worker_id += 1
348
+ # 2. Atomically write local_tokens and onfly into shm under the mutex
349
+ meta.set_shared_onfly_info(onfly_list)
350
+ meta.set_shared_local_tokens(local_tokens)
@@ -24,20 +24,20 @@ import os
24
24
  import re
25
25
  from typing import Optional
26
26
 
27
- from sglang.srt.code_completion_parser import (
27
+ from sglang.srt.parser.code_completion_parser import (
28
28
  CompletionTemplate,
29
29
  FimPosition,
30
30
  completion_template_exists,
31
31
  register_completion_template,
32
32
  )
33
- from sglang.srt.conversation import (
33
+ from sglang.srt.parser.conversation import (
34
34
  Conversation,
35
35
  SeparatorStyle,
36
36
  chat_template_exists,
37
37
  get_conv_template_by_model_path,
38
38
  register_conv_template,
39
39
  )
40
- from sglang.srt.jinja_template_utils import detect_jinja_template_content_format
40
+ from sglang.srt.parser.jinja_template_utils import detect_jinja_template_content_format
41
41
 
42
42
  logger = logging.getLogger(__name__)
43
43
 
@@ -329,6 +329,7 @@ class TokenizerManager:
329
329
  # Metrics
330
330
  if self.enable_metrics:
331
331
  self.metrics_collector = TokenizerMetricsCollector(
332
+ server_args=server_args,
332
333
  labels={
333
334
  "model_name": self.server_args.served_model_name,
334
335
  # TODO: Add lora name/path in the future,