sglang 0.5.2rc0__py3-none-any.whl → 0.5.2rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/lang/interpreter.py +1 -1
- sglang/srt/configs/internvl.py +6 -0
- sglang/srt/configs/model_config.py +2 -1
- sglang/srt/disaggregation/mini_lb.py +2 -2
- sglang/srt/distributed/parallel_state.py +46 -41
- sglang/srt/entrypoints/engine.py +1 -1
- sglang/srt/entrypoints/http_server.py +5 -1
- sglang/srt/entrypoints/openai/protocol.py +3 -3
- sglang/srt/entrypoints/openai/serving_chat.py +3 -3
- sglang/srt/entrypoints/openai/serving_completions.py +3 -1
- sglang/srt/entrypoints/openai/serving_embedding.py +1 -1
- sglang/srt/entrypoints/openai/serving_responses.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +1 -1
- sglang/srt/layers/attention/aiter_backend.py +93 -68
- sglang/srt/layers/communicator.py +45 -7
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +1 -9
- sglang/srt/layers/moe/ep_moe/layer.py +2 -7
- sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -1048
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +796 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
- sglang/srt/layers/moe/utils.py +0 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +8 -0
- sglang/srt/layers/quantization/modelopt_quant.py +35 -2
- sglang/srt/layers/quantization/mxfp4.py +4 -1
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
- sglang/srt/layers/quantization/quark/utils.py +97 -0
- sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
- sglang/srt/layers/quantization/w4afp8.py +30 -25
- sglang/srt/layers/rocm_linear_utils.py +44 -0
- sglang/srt/layers/rotary_embedding.py +0 -18
- sglang/srt/managers/cache_controller.py +42 -39
- sglang/srt/managers/detokenizer_manager.py +0 -34
- sglang/srt/managers/multi_tokenizer_mixin.py +48 -6
- sglang/srt/managers/schedule_policy.py +3 -2
- sglang/srt/managers/scheduler.py +7 -100
- sglang/srt/managers/scheduler_metrics_mixin.py +113 -7
- sglang/srt/managers/template_manager.py +3 -3
- sglang/srt/managers/tokenizer_manager.py +1 -0
- sglang/srt/mem_cache/allocator.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +15 -10
- sglang/srt/mem_cache/hiradix_cache.py +16 -0
- sglang/srt/mem_cache/memory_pool_host.py +18 -11
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +35 -6
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +32 -13
- sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
- sglang/srt/metrics/collector.py +12 -4
- sglang/srt/metrics/utils.py +48 -0
- sglang/srt/model_executor/forward_batch_info.py +16 -17
- sglang/srt/model_executor/model_runner.py +1 -1
- sglang/srt/models/deepseek_v2.py +245 -36
- sglang/srt/models/glm4_moe.py +10 -1
- sglang/srt/models/gpt_oss.py +5 -4
- sglang/srt/models/internvl.py +28 -0
- sglang/srt/models/longcat_flash.py +26 -15
- sglang/srt/models/longcat_flash_nextn.py +23 -15
- sglang/srt/models/minicpmv.py +165 -3
- sglang/srt/models/qwen2_moe.py +4 -1
- sglang/srt/models/qwen3.py +8 -2
- sglang/srt/models/qwen3_moe.py +39 -8
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
- sglang/srt/server_args.py +79 -2
- sglang/srt/speculative/eagle_worker.py +158 -112
- sglang/srt/utils.py +12 -10
- sglang/test/few_shot_gsm8k.py +1 -0
- sglang/test/test_cutlass_w4a8_moe.py +24 -9
- sglang/utils.py +1 -0
- sglang/version.py +1 -1
- {sglang-0.5.2rc0.dist-info → sglang-0.5.2rc2.dist-info}/METADATA +2 -2
- {sglang-0.5.2rc0.dist-info → sglang-0.5.2rc2.dist-info}/RECORD +83 -76
- sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
- /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
- /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
- /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
- /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
- /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
- {sglang-0.5.2rc0.dist-info → sglang-0.5.2rc2.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc0.dist-info → sglang-0.5.2rc2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc0.dist-info → sglang-0.5.2rc2.dist-info}/top_level.txt +0 -0
@@ -324,6 +324,22 @@ class HiCacheController:
|
|
324
324
|
group_ranks, backend="gloo"
|
325
325
|
)
|
326
326
|
|
327
|
+
# Select the get and set functions
|
328
|
+
self.page_get_func = self._generic_page_get
|
329
|
+
self.page_set_func = self._generic_page_set
|
330
|
+
self.batch_exists_func = self.storage_backend.batch_exists
|
331
|
+
self.is_3fs_zerocopy = (
|
332
|
+
self.storage_backend_type == "hf3fs"
|
333
|
+
and self.mem_pool_host.layout == "page_first"
|
334
|
+
)
|
335
|
+
if self.storage_backend_type == "mooncake":
|
336
|
+
self.page_get_func = self._mooncake_page_get
|
337
|
+
self.page_set_func = self._mooncake_page_set
|
338
|
+
elif self.is_3fs_zerocopy:
|
339
|
+
self.page_get_func = self._3fs_zero_copy_page_get
|
340
|
+
self.page_set_func = self._3fs_zero_copy_page_set
|
341
|
+
self.batch_exists_func = self._3fs_zero_copy_batch_exists
|
342
|
+
|
327
343
|
self.load_cache_event = load_cache_event
|
328
344
|
self.layer_done_counter = LayerDoneCounter(self.mem_pool_device.layer_num)
|
329
345
|
self.mem_pool_device.register_layer_transfer_counter(self.layer_done_counter)
|
@@ -407,6 +423,7 @@ class HiCacheController:
|
|
407
423
|
tp_rank=self.tp_rank,
|
408
424
|
tp_size=self.tp_size,
|
409
425
|
is_mla_model=is_mla_backend,
|
426
|
+
is_page_first_layout=self.mem_pool_host.layout == "page_first",
|
410
427
|
model_name=model_name,
|
411
428
|
extra_config=extra_config,
|
412
429
|
)
|
@@ -616,13 +633,19 @@ class HiCacheController:
|
|
616
633
|
for chunk in chunks:
|
617
634
|
self.host_mem_release_queue.put(chunk)
|
618
635
|
|
636
|
+
def _3fs_zero_copy_batch_exists(self, batch_hashes):
|
637
|
+
_batch_hashes, _, factor = self.mem_pool_host.get_buffer_with_hash(batch_hashes)
|
638
|
+
hit_page_num = self.storage_backend.batch_exists(_batch_hashes) // factor
|
639
|
+
return hit_page_num
|
640
|
+
|
619
641
|
def _3fs_zero_copy_page_get(self, operation, hash_values, host_indices):
|
620
|
-
hashes, dsts = self.mem_pool_host.get_buffer_with_hash(
|
642
|
+
hashes, dsts, factor = self.mem_pool_host.get_buffer_with_hash(
|
621
643
|
hash_values, host_indices
|
622
644
|
)
|
623
645
|
page_data = self.storage_backend.batch_get(hashes, dsts)
|
624
646
|
if page_data:
|
625
|
-
|
647
|
+
inc = self.page_size * len(hashes) // factor
|
648
|
+
operation.increment(inc)
|
626
649
|
else:
|
627
650
|
logger.warning(
|
628
651
|
f"Prefetch operation {operation.request_id} failed to retrieve page {hashes}."
|
@@ -636,7 +659,7 @@ class HiCacheController:
|
|
636
659
|
)
|
637
660
|
get_result = self.storage_backend.batch_get(
|
638
661
|
key_strs,
|
639
|
-
|
662
|
+
target_locations=buffer_ptrs,
|
640
663
|
target_sizes=buffer_sizes,
|
641
664
|
)
|
642
665
|
if get_result != len(hash_values):
|
@@ -647,9 +670,9 @@ class HiCacheController:
|
|
647
670
|
operation.increment(get_result * self.page_size)
|
648
671
|
|
649
672
|
def _generic_page_get(self, operation, hash_values, host_indices):
|
650
|
-
dummy_page_dst = [
|
651
|
-
hash_values
|
652
|
-
|
673
|
+
dummy_page_dst = [
|
674
|
+
self.mem_pool_host.get_dummy_flat_data_page() for _ in hash_values
|
675
|
+
]
|
653
676
|
page_data = self.storage_backend.batch_get(hash_values, dummy_page_dst)
|
654
677
|
if page_data is None:
|
655
678
|
return
|
@@ -659,26 +682,16 @@ class HiCacheController:
|
|
659
682
|
f"Prefetch operation {operation.request_id} failed to retrieve page {hash_values[i]}."
|
660
683
|
)
|
661
684
|
break
|
662
|
-
|
663
|
-
|
664
|
-
|
665
|
-
|
666
|
-
|
667
|
-
|
668
|
-
|
685
|
+
# Must set the data before increasing the completed tokens.
|
686
|
+
# Otherwise this page may be read before being set.
|
687
|
+
self.mem_pool_host.set_from_flat_data_page(
|
688
|
+
host_indices[i * self.page_size],
|
689
|
+
page_data[i],
|
690
|
+
)
|
691
|
+
if not operation.increment(self.page_size):
|
692
|
+
break # Operation terminated by controller
|
669
693
|
|
670
694
|
def _page_transfer(self, operation):
|
671
|
-
# Select the get function and batch size
|
672
|
-
if self.storage_backend_type == "mooncake":
|
673
|
-
get_func = self._mooncake_page_get
|
674
|
-
elif (
|
675
|
-
self.storage_backend_type == "hf3fs"
|
676
|
-
and self.mem_pool_host.layout == "page_first"
|
677
|
-
):
|
678
|
-
get_func = self._3fs_zero_copy_page_get
|
679
|
-
else:
|
680
|
-
get_func = self._generic_page_get
|
681
|
-
|
682
695
|
# Transfer batch by batch
|
683
696
|
for i in range(0, len(operation.hash_value), self.storage_batch_size):
|
684
697
|
batch_hashes = operation.hash_value[i : i + self.storage_batch_size]
|
@@ -687,7 +700,7 @@ class HiCacheController:
|
|
687
700
|
]
|
688
701
|
prev_completed_tokens = operation.completed_tokens
|
689
702
|
# Get one batch token, and update the completed_tokens if succeed
|
690
|
-
|
703
|
+
self.page_get_func(operation, batch_hashes, batch_host_indices)
|
691
704
|
# Check termination
|
692
705
|
if (
|
693
706
|
operation.completed_tokens
|
@@ -744,7 +757,7 @@ class HiCacheController:
|
|
744
757
|
batch_tokens[i : i + self.page_size], last_hash
|
745
758
|
)
|
746
759
|
batch_hashes.append(last_hash)
|
747
|
-
hit_page_num = self.
|
760
|
+
hit_page_num = self.batch_exists_func(batch_hashes)
|
748
761
|
hash_value.extend(batch_hashes[:hit_page_num])
|
749
762
|
storage_query_count += hit_page_num * self.page_size
|
750
763
|
if hit_page_num < len(batch_hashes):
|
@@ -830,30 +843,20 @@ class HiCacheController:
|
|
830
843
|
)
|
831
844
|
success = self.storage_backend.batch_set(
|
832
845
|
key_strs,
|
833
|
-
|
846
|
+
target_locations=buffer_ptrs,
|
834
847
|
target_sizes=buffer_sizes,
|
835
848
|
)
|
836
849
|
return success
|
837
850
|
|
838
851
|
# zero copy
|
839
852
|
def _3fs_zero_copy_page_set(self, hash_values, host_indices) -> bool:
|
840
|
-
hashes, dsts = self.mem_pool_host.get_buffer_with_hash(
|
853
|
+
hashes, dsts, _ = self.mem_pool_host.get_buffer_with_hash(
|
841
854
|
hash_values, host_indices
|
842
855
|
)
|
843
856
|
return self.storage_backend.batch_set(hashes, dsts)
|
844
857
|
|
845
858
|
# Backup batch by batch
|
846
859
|
def _page_backup(self, operation):
|
847
|
-
# Select the set function and batch size
|
848
|
-
if self.storage_backend_type == "mooncake":
|
849
|
-
backup_set_func = self._mooncake_page_set
|
850
|
-
elif (
|
851
|
-
self.storage_backend_type == "hf3fs"
|
852
|
-
and self.mem_pool_host.layout == "page_first"
|
853
|
-
):
|
854
|
-
backup_set_func = self._3fs_zero_copy_page_set
|
855
|
-
else:
|
856
|
-
backup_set_func = self._generic_page_set
|
857
860
|
# Backup batch by batch
|
858
861
|
for i in range(0, len(operation.hash_value), self.storage_batch_size):
|
859
862
|
batch_hashes = operation.hash_value[i : i + self.storage_batch_size]
|
@@ -862,7 +865,7 @@ class HiCacheController:
|
|
862
865
|
]
|
863
866
|
# Set one batch token, and record if success.
|
864
867
|
# todo: allow partial success
|
865
|
-
success =
|
868
|
+
success = self.page_set_func(batch_hashes, batch_host_indices)
|
866
869
|
if not success:
|
867
870
|
logger.warning(
|
868
871
|
f"Write page to storage: {len(batch_hashes)} pages failed."
|
@@ -39,7 +39,6 @@ from sglang.srt.server_args import PortArgs, ServerArgs
|
|
39
39
|
from sglang.srt.utils import (
|
40
40
|
configure_logger,
|
41
41
|
freeze_gc,
|
42
|
-
get_worker_ids_from_req_rids,
|
43
42
|
get_zmq_socket,
|
44
43
|
kill_itself_when_parent_died,
|
45
44
|
)
|
@@ -120,39 +119,6 @@ class DetokenizerManager(MultiTokenizerMixin):
|
|
120
119
|
if output is not None:
|
121
120
|
self.send_to_tokenizer.send_pyobj(output)
|
122
121
|
|
123
|
-
def multi_tokenizer_manager_event_loop(self):
|
124
|
-
"""The event loop that handles requests, for multi tokenizer manager mode only"""
|
125
|
-
self.create_sockets_mapping()
|
126
|
-
while True:
|
127
|
-
recv_obj = self.recv_from_scheduler.recv_pyobj()
|
128
|
-
output = self._request_dispatcher(recv_obj)
|
129
|
-
if output is None:
|
130
|
-
continue
|
131
|
-
# Extract worker_id from rid
|
132
|
-
if isinstance(recv_obj.rids, list):
|
133
|
-
worker_ids = get_worker_ids_from_req_rids(recv_obj.rids)
|
134
|
-
else:
|
135
|
-
raise RuntimeError(
|
136
|
-
f"for tokenizer_worker_num > 1, recv_obj.rids must be a list"
|
137
|
-
)
|
138
|
-
|
139
|
-
# Send data using the corresponding socket
|
140
|
-
for i, worker_id in enumerate(worker_ids):
|
141
|
-
if isinstance(recv_obj, MultiTokenizerRegisterReq):
|
142
|
-
if self.register_tokenizer_ipc(recv_obj, worker_id):
|
143
|
-
logger.info(
|
144
|
-
f"DetokenizerManager Created ZMQ socket for worker {worker_id}"
|
145
|
-
)
|
146
|
-
continue
|
147
|
-
else:
|
148
|
-
if worker_id not in self.tokenizer_mapping:
|
149
|
-
logger.error(
|
150
|
-
f"Tokenizer Worker ID {worker_id} not registered. Check if the server Process {worker_id} is alive"
|
151
|
-
)
|
152
|
-
continue
|
153
|
-
new_output = self._handle_output_by_index(output, i)
|
154
|
-
self.tokenizer_mapping[worker_id].send_pyobj(new_output)
|
155
|
-
|
156
122
|
def trim_matched_stop(
|
157
123
|
self, output: Union[str, List[int]], finished_reason: Dict, no_stop_trim: bool
|
158
124
|
):
|
@@ -23,6 +23,7 @@ import threading
|
|
23
23
|
from multiprocessing import shared_memory
|
24
24
|
from typing import Dict
|
25
25
|
|
26
|
+
import setproctitle
|
26
27
|
import zmq
|
27
28
|
import zmq.asyncio
|
28
29
|
|
@@ -37,11 +38,7 @@ from sglang.srt.managers.io_struct import (
|
|
37
38
|
)
|
38
39
|
from sglang.srt.managers.tokenizer_manager import TokenizerManager, _Communicator
|
39
40
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
40
|
-
from sglang.srt.utils import
|
41
|
-
get_worker_ids_from_req_rids,
|
42
|
-
get_zmq_socket,
|
43
|
-
kill_process_tree,
|
44
|
-
)
|
41
|
+
from sglang.srt.utils import get_zmq_socket, kill_process_tree
|
45
42
|
from sglang.utils import get_exception_traceback
|
46
43
|
|
47
44
|
logger = logging.getLogger(__name__)
|
@@ -344,6 +341,48 @@ class MultiTokenizerMixin:
|
|
344
341
|
new_output = output
|
345
342
|
return new_output
|
346
343
|
|
344
|
+
def get_worker_ids_from_req_rids(self, rids):
|
345
|
+
if isinstance(rids, list):
|
346
|
+
worker_ids = [int(rid.split("_")[0]) for rid in rids]
|
347
|
+
elif isinstance(rids, str):
|
348
|
+
worker_ids = [int(rids.split("_")[0])]
|
349
|
+
else:
|
350
|
+
worker_ids = []
|
351
|
+
return worker_ids
|
352
|
+
|
353
|
+
def multi_tokenizer_manager_event_loop(self):
|
354
|
+
"""The event loop that handles requests, for multi tokenizer manager mode only"""
|
355
|
+
self.create_sockets_mapping()
|
356
|
+
while True:
|
357
|
+
recv_obj = self.recv_from_scheduler.recv_pyobj()
|
358
|
+
output = self._request_dispatcher(recv_obj)
|
359
|
+
if output is None:
|
360
|
+
continue
|
361
|
+
# Extract worker_id from rid
|
362
|
+
if isinstance(recv_obj.rids, list):
|
363
|
+
worker_ids = self.get_worker_ids_from_req_rids(recv_obj.rids)
|
364
|
+
else:
|
365
|
+
raise RuntimeError(
|
366
|
+
f"for tokenizer_worker_num > 1, recv_obj.rids must be a list"
|
367
|
+
)
|
368
|
+
|
369
|
+
# Send data using the corresponding socket
|
370
|
+
for i, worker_id in enumerate(worker_ids):
|
371
|
+
if isinstance(recv_obj, MultiTokenizerRegisterReq):
|
372
|
+
if self.register_tokenizer_ipc(recv_obj, worker_id):
|
373
|
+
logger.info(
|
374
|
+
f"DetokenizerManager Created ZMQ socket for worker {worker_id}"
|
375
|
+
)
|
376
|
+
continue
|
377
|
+
else:
|
378
|
+
if worker_id not in self.tokenizer_mapping:
|
379
|
+
logger.error(
|
380
|
+
f"Tokenizer Worker ID {worker_id} not registered. Check if the server Process {worker_id} is alive"
|
381
|
+
)
|
382
|
+
continue
|
383
|
+
new_output = self._handle_output_by_index(output, i)
|
384
|
+
self.tokenizer_mapping[worker_id].send_pyobj(new_output)
|
385
|
+
|
347
386
|
def clear_tokenizer_mapping(self):
|
348
387
|
if hasattr(self, "tokenizer_mapping"):
|
349
388
|
for socket in self.tokenizer_mapping.values():
|
@@ -406,7 +445,7 @@ class MultiTokenizerRouter(TokenizerManager, MultiTokenizerMixin):
|
|
406
445
|
worker_ids = [recv_obj.worker_id]
|
407
446
|
recv_obj = recv_obj.obj
|
408
447
|
else:
|
409
|
-
worker_ids = get_worker_ids_from_req_rids(recv_obj.rids)
|
448
|
+
worker_ids = self.get_worker_ids_from_req_rids(recv_obj.rids)
|
410
449
|
|
411
450
|
if len(worker_ids) == 0:
|
412
451
|
logger.error(f"Cannot find worker_id from rids {recv_obj.rids}")
|
@@ -438,6 +477,9 @@ class MultiTokenizerManager(TokenizerManager, MultiTokenizerMixin):
|
|
438
477
|
server_args: ServerArgs,
|
439
478
|
port_args: PortArgs,
|
440
479
|
):
|
480
|
+
setproctitle.setproctitle(
|
481
|
+
f"sglang::http_server/multi_tokenizer_manager:{os.getpid()}"
|
482
|
+
)
|
441
483
|
# prevent init prefill bootstrapserver again
|
442
484
|
disaggregation_mode = server_args.disaggregation_mode
|
443
485
|
server_args.disaggregation_mode = "null"
|
@@ -380,8 +380,9 @@ class PrefillAdder:
|
|
380
380
|
self.log_input_tokens += extend_input_len
|
381
381
|
|
382
382
|
def add_chunked_req(self, req: Req):
|
383
|
-
|
384
|
-
|
383
|
+
_rem_tokens = min(self.rem_chunk_tokens, int(self.rem_total_tokens))
|
384
|
+
truncated = req.extend_input_len > _rem_tokens
|
385
|
+
req.extend_input_len = min(req.extend_input_len, _rem_tokens)
|
385
386
|
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + req.extend_input_len]
|
386
387
|
self.can_run_list.append(req)
|
387
388
|
self._update_prefill_budget(
|
sglang/srt/managers/scheduler.py
CHANGED
@@ -141,7 +141,7 @@ from sglang.srt.mem_cache.lora_radix_cache import LoRARadixCache
|
|
141
141
|
from sglang.srt.mem_cache.radix_cache import RadixCache
|
142
142
|
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
|
143
143
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors
|
144
|
-
from sglang.srt.reasoning_parser import ReasoningParser
|
144
|
+
from sglang.srt.parser.reasoning_parser import ReasoningParser
|
145
145
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
146
146
|
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
147
147
|
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
@@ -500,6 +500,7 @@ class Scheduler(
|
|
500
500
|
# Init metrics stats
|
501
501
|
self.init_metrics(tp_rank, pp_rank, dp_rank)
|
502
502
|
self.init_kv_events(server_args.kv_events_config)
|
503
|
+
self.init_dp_balance(dp_balance_meta)
|
503
504
|
|
504
505
|
# Init disaggregation
|
505
506
|
self.disaggregation_mode = DisaggregationMode(
|
@@ -545,15 +546,6 @@ class Scheduler(
|
|
545
546
|
]
|
546
547
|
)
|
547
548
|
|
548
|
-
self.balance_meta = dp_balance_meta
|
549
|
-
if (
|
550
|
-
server_args.enable_dp_attention
|
551
|
-
and server_args.load_balance_method == "minimum_tokens"
|
552
|
-
):
|
553
|
-
assert dp_balance_meta is not None
|
554
|
-
|
555
|
-
self.recv_dp_balance_id_this_term = []
|
556
|
-
|
557
549
|
def init_tokenizer(self):
|
558
550
|
server_args = self.server_args
|
559
551
|
self.is_generation = self.model_config.is_generation
|
@@ -1126,11 +1118,7 @@ class Scheduler(
|
|
1126
1118
|
self,
|
1127
1119
|
recv_req: TokenizedGenerateReqInput,
|
1128
1120
|
):
|
1129
|
-
|
1130
|
-
self.server_args.enable_dp_attention
|
1131
|
-
and self.server_args.load_balance_method == "minimum_tokens"
|
1132
|
-
):
|
1133
|
-
self.recv_dp_balance_id_this_term.append(recv_req.dp_balance_id)
|
1121
|
+
self.maybe_update_dp_balance_data(recv_req)
|
1134
1122
|
|
1135
1123
|
# Create a new request
|
1136
1124
|
if (
|
@@ -1568,11 +1556,7 @@ class Scheduler(
|
|
1568
1556
|
|
1569
1557
|
# Handle DP attention
|
1570
1558
|
if need_dp_attn_preparation:
|
1571
|
-
|
1572
|
-
self.server_args.load_balance_method == "minimum_tokens"
|
1573
|
-
and self.forward_ct % 40 == 0
|
1574
|
-
):
|
1575
|
-
self.handle_dp_balance_data(ret)
|
1559
|
+
self.maybe_handle_dp_balance_data()
|
1576
1560
|
ret = self.prepare_mlp_sync_batch(ret)
|
1577
1561
|
|
1578
1562
|
return ret
|
@@ -1897,86 +1881,6 @@ class Scheduler(
|
|
1897
1881
|
disable_overlap_schedule=self.server_args.disable_overlap_schedule,
|
1898
1882
|
)
|
1899
1883
|
|
1900
|
-
def handle_dp_balance_data(self, local_batch: ScheduleBatch):
|
1901
|
-
def gather_dp_balance_info(holding_tokens_list) -> Union[None, List[List[int]]]:
|
1902
|
-
"""gather recv_dp_balance_id_this_term and holding tokens per worker for dp balance"""
|
1903
|
-
recv_list = self.recv_dp_balance_id_this_term
|
1904
|
-
assert len(recv_list) <= 511, (
|
1905
|
-
"The number of requests received this round is too large. "
|
1906
|
-
"Please increase gather_tensor_size and onfly_info_size."
|
1907
|
-
)
|
1908
|
-
# The maximum size of the tensor used for gathering data from all workers.
|
1909
|
-
gather_tensor_size = 512
|
1910
|
-
|
1911
|
-
# recv_tensor: | holding_tokens | len(recv_dp_balance_id) | recv_dp_balance_ids
|
1912
|
-
recv_tensor = torch.zeros(gather_tensor_size, dtype=torch.int32)
|
1913
|
-
recv_tensor[0] = holding_tokens_list
|
1914
|
-
recv_tensor[1] = len(
|
1915
|
-
recv_list
|
1916
|
-
) # The first element is the length of the list.
|
1917
|
-
recv_tensor[2 : len(recv_list) + 2] = torch.tensor(
|
1918
|
-
recv_list, dtype=torch.int32
|
1919
|
-
)
|
1920
|
-
|
1921
|
-
if self.tp_rank == 0:
|
1922
|
-
gathered_list = [
|
1923
|
-
torch.zeros(gather_tensor_size, dtype=torch.int32)
|
1924
|
-
for _ in range(self.balance_meta.num_workers)
|
1925
|
-
]
|
1926
|
-
else:
|
1927
|
-
gathered_list = None
|
1928
|
-
|
1929
|
-
torch.distributed.gather(
|
1930
|
-
recv_tensor, gathered_list, group=self.tp_cpu_group
|
1931
|
-
)
|
1932
|
-
|
1933
|
-
gathered_id_list_per_worker = None
|
1934
|
-
if self.tp_rank == 0:
|
1935
|
-
gathered_id_list_per_worker = []
|
1936
|
-
holding_tokens_list = []
|
1937
|
-
for tensor in gathered_list:
|
1938
|
-
holding_tokens_list.append(tensor[0].item())
|
1939
|
-
list_length = tensor[1].item()
|
1940
|
-
gathered_id_list_per_worker.append(
|
1941
|
-
tensor[2 : list_length + 2].tolist()
|
1942
|
-
)
|
1943
|
-
|
1944
|
-
return gathered_id_list_per_worker, holding_tokens_list
|
1945
|
-
|
1946
|
-
def write_shared_dp_balance_info(new_recv_rid_lists, local_tokens):
|
1947
|
-
meta = self.balance_meta
|
1948
|
-
|
1949
|
-
with meta.mutex:
|
1950
|
-
onfly_list: List[Dict[int, int]] = meta.get_shared_onfly()
|
1951
|
-
assert len(new_recv_rid_lists) == len(
|
1952
|
-
onfly_list
|
1953
|
-
), "num_worker not equal"
|
1954
|
-
# 1.Check if the rid received by each worker this round is present in onfly.
|
1955
|
-
# If it is, remove the corresponding onfly item.
|
1956
|
-
worker_id = 0
|
1957
|
-
for new_recv_rids, on_fly_reqs in zip(new_recv_rid_lists, onfly_list):
|
1958
|
-
for new_recv_rid in new_recv_rids:
|
1959
|
-
assert (
|
1960
|
-
new_recv_rid in on_fly_reqs
|
1961
|
-
), f"{new_recv_rid=} not in {worker_id=} {on_fly_reqs=}, data consistency is wrong"
|
1962
|
-
del on_fly_reqs[new_recv_rid]
|
1963
|
-
worker_id += 1
|
1964
|
-
# 2. Atomically write local_tokens and onfly into shm under the mutex
|
1965
|
-
meta.set_shared_onfly_info(onfly_list)
|
1966
|
-
meta.set_shared_local_tokens(local_tokens)
|
1967
|
-
|
1968
|
-
holding_tokens = self.get_load()
|
1969
|
-
|
1970
|
-
new_recv_dp_balance_id_list, holding_token_list = gather_dp_balance_info(
|
1971
|
-
holding_tokens
|
1972
|
-
)
|
1973
|
-
|
1974
|
-
self.recv_dp_balance_id_this_term.clear()
|
1975
|
-
if self.tp_rank == 0: # only first worker write info
|
1976
|
-
write_shared_dp_balance_info(
|
1977
|
-
new_recv_dp_balance_id_list, holding_token_list
|
1978
|
-
)
|
1979
|
-
|
1980
1884
|
@staticmethod
|
1981
1885
|
def prepare_mlp_sync_batch_raw(
|
1982
1886
|
local_batch: ScheduleBatch,
|
@@ -2403,6 +2307,9 @@ class Scheduler(
|
|
2403
2307
|
# This only works for requests that have not started anything.
|
2404
2308
|
# We still need to send something back to TokenizerManager to clean up the state.
|
2405
2309
|
req = self.waiting_queue.pop(i)
|
2310
|
+
if self.enable_hicache_storage:
|
2311
|
+
# to release prefetch events associated with the request
|
2312
|
+
self.tree_cache.release_aborted_request(req.rid)
|
2406
2313
|
self.send_to_tokenizer.send_pyobj(AbortReq(req.rid))
|
2407
2314
|
# For disaggregation decode mode, the request in the waiting queue has KV cache allocated.
|
2408
2315
|
if self.disaggregation_mode == DisaggregationMode.DECODE:
|
@@ -1,15 +1,24 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
1
3
|
import logging
|
2
4
|
import time
|
3
5
|
from collections import defaultdict
|
4
|
-
from typing import List, Optional
|
6
|
+
from typing import TYPE_CHECKING, Dict, List, Optional, Union
|
7
|
+
|
8
|
+
import torch
|
5
9
|
|
6
10
|
from sglang.srt.disaggregation.kv_events import EventPublisherFactory, KVEventBatch
|
7
11
|
from sglang.srt.disaggregation.utils import DisaggregationMode
|
12
|
+
from sglang.srt.managers.io_struct import TokenizedGenerateReqInput
|
8
13
|
from sglang.srt.managers.schedule_policy import PrefillAdder
|
9
14
|
from sglang.srt.managers.scheduler import Req, ScheduleBatch
|
15
|
+
from sglang.srt.managers.utils import DPBalanceMeta
|
10
16
|
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
|
11
17
|
from sglang.srt.utils import get_bool_env_var
|
12
18
|
|
19
|
+
if TYPE_CHECKING:
|
20
|
+
from sglang.srt.managers.scheduler import Scheduler
|
21
|
+
|
13
22
|
logger = logging.getLogger(__name__)
|
14
23
|
|
15
24
|
RECORD_STEP_TIME = get_bool_env_var("SGLANG_RECORD_STEP_TIME")
|
@@ -28,7 +37,9 @@ class KvMetrics:
|
|
28
37
|
|
29
38
|
|
30
39
|
class SchedulerMetricsMixin:
|
31
|
-
def init_metrics(
|
40
|
+
def init_metrics(
|
41
|
+
self: Scheduler, tp_rank: int, pp_rank: int, dp_rank: Optional[int]
|
42
|
+
):
|
32
43
|
self.last_gen_throughput: float = 0.0
|
33
44
|
self.last_input_throughput: float = 0.0
|
34
45
|
self.step_time_dict = defaultdict(list) # Dict[batch size -> step time]
|
@@ -50,14 +61,24 @@ class SchedulerMetricsMixin:
|
|
50
61
|
labels["dp_rank"] = dp_rank
|
51
62
|
self.metrics_collector = SchedulerMetricsCollector(labels=labels)
|
52
63
|
|
53
|
-
def
|
64
|
+
def init_dp_balance(self: Scheduler, dp_balance_meta: Optional[DPBalanceMeta]):
|
65
|
+
self.balance_meta = dp_balance_meta
|
66
|
+
if (
|
67
|
+
self.server_args.enable_dp_attention
|
68
|
+
and self.server_args.load_balance_method == "minimum_tokens"
|
69
|
+
):
|
70
|
+
assert dp_balance_meta is not None
|
71
|
+
|
72
|
+
self.recv_dp_balance_id_this_term = []
|
73
|
+
|
74
|
+
def init_kv_events(self: Scheduler, kv_events_config: Optional[str]):
|
54
75
|
if self.enable_kv_cache_events:
|
55
76
|
self.kv_event_publisher = EventPublisherFactory.create(
|
56
77
|
kv_events_config, self.attn_dp_rank
|
57
78
|
)
|
58
79
|
|
59
80
|
def log_prefill_stats(
|
60
|
-
self,
|
81
|
+
self: Scheduler,
|
61
82
|
adder: PrefillAdder,
|
62
83
|
can_run_list: List[Req],
|
63
84
|
running_bs: int,
|
@@ -138,7 +159,7 @@ class SchedulerMetricsMixin:
|
|
138
159
|
self._publish_kv_events()
|
139
160
|
|
140
161
|
def log_decode_stats(
|
141
|
-
self, can_run_cuda_graph: bool, running_batch: ScheduleBatch = None
|
162
|
+
self: Scheduler, can_run_cuda_graph: bool, running_batch: ScheduleBatch = None
|
142
163
|
):
|
143
164
|
batch = running_batch or self.running_batch
|
144
165
|
|
@@ -220,7 +241,7 @@ class SchedulerMetricsMixin:
|
|
220
241
|
self._emit_kv_metrics()
|
221
242
|
self._publish_kv_events()
|
222
243
|
|
223
|
-
def _emit_kv_metrics(self):
|
244
|
+
def _emit_kv_metrics(self: Scheduler):
|
224
245
|
kv_metrics = KvMetrics()
|
225
246
|
kv_metrics.request_active_slots = self.stats.num_running_reqs
|
226
247
|
kv_metrics.request_total_slots = self.max_running_requests
|
@@ -236,9 +257,94 @@ class SchedulerMetricsMixin:
|
|
236
257
|
if not self.send_metrics_from_scheduler.closed:
|
237
258
|
self.send_metrics_from_scheduler.send_pyobj(kv_metrics)
|
238
259
|
|
239
|
-
def _publish_kv_events(self):
|
260
|
+
def _publish_kv_events(self: Scheduler):
|
240
261
|
if self.enable_kv_cache_events:
|
241
262
|
events = self.tree_cache.take_events()
|
242
263
|
if events:
|
243
264
|
batch = KVEventBatch(ts=time.time(), events=events)
|
244
265
|
self.kv_event_publisher.publish(batch)
|
266
|
+
|
267
|
+
def maybe_update_dp_balance_data(
|
268
|
+
self: Scheduler, recv_req: TokenizedGenerateReqInput
|
269
|
+
):
|
270
|
+
if (
|
271
|
+
self.server_args.enable_dp_attention
|
272
|
+
and self.server_args.load_balance_method == "minimum_tokens"
|
273
|
+
):
|
274
|
+
self.recv_dp_balance_id_this_term.append(recv_req.dp_balance_id)
|
275
|
+
|
276
|
+
def maybe_handle_dp_balance_data(self: Scheduler):
|
277
|
+
if (
|
278
|
+
self.server_args.load_balance_method == "minimum_tokens"
|
279
|
+
and self.forward_ct % 40 == 0
|
280
|
+
):
|
281
|
+
holding_tokens = self.get_load()
|
282
|
+
|
283
|
+
new_recv_dp_balance_id_list, holding_token_list = (
|
284
|
+
self.gather_dp_balance_info(holding_tokens)
|
285
|
+
)
|
286
|
+
|
287
|
+
self.recv_dp_balance_id_this_term.clear()
|
288
|
+
if self.tp_rank == 0: # only first worker write info
|
289
|
+
self.write_shared_dp_balance_info(
|
290
|
+
new_recv_dp_balance_id_list, holding_token_list
|
291
|
+
)
|
292
|
+
|
293
|
+
def gather_dp_balance_info(
|
294
|
+
self: Scheduler, holding_tokens_list
|
295
|
+
) -> Union[None, List[List[int]]]:
|
296
|
+
"""gather recv_dp_balance_id_this_term and holding tokens per worker for dp balance"""
|
297
|
+
recv_list = self.recv_dp_balance_id_this_term
|
298
|
+
assert len(recv_list) <= 511, (
|
299
|
+
"The number of requests received this round is too large. "
|
300
|
+
"Please increase gather_tensor_size and onfly_info_size."
|
301
|
+
)
|
302
|
+
# The maximum size of the tensor used for gathering data from all workers.
|
303
|
+
gather_tensor_size = 512
|
304
|
+
|
305
|
+
# recv_tensor: | holding_tokens | len(recv_dp_balance_id) | recv_dp_balance_ids
|
306
|
+
recv_tensor = torch.zeros(gather_tensor_size, dtype=torch.int32)
|
307
|
+
recv_tensor[0] = holding_tokens_list
|
308
|
+
recv_tensor[1] = len(recv_list) # The first element is the length of the list.
|
309
|
+
recv_tensor[2 : len(recv_list) + 2] = torch.tensor(recv_list, dtype=torch.int32)
|
310
|
+
|
311
|
+
if self.tp_rank == 0:
|
312
|
+
gathered_list = [
|
313
|
+
torch.zeros(gather_tensor_size, dtype=torch.int32)
|
314
|
+
for _ in range(self.balance_meta.num_workers)
|
315
|
+
]
|
316
|
+
else:
|
317
|
+
gathered_list = None
|
318
|
+
|
319
|
+
torch.distributed.gather(recv_tensor, gathered_list, group=self.tp_cpu_group)
|
320
|
+
|
321
|
+
gathered_id_list_per_worker = None
|
322
|
+
if self.tp_rank == 0:
|
323
|
+
gathered_id_list_per_worker = []
|
324
|
+
holding_tokens_list = []
|
325
|
+
for tensor in gathered_list:
|
326
|
+
holding_tokens_list.append(tensor[0].item())
|
327
|
+
list_length = tensor[1].item()
|
328
|
+
gathered_id_list_per_worker.append(tensor[2 : list_length + 2].tolist())
|
329
|
+
|
330
|
+
return gathered_id_list_per_worker, holding_tokens_list
|
331
|
+
|
332
|
+
def write_shared_dp_balance_info(self: Scheduler, new_recv_rid_lists, local_tokens):
|
333
|
+
meta = self.balance_meta
|
334
|
+
|
335
|
+
with meta.mutex:
|
336
|
+
onfly_list: List[Dict[int, int]] = meta.get_shared_onfly()
|
337
|
+
assert len(new_recv_rid_lists) == len(onfly_list), "num_worker not equal"
|
338
|
+
# 1.Check if the rid received by each worker this round is present in onfly.
|
339
|
+
# If it is, remove the corresponding onfly item.
|
340
|
+
worker_id = 0
|
341
|
+
for new_recv_rids, on_fly_reqs in zip(new_recv_rid_lists, onfly_list):
|
342
|
+
for new_recv_rid in new_recv_rids:
|
343
|
+
assert (
|
344
|
+
new_recv_rid in on_fly_reqs
|
345
|
+
), f"{new_recv_rid=} not in {worker_id=} {on_fly_reqs=}, data consistency is wrong"
|
346
|
+
del on_fly_reqs[new_recv_rid]
|
347
|
+
worker_id += 1
|
348
|
+
# 2. Atomically write local_tokens and onfly into shm under the mutex
|
349
|
+
meta.set_shared_onfly_info(onfly_list)
|
350
|
+
meta.set_shared_local_tokens(local_tokens)
|
@@ -24,20 +24,20 @@ import os
|
|
24
24
|
import re
|
25
25
|
from typing import Optional
|
26
26
|
|
27
|
-
from sglang.srt.code_completion_parser import (
|
27
|
+
from sglang.srt.parser.code_completion_parser import (
|
28
28
|
CompletionTemplate,
|
29
29
|
FimPosition,
|
30
30
|
completion_template_exists,
|
31
31
|
register_completion_template,
|
32
32
|
)
|
33
|
-
from sglang.srt.conversation import (
|
33
|
+
from sglang.srt.parser.conversation import (
|
34
34
|
Conversation,
|
35
35
|
SeparatorStyle,
|
36
36
|
chat_template_exists,
|
37
37
|
get_conv_template_by_model_path,
|
38
38
|
register_conv_template,
|
39
39
|
)
|
40
|
-
from sglang.srt.jinja_template_utils import detect_jinja_template_content_format
|
40
|
+
from sglang.srt.parser.jinja_template_utils import detect_jinja_template_content_format
|
41
41
|
|
42
42
|
logger = logging.getLogger(__name__)
|
43
43
|
|
@@ -329,6 +329,7 @@ class TokenizerManager:
|
|
329
329
|
# Metrics
|
330
330
|
if self.enable_metrics:
|
331
331
|
self.metrics_collector = TokenizerMetricsCollector(
|
332
|
+
server_args=server_args,
|
332
333
|
labels={
|
333
334
|
"model_name": self.server_args.served_model_name,
|
334
335
|
# TODO: Add lora name/path in the future,
|