sglang 0.4.10__py3-none-any.whl → 0.4.10.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +20 -0
- sglang/compile_deep_gemm.py +8 -1
- sglang/global_config.py +5 -1
- sglang/srt/configs/model_config.py +1 -0
- sglang/srt/conversation.py +0 -112
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +1 -0
- sglang/srt/disaggregation/launch_lb.py +5 -20
- sglang/srt/disaggregation/mooncake/conn.py +33 -15
- sglang/srt/disaggregation/prefill.py +1 -0
- sglang/srt/distributed/device_communicators/pynccl.py +7 -0
- sglang/srt/distributed/device_communicators/pynccl_allocator.py +133 -0
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +42 -3
- sglang/srt/distributed/parallel_state.py +11 -0
- sglang/srt/entrypoints/engine.py +4 -2
- sglang/srt/entrypoints/http_server.py +35 -15
- sglang/srt/eplb/expert_distribution.py +4 -2
- sglang/srt/hf_transformers_utils.py +25 -10
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/flashattention_backend.py +7 -11
- sglang/srt/layers/attention/trtllm_mla_backend.py +372 -0
- sglang/srt/layers/attention/utils.py +6 -1
- sglang/srt/layers/attention/vision.py +27 -10
- sglang/srt/layers/communicator.py +14 -4
- sglang/srt/layers/linear.py +7 -1
- sglang/srt/layers/logits_processor.py +9 -1
- sglang/srt/layers/moe/ep_moe/layer.py +29 -68
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=352,device_name=NVIDIA_RTX_6000_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +82 -25
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +0 -31
- sglang/srt/layers/moe/token_dispatcher/__init__.py +23 -0
- sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +12 -1
- sglang/srt/layers/moe/{ep_moe/token_dispatcher.py → token_dispatcher/deepep.py} +8 -15
- sglang/srt/layers/moe/utils.py +43 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +3 -2
- sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +1 -1
- sglang/srt/layers/quantization/fp8.py +57 -1
- sglang/srt/layers/quantization/fp8_kernel.py +0 -4
- sglang/srt/layers/quantization/w8a8_int8.py +4 -1
- sglang/srt/layers/vocab_parallel_embedding.py +7 -1
- sglang/srt/lora/lora_registry.py +7 -0
- sglang/srt/managers/cache_controller.py +43 -39
- sglang/srt/managers/data_parallel_controller.py +52 -2
- sglang/srt/managers/io_struct.py +6 -1
- sglang/srt/managers/schedule_batch.py +3 -2
- sglang/srt/managers/schedule_policy.py +3 -1
- sglang/srt/managers/scheduler.py +145 -6
- sglang/srt/managers/template_manager.py +25 -22
- sglang/srt/managers/tokenizer_manager.py +114 -62
- sglang/srt/managers/utils.py +45 -1
- sglang/srt/mem_cache/cpp_radix_tree/radix_tree.py +182 -0
- sglang/srt/mem_cache/hicache_storage.py +13 -12
- sglang/srt/mem_cache/hiradix_cache.py +21 -4
- sglang/srt/mem_cache/memory_pool.py +15 -118
- sglang/srt/mem_cache/memory_pool_host.py +350 -33
- sglang/srt/mem_cache/radix_cache_cpp.py +229 -0
- sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +8 -2
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_utils.cpp +35 -0
- sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +163 -0
- sglang/srt/mem_cache/storage/nixl/nixl_utils.py +238 -0
- sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +216 -0
- sglang/srt/model_executor/cuda_graph_runner.py +42 -4
- sglang/srt/model_executor/forward_batch_info.py +13 -3
- sglang/srt/model_executor/model_runner.py +13 -1
- sglang/srt/model_loader/weight_utils.py +2 -0
- sglang/srt/models/deepseek_v2.py +28 -23
- sglang/srt/models/glm4_moe.py +85 -22
- sglang/srt/models/grok.py +3 -3
- sglang/srt/models/llama4.py +13 -2
- sglang/srt/models/mixtral.py +3 -3
- sglang/srt/models/mllama4.py +428 -19
- sglang/srt/models/qwen2_moe.py +1 -4
- sglang/srt/models/qwen3_moe.py +7 -8
- sglang/srt/models/step3_vl.py +1 -4
- sglang/srt/multimodal/processors/base_processor.py +4 -3
- sglang/srt/multimodal/processors/gemma3n.py +0 -7
- sglang/srt/operations_strategy.py +1 -1
- sglang/srt/server_args.py +115 -21
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +18 -0
- sglang/srt/two_batch_overlap.py +6 -4
- sglang/srt/utils.py +4 -24
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/attention/test_trtllm_mla_backend.py +945 -0
- sglang/test/runners.py +2 -2
- sglang/test/test_utils.py +3 -3
- sglang/version.py +1 -1
- {sglang-0.4.10.dist-info → sglang-0.4.10.post2.dist-info}/METADATA +3 -2
- {sglang-0.4.10.dist-info → sglang-0.4.10.post2.dist-info}/RECORD +92 -81
- /sglang/srt/mem_cache/{mooncake_store → storage/mooncake_store}/mooncake_store.py +0 -0
- /sglang/srt/mem_cache/{mooncake_store → storage/mooncake_store}/unit_test.py +0 -0
- {sglang-0.4.10.dist-info → sglang-0.4.10.post2.dist-info}/WHEEL +0 -0
- {sglang-0.4.10.dist-info → sglang-0.4.10.post2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.10.dist-info → sglang-0.4.10.post2.dist-info}/top_level.txt +0 -0
@@ -25,12 +25,6 @@ if TYPE_CHECKING:
|
|
25
25
|
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
26
26
|
from sglang.srt.mem_cache.memory_pool_host import HostKVCache
|
27
27
|
|
28
|
-
from sglang.srt.mem_cache.hicache_storage import HiCacheFile, get_hash_str
|
29
|
-
from sglang.srt.mem_cache.mooncake_store.mooncake_store import (
|
30
|
-
MooncakeStore,
|
31
|
-
get_hash_str_mooncake,
|
32
|
-
)
|
33
|
-
from sglang.srt.mem_cache.storage.hf3fs.storage_hf3fs import HiCacheHF3FS
|
34
28
|
|
35
29
|
logger = logging.getLogger(__name__)
|
36
30
|
|
@@ -237,40 +231,36 @@ class HiCacheController:
|
|
237
231
|
self.mem_pool_host = mem_pool_host
|
238
232
|
self.write_policy = write_policy
|
239
233
|
self.page_size = page_size
|
240
|
-
|
241
|
-
if not io_backend:
|
242
|
-
IO_BACKEND_PAGE_SIZE_THRESHOLD = 64
|
243
|
-
self.io_backend = (
|
244
|
-
"direct"
|
245
|
-
if self.page_size >= IO_BACKEND_PAGE_SIZE_THRESHOLD
|
246
|
-
else "kernel"
|
247
|
-
)
|
248
|
-
else:
|
249
|
-
self.io_backend = io_backend
|
234
|
+
self.io_backend = io_backend
|
250
235
|
|
251
236
|
self.enable_storage = False
|
252
237
|
# todo: move backend initialization to storage backend module
|
253
238
|
if storage_backend is not None:
|
254
|
-
|
255
|
-
|
256
|
-
if self.tp_world_size > 1:
|
257
|
-
group_ranks = torch.distributed.get_process_group_ranks(tp_group)
|
258
|
-
self.prefetch_tp_group = torch.distributed.new_group(
|
259
|
-
group_ranks, backend="gloo"
|
260
|
-
)
|
261
|
-
self.backup_tp_group = torch.distributed.new_group(
|
262
|
-
group_ranks, backend="gloo"
|
263
|
-
)
|
239
|
+
self.storage_backend_type = storage_backend
|
240
|
+
from sglang.srt.mem_cache.hicache_storage import HiCacheFile, get_hash_str
|
264
241
|
|
265
242
|
if storage_backend == "file":
|
266
243
|
self.storage_backend = HiCacheFile()
|
267
244
|
self.get_hash_str = get_hash_str
|
245
|
+
elif storage_backend == "nixl":
|
246
|
+
from sglang.srt.mem_cache.nixl.hicache_nixl import HiCacheNixl
|
247
|
+
|
248
|
+
self.storage_backend = HiCacheNixl()
|
249
|
+
self.get_hash_str = get_hash_str
|
268
250
|
elif storage_backend == "mooncake":
|
251
|
+
from sglang.srt.mem_cache.mooncake_store.mooncake_store import (
|
252
|
+
MooncakeStore,
|
253
|
+
get_hash_str_mooncake,
|
254
|
+
)
|
255
|
+
|
269
256
|
self.storage_backend = MooncakeStore()
|
270
257
|
self.get_hash_str = get_hash_str_mooncake
|
271
258
|
self.storage_backend.register_buffer(self.mem_pool_host.kv_buffer)
|
272
259
|
elif storage_backend == "hf3fs":
|
273
260
|
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
261
|
+
from sglang.srt.mem_cache.storage.hf3fs.storage_hf3fs import (
|
262
|
+
HiCacheHF3FS,
|
263
|
+
)
|
274
264
|
|
275
265
|
rank = get_tensor_model_parallel_rank()
|
276
266
|
bytes_per_page = (
|
@@ -288,6 +278,16 @@ class HiCacheController:
|
|
288
278
|
self.enable_storage = True
|
289
279
|
# todo: threshold policy for prefetching
|
290
280
|
self.prefetch_threshold = max(prefetch_threshold, self.page_size)
|
281
|
+
# create a new communication group for synchronizing storage operations across TP workers
|
282
|
+
self.tp_world_size = torch.distributed.get_world_size(group=tp_group)
|
283
|
+
if self.tp_world_size > 1:
|
284
|
+
group_ranks = torch.distributed.get_process_group_ranks(tp_group)
|
285
|
+
self.prefetch_tp_group = torch.distributed.new_group(
|
286
|
+
group_ranks, backend="gloo"
|
287
|
+
)
|
288
|
+
self.backup_tp_group = torch.distributed.new_group(
|
289
|
+
group_ranks, backend="gloo"
|
290
|
+
)
|
291
291
|
|
292
292
|
self.load_cache_event = load_cache_event
|
293
293
|
self.layer_done_counter = LayerDoneCounter(self.mem_pool_device.layer_num)
|
@@ -439,11 +439,8 @@ class HiCacheController:
|
|
439
439
|
host_indices, device_indices = self.move_indices(
|
440
440
|
operation.host_indices, operation.device_indices
|
441
441
|
)
|
442
|
-
self.
|
443
|
-
self.
|
444
|
-
host_indices,
|
445
|
-
device_indices,
|
446
|
-
self.io_backend,
|
442
|
+
self.mem_pool_host.backup_from_device_all_layer(
|
443
|
+
self.mem_pool_device, host_indices, device_indices, self.io_backend
|
447
444
|
)
|
448
445
|
self.write_stream.synchronize()
|
449
446
|
self.mem_pool_host.complete_io(operation.host_indices)
|
@@ -483,8 +480,8 @@ class HiCacheController:
|
|
483
480
|
batch_operation.host_indices, batch_operation.device_indices
|
484
481
|
)
|
485
482
|
for i in range(self.mem_pool_host.layer_num):
|
486
|
-
self.
|
487
|
-
self.
|
483
|
+
self.mem_pool_host.load_to_device_per_layer(
|
484
|
+
self.mem_pool_device,
|
488
485
|
host_indices,
|
489
486
|
device_indices,
|
490
487
|
i,
|
@@ -545,7 +542,11 @@ class HiCacheController:
|
|
545
542
|
def generic_page_transfer(self, operation, batch_size=8):
|
546
543
|
for i in range(0, len(operation.hash_value), batch_size):
|
547
544
|
page_hashes = operation.hash_value[i : i + batch_size]
|
548
|
-
|
545
|
+
# todo: zero copy
|
546
|
+
dummy_page_dst = [self.mem_pool_host.get_dummy_flat_data_page()] * len(
|
547
|
+
page_hashes
|
548
|
+
)
|
549
|
+
page_data = self.storage_backend.batch_get(page_hashes, dummy_page_dst)
|
549
550
|
if page_data is None:
|
550
551
|
logger.warning(
|
551
552
|
f"Prefetch operation {operation.request_id} failed to retrieve page {page_hashes}."
|
@@ -573,6 +574,9 @@ class HiCacheController:
|
|
573
574
|
self.storage_backend.batch_get(key_strs, buffer_ptrs, buffer_sizes)
|
574
575
|
operation.increment(len(operation.hash_value) * self.page_size)
|
575
576
|
|
577
|
+
def is_mooncake_backend(self):
|
578
|
+
return self.storage_backend_type == "mooncake"
|
579
|
+
|
576
580
|
def prefetch_io_aux_func(self):
|
577
581
|
"""
|
578
582
|
Auxiliary function conducting IO operations for prefetching.
|
@@ -580,7 +584,7 @@ class HiCacheController:
|
|
580
584
|
while not self.stop_event.is_set():
|
581
585
|
try:
|
582
586
|
operation = self.prefetch_buffer.get(block=True, timeout=1)
|
583
|
-
if
|
587
|
+
if self.is_mooncake_backend():
|
584
588
|
self.mooncake_page_transfer(operation)
|
585
589
|
else:
|
586
590
|
self.generic_page_transfer(operation)
|
@@ -615,14 +619,14 @@ class HiCacheController:
|
|
615
619
|
)
|
616
620
|
|
617
621
|
# todo, more unified interface
|
618
|
-
if not
|
622
|
+
if not self.is_mooncake_backend():
|
619
623
|
if not self.storage_backend.exists(last_hash):
|
620
624
|
break
|
621
625
|
hash_value.append(last_hash)
|
622
626
|
storage_hit_count += self.page_size
|
623
627
|
remaining_tokens -= self.page_size
|
624
628
|
|
625
|
-
if
|
629
|
+
if self.is_mooncake_backend():
|
626
630
|
# deferring to batch exists for mooncake store
|
627
631
|
exist_result = self.storage_backend.exists(hash_value)
|
628
632
|
storage_hit_count = (
|
@@ -679,7 +683,7 @@ class HiCacheController:
|
|
679
683
|
for i in range(0, len(operation.hash_value), batch_size):
|
680
684
|
page_hashes = operation.hash_value[i : i + batch_size]
|
681
685
|
page_data = [
|
682
|
-
self.mem_pool_host.
|
686
|
+
self.mem_pool_host.get_flat_data_page(
|
683
687
|
operation.host_indices[j * self.page_size]
|
684
688
|
)
|
685
689
|
for j in range(i, i + len(page_hashes))
|
@@ -744,7 +748,7 @@ class HiCacheController:
|
|
744
748
|
remaining_tokens -= self.page_size
|
745
749
|
operation.hash_value = hash_value
|
746
750
|
|
747
|
-
if
|
751
|
+
if self.is_mooncake_backend():
|
748
752
|
self.mooncake_page_backup(operation)
|
749
753
|
else:
|
750
754
|
self.generic_page_backup(operation)
|
@@ -16,9 +16,13 @@
|
|
16
16
|
import logging
|
17
17
|
import multiprocessing as mp
|
18
18
|
import signal
|
19
|
+
import struct
|
20
|
+
import sys
|
19
21
|
import threading
|
20
22
|
import time
|
21
23
|
from enum import Enum, auto
|
24
|
+
from multiprocessing import shared_memory
|
25
|
+
from typing import Dict, List
|
22
26
|
|
23
27
|
import psutil
|
24
28
|
import setproctitle
|
@@ -32,6 +36,7 @@ from sglang.srt.managers.io_struct import (
|
|
32
36
|
)
|
33
37
|
from sglang.srt.managers.schedule_batch import Req
|
34
38
|
from sglang.srt.managers.scheduler import run_scheduler_process
|
39
|
+
from sglang.srt.managers.utils import DPBalanceMeta
|
35
40
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
36
41
|
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
37
42
|
from sglang.srt.utils import bind_port, configure_logger, get_zmq_socket
|
@@ -45,6 +50,7 @@ class LoadBalanceMethod(Enum):
|
|
45
50
|
|
46
51
|
ROUND_ROBIN = auto()
|
47
52
|
SHORTEST_QUEUE = auto()
|
53
|
+
MINIMUM_TOKENS = auto()
|
48
54
|
|
49
55
|
@classmethod
|
50
56
|
def from_str(cls, method: str):
|
@@ -58,7 +64,16 @@ class LoadBalanceMethod(Enum):
|
|
58
64
|
class DataParallelController:
|
59
65
|
"""A controller that dispatches requests to multiple data parallel workers."""
|
60
66
|
|
61
|
-
def __init__(
|
67
|
+
def __init__(
|
68
|
+
self,
|
69
|
+
server_args: ServerArgs,
|
70
|
+
port_args: PortArgs,
|
71
|
+
dp_balance_meta: DPBalanceMeta,
|
72
|
+
) -> None:
|
73
|
+
# for dp balance
|
74
|
+
self.global_balance_id = 0
|
75
|
+
self.balance_meta = dp_balance_meta
|
76
|
+
|
62
77
|
# Parse args
|
63
78
|
self.max_total_num_tokens = None
|
64
79
|
self.server_args = server_args
|
@@ -79,6 +94,7 @@ class DataParallelController:
|
|
79
94
|
dispatch_lookup = {
|
80
95
|
LoadBalanceMethod.ROUND_ROBIN: self.round_robin_scheduler,
|
81
96
|
LoadBalanceMethod.SHORTEST_QUEUE: self.shortest_queue_scheduler,
|
97
|
+
LoadBalanceMethod.MINIMUM_TOKENS: self.minimum_tokens_scheduler,
|
82
98
|
}
|
83
99
|
self.dispatching = dispatch_lookup[self.load_balance_method]
|
84
100
|
|
@@ -234,6 +250,7 @@ class DataParallelController:
|
|
234
250
|
pp_rank,
|
235
251
|
dp_rank,
|
236
252
|
writer,
|
253
|
+
self.balance_meta,
|
237
254
|
),
|
238
255
|
)
|
239
256
|
with memory_saver_adapter.configure_subprocess():
|
@@ -269,6 +286,33 @@ class DataParallelController:
|
|
269
286
|
def shortest_queue_scheduler(self, input_requests):
|
270
287
|
raise NotImplementedError()
|
271
288
|
|
289
|
+
def minimum_tokens_scheduler(self, req):
|
290
|
+
# This variable corresponds to the balance_id in TokenizedGenerateReqInput.
|
291
|
+
# We use it to to control the number of onfly tokens (requests dispatched to workers but not yet received).
|
292
|
+
def get_next_global_balance_id() -> int:
|
293
|
+
INT32_MAX = 2147483647
|
294
|
+
current_id = self.global_balance_id
|
295
|
+
self.global_balance_id = (self.global_balance_id + 1) % INT32_MAX
|
296
|
+
return current_id
|
297
|
+
|
298
|
+
req.dp_balance_id = get_next_global_balance_id()
|
299
|
+
with self.balance_meta.mutex:
|
300
|
+
# 1. local_tokens represents the tokens currently inferring on the worker,
|
301
|
+
# while onfly refers to the requests dispatched by the dispatcher but not yet received by the scheduler.
|
302
|
+
onfly_info = self.balance_meta.get_shared_onfly()
|
303
|
+
local_tokens = self.balance_meta.get_shared_local_tokens()
|
304
|
+
total_tokens = [
|
305
|
+
local_token + sum(onfly_dict.values())
|
306
|
+
for local_token, onfly_dict in zip(local_tokens, onfly_info)
|
307
|
+
]
|
308
|
+
target_worker = total_tokens.index(min(total_tokens))
|
309
|
+
onfly_info[target_worker][req.dp_balance_id] = len(req.input_ids)
|
310
|
+
# 2. write the new onfly info to the shm
|
311
|
+
self.balance_meta.set_shared_onfly_info(onfly_info)
|
312
|
+
|
313
|
+
# logger.info(f"dp workers {local_tokens=}, {onfly_info=}, {target_worker=}")
|
314
|
+
self.workers[target_worker].send_pyobj(req)
|
315
|
+
|
272
316
|
def event_loop(self):
|
273
317
|
while True:
|
274
318
|
while True:
|
@@ -302,9 +346,12 @@ def run_data_parallel_controller_process(
|
|
302
346
|
setproctitle.setproctitle("sglang::data_parallel_controller")
|
303
347
|
configure_logger(server_args)
|
304
348
|
parent_process = psutil.Process().parent()
|
349
|
+
balance_meta = DPBalanceMeta(server_args.dp_size)
|
305
350
|
|
306
351
|
try:
|
307
|
-
controller = DataParallelController(
|
352
|
+
controller = DataParallelController(
|
353
|
+
server_args, port_args, dp_balance_meta=balance_meta
|
354
|
+
)
|
308
355
|
pipe_writer.send(
|
309
356
|
{
|
310
357
|
"status": "ready",
|
@@ -323,3 +370,6 @@ def run_data_parallel_controller_process(
|
|
323
370
|
traceback = get_exception_traceback()
|
324
371
|
logger.error(f"DataParallelController hit an exception: {traceback}")
|
325
372
|
parent_process.send_signal(signal.SIGQUIT)
|
373
|
+
finally:
|
374
|
+
# we need to destruct mp.Manager() in balance_meta
|
375
|
+
balance_meta.destructor()
|
sglang/srt/managers/io_struct.py
CHANGED
@@ -523,6 +523,9 @@ class TokenizedGenerateReqInput:
|
|
523
523
|
# For data parallel rank routing
|
524
524
|
data_parallel_rank: Optional[int] = None
|
525
525
|
|
526
|
+
# For dp balance
|
527
|
+
dp_balance_id: int = -1
|
528
|
+
|
526
529
|
|
527
530
|
@dataclass
|
528
531
|
class EmbeddingReqInput:
|
@@ -648,6 +651,8 @@ class TokenizedEmbeddingReqInput:
|
|
648
651
|
token_type_ids: List[int]
|
649
652
|
# Dummy sampling params for compatibility
|
650
653
|
sampling_params: SamplingParams
|
654
|
+
# For dp balance
|
655
|
+
dp_balance_id: int = -1
|
651
656
|
|
652
657
|
|
653
658
|
@dataclass
|
@@ -1097,7 +1102,7 @@ class UnloadLoRAAdapterReqInput:
|
|
1097
1102
|
class LoRAUpdateResult:
|
1098
1103
|
success: bool
|
1099
1104
|
error_message: Optional[str] = None
|
1100
|
-
loaded_adapters: Dict[str, LoRARef] =
|
1105
|
+
loaded_adapters: Optional[Dict[str, LoRARef]] = None
|
1101
1106
|
|
1102
1107
|
|
1103
1108
|
LoadLoRAAdapterReqOutput = UnloadLoRAAdapterReqOutput = LoRAUpdateResult
|
@@ -51,6 +51,7 @@ from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
|
|
51
51
|
ScheduleBatchDisaggregationDecodeMixin,
|
52
52
|
)
|
53
53
|
from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank
|
54
|
+
from sglang.srt.layers.moe.utils import DeepEPMode, MoeA2ABackend
|
54
55
|
from sglang.srt.mem_cache.allocator import (
|
55
56
|
BaseTokenToKVPoolAllocator,
|
56
57
|
SWATokenToKVPoolAllocator,
|
@@ -85,9 +86,8 @@ GLOBAL_SERVER_ARGS_KEYS = [
|
|
85
86
|
"enable_dp_attention",
|
86
87
|
"enable_two_batch_overlap",
|
87
88
|
"enable_dp_lm_head",
|
88
|
-
"
|
89
|
+
"moe_a2a_backend",
|
89
90
|
"deepep_mode",
|
90
|
-
"enable_ep_moe",
|
91
91
|
"enable_flashinfer_cutlass_moe",
|
92
92
|
"enable_flashinfer_trtllm_moe",
|
93
93
|
"enable_flashinfer_allreduce_fusion",
|
@@ -108,6 +108,7 @@ GLOBAL_SERVER_ARGS_KEYS = [
|
|
108
108
|
"weight_loader_disable_mmap",
|
109
109
|
"enable_triton_kernel_moe",
|
110
110
|
"enable_multimodal",
|
111
|
+
"enable_symm_mem",
|
111
112
|
]
|
112
113
|
|
113
114
|
# Put some global args for easy access
|
@@ -455,7 +455,9 @@ class PrefillAdder:
|
|
455
455
|
if not self.is_hybrid:
|
456
456
|
# Skip this logic for swa. The SWA has different memory management, and
|
457
457
|
# this mechanism is underestimating the memory usage.
|
458
|
-
cur_rem_tokens = self.cur_rem_tokens -
|
458
|
+
cur_rem_tokens = self.cur_rem_tokens - self.ceil_paged_tokens(
|
459
|
+
req.extend_input_len
|
460
|
+
)
|
459
461
|
tokens_freed = 0
|
460
462
|
for i, (tokens_left, tokens_occupied) in enumerate(self.req_states):
|
461
463
|
# tokens_left gives a reservative calculation as the last token is not stored
|
sglang/srt/managers/scheduler.py
CHANGED
@@ -64,6 +64,7 @@ from sglang.srt.hf_transformers_utils import (
|
|
64
64
|
)
|
65
65
|
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
|
66
66
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
67
|
+
from sglang.srt.layers.moe.utils import DeepEPMode, MoeA2ABackend
|
67
68
|
from sglang.srt.managers.io_struct import (
|
68
69
|
AbortReq,
|
69
70
|
CloseSessionReqInput,
|
@@ -125,7 +126,7 @@ from sglang.srt.managers.scheduler_update_weights_mixin import (
|
|
125
126
|
from sglang.srt.managers.session_controller import Session
|
126
127
|
from sglang.srt.managers.tp_worker import TpModelWorker
|
127
128
|
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
|
128
|
-
from sglang.srt.managers.utils import validate_input_length
|
129
|
+
from sglang.srt.managers.utils import DPBalanceMeta, validate_input_length
|
129
130
|
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
|
130
131
|
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
|
131
132
|
from sglang.srt.mem_cache.radix_cache import RadixCache
|
@@ -137,7 +138,6 @@ from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
|
137
138
|
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
138
139
|
from sglang.srt.two_batch_overlap import TboDPAttentionPreparer
|
139
140
|
from sglang.srt.utils import (
|
140
|
-
DeepEPMode,
|
141
141
|
DynamicGradMode,
|
142
142
|
broadcast_pyobj,
|
143
143
|
configure_gc_logger,
|
@@ -203,6 +203,7 @@ class Scheduler(
|
|
203
203
|
moe_ep_rank: int,
|
204
204
|
pp_rank: int,
|
205
205
|
dp_rank: Optional[int],
|
206
|
+
dp_balance_meta: Optional[DPBalanceMeta] = None,
|
206
207
|
):
|
207
208
|
# Parse args
|
208
209
|
self.server_args = server_args
|
@@ -522,6 +523,15 @@ class Scheduler(
|
|
522
523
|
]
|
523
524
|
)
|
524
525
|
|
526
|
+
self.balance_meta = dp_balance_meta
|
527
|
+
if (
|
528
|
+
server_args.enable_dp_attention
|
529
|
+
and server_args.load_balance_method == "minimum_tokens"
|
530
|
+
):
|
531
|
+
assert dp_balance_meta is not None
|
532
|
+
|
533
|
+
self.recv_dp_balance_id_this_term = []
|
534
|
+
|
525
535
|
def init_tokenizer(self):
|
526
536
|
server_args = self.server_args
|
527
537
|
|
@@ -569,7 +579,23 @@ class Scheduler(
|
|
569
579
|
page_size=self.page_size,
|
570
580
|
)
|
571
581
|
else:
|
572
|
-
if
|
582
|
+
if os.environ.get("SGLANG_EXPERIMENTAL_CPP_RADIX_TREE") == "1":
|
583
|
+
# lazy import to avoid JIT overhead
|
584
|
+
from sglang.srt.mem_cache.radix_cache_cpp import RadixCacheCpp
|
585
|
+
|
586
|
+
self.tree_cache = RadixCacheCpp(
|
587
|
+
disable=False,
|
588
|
+
use_hicache=self.enable_hierarchical_cache,
|
589
|
+
req_to_token_pool=self.req_to_token_pool,
|
590
|
+
token_to_kv_pool=self.token_to_kv_pool_allocator,
|
591
|
+
tp_cache_group=self.tp_cpu_group,
|
592
|
+
page_size=self.page_size,
|
593
|
+
hicache_ratio=server_args.hicache_ratio,
|
594
|
+
hicache_size=server_args.hicache_size,
|
595
|
+
hicache_write_policy=server_args.hicache_write_policy,
|
596
|
+
enable_kv_cache_events=self.enable_kv_cache_events,
|
597
|
+
)
|
598
|
+
elif self.enable_hierarchical_cache:
|
573
599
|
self.tree_cache = HiRadixCache(
|
574
600
|
req_to_token_pool=self.req_to_token_pool,
|
575
601
|
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
@@ -588,6 +614,7 @@ class Scheduler(
|
|
588
614
|
== "fa3" # hot fix for incompatibility
|
589
615
|
else server_args.hicache_io_backend
|
590
616
|
),
|
617
|
+
hicache_mem_layout=server_args.hicache_mem_layout,
|
591
618
|
hicache_storage_backend=server_args.hicache_storage_backend,
|
592
619
|
)
|
593
620
|
self.tp_worker.register_hicache_layer_transfer_counter(
|
@@ -1032,6 +1059,12 @@ class Scheduler(
|
|
1032
1059
|
self,
|
1033
1060
|
recv_req: TokenizedGenerateReqInput,
|
1034
1061
|
):
|
1062
|
+
if (
|
1063
|
+
self.server_args.enable_dp_attention
|
1064
|
+
and self.server_args.load_balance_method == "minimum_tokens"
|
1065
|
+
):
|
1066
|
+
self.recv_dp_balance_id_this_term.append(recv_req.dp_balance_id)
|
1067
|
+
|
1035
1068
|
# Create a new request
|
1036
1069
|
if (
|
1037
1070
|
recv_req.session_params is None
|
@@ -1442,6 +1475,11 @@ class Scheduler(
|
|
1442
1475
|
|
1443
1476
|
# Handle DP attention
|
1444
1477
|
if need_dp_attn_preparation:
|
1478
|
+
if (
|
1479
|
+
self.server_args.load_balance_method == "minimum_tokens"
|
1480
|
+
and self.forward_ct % 40 == 0
|
1481
|
+
):
|
1482
|
+
self.handle_dp_balance_data(ret)
|
1445
1483
|
ret = self.prepare_mlp_sync_batch(ret)
|
1446
1484
|
|
1447
1485
|
return ret
|
@@ -1743,6 +1781,9 @@ class Scheduler(
|
|
1743
1781
|
elif batch.forward_mode.is_dummy_first():
|
1744
1782
|
self.set_next_batch_sampling_info_done(batch)
|
1745
1783
|
|
1784
|
+
self.maybe_send_health_check_signal()
|
1785
|
+
|
1786
|
+
def maybe_send_health_check_signal(self):
|
1746
1787
|
if self.return_health_check_ct:
|
1747
1788
|
# Return some signal for the health check.
|
1748
1789
|
# This is used to prevent the health check signal being blocked by long context prefill.
|
@@ -1761,12 +1802,94 @@ class Scheduler(
|
|
1761
1802
|
spec_algorithm=self.spec_algorithm,
|
1762
1803
|
speculative_num_draft_tokens=self.server_args.speculative_num_draft_tokens,
|
1763
1804
|
enable_two_batch_overlap=self.server_args.enable_two_batch_overlap,
|
1764
|
-
enable_deepep_moe=
|
1765
|
-
|
1805
|
+
enable_deepep_moe=MoeA2ABackend(
|
1806
|
+
self.server_args.moe_a2a_backend
|
1807
|
+
).is_deepep(),
|
1808
|
+
deepep_mode=DeepEPMode(self.server_args.deepep_mode),
|
1766
1809
|
require_mlp_tp_gather=require_mlp_tp_gather(self.server_args),
|
1767
1810
|
disable_overlap_schedule=self.server_args.disable_overlap_schedule,
|
1768
1811
|
)
|
1769
1812
|
|
1813
|
+
def handle_dp_balance_data(self, local_batch: ScheduleBatch):
|
1814
|
+
def gather_dp_balance_info(holding_tokens_list) -> Union[None, List[List[int]]]:
|
1815
|
+
"""gather recv_dp_balance_id_this_term and holding tokens per worker for dp balance"""
|
1816
|
+
recv_list = self.recv_dp_balance_id_this_term
|
1817
|
+
assert len(recv_list) <= 511, (
|
1818
|
+
"The number of requests received this round is too large. "
|
1819
|
+
"Please increase gather_tensor_size and onfly_info_size."
|
1820
|
+
)
|
1821
|
+
# The maximum size of the tensor used for gathering data from all workers.
|
1822
|
+
gather_tensor_size = 512
|
1823
|
+
|
1824
|
+
# recv_tensor: | holding_tokens | len(recv_dp_balance_id) | recv_dp_balance_ids
|
1825
|
+
recv_tensor = torch.zeros(gather_tensor_size, dtype=torch.int32)
|
1826
|
+
recv_tensor[0] = holding_tokens_list
|
1827
|
+
recv_tensor[1] = len(
|
1828
|
+
recv_list
|
1829
|
+
) # The first element is the length of the list.
|
1830
|
+
recv_tensor[2 : len(recv_list) + 2] = torch.tensor(
|
1831
|
+
recv_list, dtype=torch.int32
|
1832
|
+
)
|
1833
|
+
|
1834
|
+
if self.tp_rank == 0:
|
1835
|
+
gathered_list = [
|
1836
|
+
torch.zeros(gather_tensor_size, dtype=torch.int32)
|
1837
|
+
for _ in range(self.balance_meta.num_workers)
|
1838
|
+
]
|
1839
|
+
else:
|
1840
|
+
gathered_list = None
|
1841
|
+
|
1842
|
+
torch.distributed.gather(
|
1843
|
+
recv_tensor, gathered_list, group=self.tp_cpu_group
|
1844
|
+
)
|
1845
|
+
|
1846
|
+
gathered_id_list_per_worker = None
|
1847
|
+
if self.tp_rank == 0:
|
1848
|
+
gathered_id_list_per_worker = []
|
1849
|
+
holding_tokens_list = []
|
1850
|
+
for tensor in gathered_list:
|
1851
|
+
holding_tokens_list.append(tensor[0].item())
|
1852
|
+
list_length = tensor[1].item()
|
1853
|
+
gathered_id_list_per_worker.append(
|
1854
|
+
tensor[2 : list_length + 2].tolist()
|
1855
|
+
)
|
1856
|
+
|
1857
|
+
return gathered_id_list_per_worker, holding_tokens_list
|
1858
|
+
|
1859
|
+
def write_shared_dp_balance_info(new_recv_rid_lists, local_tokens):
|
1860
|
+
meta = self.balance_meta
|
1861
|
+
|
1862
|
+
with meta.mutex:
|
1863
|
+
onfly_list: List[Dict[int, int]] = meta.get_shared_onfly()
|
1864
|
+
assert len(new_recv_rid_lists) == len(
|
1865
|
+
onfly_list
|
1866
|
+
), "num_worker not equal"
|
1867
|
+
# 1.Check if the rid received by each worker this round is present in onfly.
|
1868
|
+
# If it is, remove the corresponding onfly item.
|
1869
|
+
worker_id = 0
|
1870
|
+
for new_recv_rids, on_fly_reqs in zip(new_recv_rid_lists, onfly_list):
|
1871
|
+
for new_recv_rid in new_recv_rids:
|
1872
|
+
assert (
|
1873
|
+
new_recv_rid in on_fly_reqs
|
1874
|
+
), f"{new_recv_rid=} not in {worker_id=} {on_fly_reqs=}, data consistency is wrong"
|
1875
|
+
del on_fly_reqs[new_recv_rid]
|
1876
|
+
worker_id += 1
|
1877
|
+
# 2. Atomically write local_tokens and onfly into shm under the mutex
|
1878
|
+
meta.set_shared_onfly_info(onfly_list)
|
1879
|
+
meta.set_shared_local_tokens(local_tokens)
|
1880
|
+
|
1881
|
+
holding_tokens = self.get_load()
|
1882
|
+
|
1883
|
+
new_recv_dp_balance_id_list, holding_token_list = gather_dp_balance_info(
|
1884
|
+
holding_tokens
|
1885
|
+
)
|
1886
|
+
|
1887
|
+
self.recv_dp_balance_id_this_term.clear()
|
1888
|
+
if self.tp_rank == 0: # only first worker write info
|
1889
|
+
write_shared_dp_balance_info(
|
1890
|
+
new_recv_dp_balance_id_list, holding_token_list
|
1891
|
+
)
|
1892
|
+
|
1770
1893
|
@staticmethod
|
1771
1894
|
def prepare_mlp_sync_batch_raw(
|
1772
1895
|
local_batch: ScheduleBatch,
|
@@ -2343,11 +2466,19 @@ class IdleSleeper:
|
|
2343
2466
|
|
2344
2467
|
def __init__(self, sockets):
|
2345
2468
|
self.poller = zmq.Poller()
|
2469
|
+
self.last_empty_time = time.time()
|
2346
2470
|
for s in sockets:
|
2347
2471
|
self.poller.register(s, zmq.POLLIN)
|
2348
2472
|
|
2349
2473
|
def maybe_sleep(self):
|
2350
2474
|
self.poller.poll(1000)
|
2475
|
+
if (
|
2476
|
+
global_config.torch_empty_cache_interval > 0
|
2477
|
+
and time.time() - self.last_empty_time
|
2478
|
+
> global_config.torch_empty_cache_interval
|
2479
|
+
):
|
2480
|
+
self.last_empty_time = time.time()
|
2481
|
+
torch.cuda.empty_cache()
|
2351
2482
|
|
2352
2483
|
|
2353
2484
|
def is_health_check_generate_req(recv_req):
|
@@ -2367,6 +2498,7 @@ def run_scheduler_process(
|
|
2367
2498
|
pp_rank: int,
|
2368
2499
|
dp_rank: Optional[int],
|
2369
2500
|
pipe_writer,
|
2501
|
+
balance_meta: Optional[DPBalanceMeta] = None,
|
2370
2502
|
):
|
2371
2503
|
# Generate the prefix
|
2372
2504
|
prefix = ""
|
@@ -2400,7 +2532,14 @@ def run_scheduler_process(
|
|
2400
2532
|
# Create a scheduler and run the event loop
|
2401
2533
|
try:
|
2402
2534
|
scheduler = Scheduler(
|
2403
|
-
server_args,
|
2535
|
+
server_args,
|
2536
|
+
port_args,
|
2537
|
+
gpu_id,
|
2538
|
+
tp_rank,
|
2539
|
+
moe_ep_rank,
|
2540
|
+
pp_rank,
|
2541
|
+
dp_rank,
|
2542
|
+
dp_balance_meta=balance_meta,
|
2404
2543
|
)
|
2405
2544
|
pipe_writer.send(
|
2406
2545
|
{
|
@@ -84,26 +84,27 @@ class TemplateManager:
|
|
84
84
|
if chat_template_arg:
|
85
85
|
self._load_explicit_chat_template(tokenizer_manager, chat_template_arg)
|
86
86
|
else:
|
87
|
-
#
|
88
|
-
hf_template = self._resolve_hf_chat_template(tokenizer_manager)
|
89
|
-
if hf_template:
|
90
|
-
self._jinja_template_content_format = (
|
91
|
-
detect_jinja_template_content_format(hf_template)
|
92
|
-
)
|
93
|
-
logger.info(
|
94
|
-
f"Using default HuggingFace chat template with detected content format: {self._jinja_template_content_format}"
|
95
|
-
)
|
96
|
-
return
|
97
|
-
|
98
|
-
# Fallback to SGLang template guessing
|
87
|
+
# Guess chat template from model path
|
99
88
|
self.guess_chat_template_from_model_path(model_path)
|
100
89
|
|
101
|
-
#
|
90
|
+
# If no pre-defined template was found, fallback to HuggingFace template
|
102
91
|
if self._chat_template_name is None:
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
92
|
+
# Try HuggingFace template first
|
93
|
+
hf_template = self._resolve_hf_chat_template(tokenizer_manager)
|
94
|
+
if hf_template:
|
95
|
+
# override the chat template
|
96
|
+
tokenizer_manager.tokenizer.chat_template = hf_template
|
97
|
+
self._jinja_template_content_format = (
|
98
|
+
detect_jinja_template_content_format(hf_template)
|
99
|
+
)
|
100
|
+
logger.info(
|
101
|
+
f"Using default HuggingFace chat template with detected content format: {self._jinja_template_content_format}"
|
102
|
+
)
|
103
|
+
return
|
104
|
+
|
105
|
+
# Default to string content format if no template was found
|
106
|
+
self._jinja_template_content_format = "string"
|
107
|
+
logger.info("No chat template found, defaulting to 'string' content format")
|
107
108
|
|
108
109
|
def _load_explicit_chat_template(
|
109
110
|
self, tokenizer_manager, chat_template_arg: str
|
@@ -257,13 +258,15 @@ class TemplateManager:
|
|
257
258
|
|
258
259
|
Returns the chat template string if found, None otherwise.
|
259
260
|
"""
|
260
|
-
tokenizer = tokenizer_manager.tokenizer
|
261
|
-
|
262
|
-
# Try to get AutoTokenizer chat template
|
263
261
|
try:
|
264
|
-
|
262
|
+
if processor := tokenizer_manager.processor:
|
263
|
+
if hasattr(processor, "chat_template") and processor.chat_template:
|
264
|
+
return processor.chat_template
|
265
|
+
if tokenizer := tokenizer_manager.tokenizer:
|
266
|
+
if hasattr(tokenizer, "chat_template") and tokenizer.chat_template:
|
267
|
+
return tokenizer.chat_template
|
265
268
|
except Exception as e:
|
266
|
-
logger.debug(f"Error getting chat template
|
269
|
+
logger.debug(f"Error getting chat template: {e}")
|
267
270
|
|
268
271
|
logger.debug("No HuggingFace chat template found")
|
269
272
|
return None
|