sglang 0.4.10.post1__py3-none-any.whl → 0.5.0rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +113 -17
- sglang/compile_deep_gemm.py +8 -1
- sglang/global_config.py +5 -1
- sglang/srt/configs/model_config.py +35 -0
- sglang/srt/conversation.py +9 -117
- sglang/srt/disaggregation/base/conn.py +5 -2
- sglang/srt/disaggregation/decode.py +6 -1
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -0
- sglang/srt/disaggregation/mooncake/conn.py +243 -135
- sglang/srt/disaggregation/prefill.py +3 -0
- sglang/srt/distributed/device_communicators/pynccl.py +7 -0
- sglang/srt/distributed/device_communicators/pynccl_allocator.py +133 -0
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +42 -3
- sglang/srt/distributed/parallel_state.py +22 -9
- sglang/srt/entrypoints/context.py +244 -0
- sglang/srt/entrypoints/engine.py +8 -5
- sglang/srt/entrypoints/harmony_utils.py +370 -0
- sglang/srt/entrypoints/http_server.py +106 -15
- sglang/srt/entrypoints/openai/protocol.py +227 -1
- sglang/srt/entrypoints/openai/serving_chat.py +278 -42
- sglang/srt/entrypoints/openai/serving_responses.py +1273 -0
- sglang/srt/entrypoints/openai/tool_server.py +174 -0
- sglang/srt/entrypoints/tool.py +87 -0
- sglang/srt/eplb/expert_distribution.py +4 -2
- sglang/srt/eplb/expert_location.py +5 -1
- sglang/srt/function_call/harmony_tool_parser.py +130 -0
- sglang/srt/hf_transformers_utils.py +55 -13
- sglang/srt/jinja_template_utils.py +8 -1
- sglang/srt/layers/attention/aiter_backend.py +5 -8
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1700 -0
- sglang/srt/layers/attention/flashattention_backend.py +7 -11
- sglang/srt/layers/attention/triton_backend.py +85 -14
- sglang/srt/layers/attention/triton_ops/decode_attention.py +17 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +143 -98
- sglang/srt/layers/attention/trtllm_mha_backend.py +332 -0
- sglang/srt/layers/attention/trtllm_mla_backend.py +6 -6
- sglang/srt/layers/attention/vision.py +40 -15
- sglang/srt/layers/communicator.py +35 -8
- sglang/srt/layers/dp_attention.py +12 -0
- sglang/srt/layers/linear.py +9 -8
- sglang/srt/layers/logits_processor.py +9 -1
- sglang/srt/layers/moe/cutlass_moe.py +20 -6
- sglang/srt/layers/moe/ep_moe/layer.py +87 -107
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=352,device_name=NVIDIA_RTX_6000_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +101 -12
- sglang/srt/layers/moe/fused_moe_triton/layer.py +442 -58
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +169 -15
- sglang/srt/layers/moe/token_dispatcher/__init__.py +23 -0
- sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +12 -1
- sglang/srt/layers/moe/{ep_moe/token_dispatcher.py → token_dispatcher/deepep.py} +8 -15
- sglang/srt/layers/moe/topk.py +12 -3
- sglang/srt/layers/moe/utils.py +59 -0
- sglang/srt/layers/quantization/__init__.py +22 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +3 -2
- sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +1 -1
- sglang/srt/layers/quantization/fp4.py +557 -0
- sglang/srt/layers/quantization/fp8.py +8 -7
- sglang/srt/layers/quantization/fp8_kernel.py +0 -4
- sglang/srt/layers/quantization/fp8_utils.py +29 -0
- sglang/srt/layers/quantization/modelopt_quant.py +259 -64
- sglang/srt/layers/quantization/mxfp4.py +651 -0
- sglang/srt/layers/quantization/mxfp4_tensor.py +133 -0
- sglang/srt/layers/quantization/quark/__init__.py +0 -0
- sglang/srt/layers/quantization/quark/schemes/__init__.py +6 -0
- sglang/srt/layers/quantization/quark/schemes/quark_scheme.py +55 -0
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +118 -0
- sglang/srt/layers/quantization/quark/utils.py +107 -0
- sglang/srt/layers/quantization/unquant.py +60 -6
- sglang/srt/layers/quantization/w4afp8.py +1 -1
- sglang/srt/layers/rotary_embedding.py +225 -1
- sglang/srt/layers/utils.py +9 -0
- sglang/srt/layers/vocab_parallel_embedding.py +15 -4
- sglang/srt/lora/lora_manager.py +70 -14
- sglang/srt/lora/lora_registry.py +10 -2
- sglang/srt/lora/mem_pool.py +43 -5
- sglang/srt/managers/cache_controller.py +61 -32
- sglang/srt/managers/data_parallel_controller.py +52 -2
- sglang/srt/managers/detokenizer_manager.py +1 -1
- sglang/srt/managers/io_struct.py +21 -4
- sglang/srt/managers/mm_utils.py +5 -11
- sglang/srt/managers/schedule_batch.py +30 -8
- sglang/srt/managers/schedule_policy.py +3 -1
- sglang/srt/managers/scheduler.py +170 -18
- sglang/srt/managers/scheduler_output_processor_mixin.py +1 -2
- sglang/srt/managers/scheduler_recv_skipper.py +37 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +6 -0
- sglang/srt/managers/template_manager.py +59 -22
- sglang/srt/managers/tokenizer_manager.py +137 -67
- sglang/srt/managers/tp_worker.py +3 -0
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -0
- sglang/srt/managers/utils.py +45 -1
- sglang/srt/mem_cache/cpp_radix_tree/radix_tree.py +182 -0
- sglang/srt/mem_cache/hicache_storage.py +13 -21
- sglang/srt/mem_cache/hiradix_cache.py +53 -5
- sglang/srt/mem_cache/memory_pool_host.py +1 -1
- sglang/srt/mem_cache/multimodal_cache.py +33 -13
- sglang/srt/mem_cache/radix_cache_cpp.py +229 -0
- sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +2 -2
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_utils.cpp +35 -0
- sglang/srt/model_executor/cuda_graph_runner.py +24 -9
- sglang/srt/model_executor/forward_batch_info.py +48 -17
- sglang/srt/model_executor/model_runner.py +24 -2
- sglang/srt/model_loader/weight_utils.py +10 -0
- sglang/srt/models/bailing_moe.py +425 -0
- sglang/srt/models/deepseek_v2.py +95 -50
- sglang/srt/models/ernie4.py +426 -0
- sglang/srt/models/ernie4_eagle.py +203 -0
- sglang/srt/models/gemma3n_mm.py +39 -0
- sglang/srt/models/glm4_moe.py +102 -27
- sglang/srt/models/gpt_oss.py +1134 -0
- sglang/srt/models/grok.py +3 -3
- sglang/srt/models/llama4.py +13 -2
- sglang/srt/models/mixtral.py +3 -3
- sglang/srt/models/mllama4.py +428 -19
- sglang/srt/models/qwen2.py +6 -0
- sglang/srt/models/qwen2_moe.py +7 -4
- sglang/srt/models/qwen3_moe.py +39 -14
- sglang/srt/models/step3_vl.py +10 -1
- sglang/srt/models/transformers.py +2 -5
- sglang/srt/multimodal/processors/base_processor.py +4 -3
- sglang/srt/multimodal/processors/gemma3n.py +0 -7
- sglang/srt/multimodal/processors/step3_vl.py +3 -1
- sglang/srt/operations_strategy.py +1 -1
- sglang/srt/reasoning_parser.py +18 -39
- sglang/srt/server_args.py +218 -23
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +18 -0
- sglang/srt/two_batch_overlap.py +163 -9
- sglang/srt/utils.py +41 -26
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/runners.py +4 -4
- sglang/test/test_utils.py +4 -4
- sglang/version.py +1 -1
- {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/METADATA +18 -15
- {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/RECORD +143 -116
- /sglang/srt/mem_cache/{mooncake_store → storage/mooncake_store}/mooncake_store.py +0 -0
- /sglang/srt/mem_cache/{mooncake_store → storage/mooncake_store}/unit_test.py +0 -0
- /sglang/srt/mem_cache/{nixl → storage/nixl}/hicache_nixl.py +0 -0
- /sglang/srt/mem_cache/{nixl → storage/nixl}/nixl_utils.py +0 -0
- /sglang/srt/mem_cache/{nixl → storage/nixl}/test_hicache_nixl_storage.py +0 -0
- {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/WHEEL +0 -0
- {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/top_level.txt +0 -0
@@ -16,6 +16,7 @@ limitations under the License.
|
|
16
16
|
import logging
|
17
17
|
import math
|
18
18
|
import threading
|
19
|
+
import time
|
19
20
|
from queue import Empty, Full, PriorityQueue, Queue
|
20
21
|
from typing import TYPE_CHECKING, List, Optional
|
21
22
|
|
@@ -195,6 +196,8 @@ class PrefetchOperation(StorageOperation):
|
|
195
196
|
self._done_flag = False
|
196
197
|
self._lock = threading.Lock()
|
197
198
|
|
199
|
+
self.start_time = time.monotonic()
|
200
|
+
|
198
201
|
super().__init__(host_indices, token_ids, last_hash)
|
199
202
|
|
200
203
|
def increment(self, num_tokens: int):
|
@@ -236,18 +239,19 @@ class HiCacheController:
|
|
236
239
|
self.enable_storage = False
|
237
240
|
# todo: move backend initialization to storage backend module
|
238
241
|
if storage_backend is not None:
|
242
|
+
self.storage_backend_type = storage_backend
|
239
243
|
from sglang.srt.mem_cache.hicache_storage import HiCacheFile, get_hash_str
|
240
244
|
|
241
245
|
if storage_backend == "file":
|
242
246
|
self.storage_backend = HiCacheFile()
|
243
247
|
self.get_hash_str = get_hash_str
|
244
248
|
elif storage_backend == "nixl":
|
245
|
-
from sglang.srt.mem_cache.nixl.hicache_nixl import HiCacheNixl
|
249
|
+
from sglang.srt.mem_cache.storage.nixl.hicache_nixl import HiCacheNixl
|
246
250
|
|
247
251
|
self.storage_backend = HiCacheNixl()
|
248
252
|
self.get_hash_str = get_hash_str
|
249
253
|
elif storage_backend == "mooncake":
|
250
|
-
from sglang.srt.mem_cache.mooncake_store.mooncake_store import (
|
254
|
+
from sglang.srt.mem_cache.storage.mooncake_store.mooncake_store import (
|
251
255
|
MooncakeStore,
|
252
256
|
get_hash_str_mooncake,
|
253
257
|
)
|
@@ -277,6 +281,12 @@ class HiCacheController:
|
|
277
281
|
self.enable_storage = True
|
278
282
|
# todo: threshold policy for prefetching
|
279
283
|
self.prefetch_threshold = max(prefetch_threshold, self.page_size)
|
284
|
+
self.prefetch_capacity_limit = int(
|
285
|
+
0.8 * (self.mem_pool_host.size - self.mem_pool_device.size)
|
286
|
+
)
|
287
|
+
# tracking the number of tokens locked in prefetching, updated by the main scheduler thread
|
288
|
+
self.prefetch_tokens_occupied = 0
|
289
|
+
|
280
290
|
# create a new communication group for synchronizing storage operations across TP workers
|
281
291
|
self.tp_world_size = torch.distributed.get_world_size(group=tp_group)
|
282
292
|
if self.tp_world_size > 1:
|
@@ -524,7 +534,7 @@ class HiCacheController:
|
|
524
534
|
host_indices: torch.Tensor,
|
525
535
|
new_input_tokens: List[int],
|
526
536
|
last_hash: Optional[str] = None,
|
527
|
-
) ->
|
537
|
+
) -> PrefetchOperation:
|
528
538
|
"""
|
529
539
|
Prefetch KV caches from storage backend to host memory.
|
530
540
|
"""
|
@@ -573,6 +583,9 @@ class HiCacheController:
|
|
573
583
|
self.storage_backend.batch_get(key_strs, buffer_ptrs, buffer_sizes)
|
574
584
|
operation.increment(len(operation.hash_value) * self.page_size)
|
575
585
|
|
586
|
+
def is_mooncake_backend(self):
|
587
|
+
return self.storage_backend_type == "mooncake"
|
588
|
+
|
576
589
|
def prefetch_io_aux_func(self):
|
577
590
|
"""
|
578
591
|
Auxiliary function conducting IO operations for prefetching.
|
@@ -580,13 +593,25 @@ class HiCacheController:
|
|
580
593
|
while not self.stop_event.is_set():
|
581
594
|
try:
|
582
595
|
operation = self.prefetch_buffer.get(block=True, timeout=1)
|
583
|
-
if
|
596
|
+
if self.is_mooncake_backend():
|
584
597
|
self.mooncake_page_transfer(operation)
|
598
|
+
elif self.storage_backend_type == "hf3fs":
|
599
|
+
self.generic_page_transfer(operation, batch_size=128)
|
585
600
|
else:
|
586
601
|
self.generic_page_transfer(operation)
|
587
602
|
except Empty:
|
588
603
|
continue
|
589
604
|
|
605
|
+
def prefetch_rate_limit_check(self) -> bool:
|
606
|
+
"""
|
607
|
+
Rate limit the prefetching operations to avoid overwhelming the storage backend.
|
608
|
+
"""
|
609
|
+
# cancel prefetch if too much memory is occupied
|
610
|
+
if self.prefetch_tokens_occupied >= self.prefetch_capacity_limit:
|
611
|
+
return False
|
612
|
+
# todo: more sophisticated rate limiting based on storage backend performance
|
613
|
+
return True
|
614
|
+
|
590
615
|
def prefetch_thread_func(self):
|
591
616
|
"""
|
592
617
|
Manage prefetching operations from storage backend to host memory.
|
@@ -600,34 +625,36 @@ class HiCacheController:
|
|
600
625
|
if operation is None:
|
601
626
|
continue
|
602
627
|
|
603
|
-
last_hash = operation.last_hash
|
604
|
-
tokens_to_fetch = operation.token_ids
|
605
|
-
|
606
628
|
storage_hit_count = 0
|
607
|
-
|
608
|
-
|
609
|
-
|
610
|
-
|
611
|
-
|
612
|
-
|
613
|
-
|
614
|
-
last_hash
|
615
|
-
|
616
|
-
|
617
|
-
|
618
|
-
|
619
|
-
|
620
|
-
|
621
|
-
|
622
|
-
|
623
|
-
|
624
|
-
|
625
|
-
|
626
|
-
|
627
|
-
|
628
|
-
|
629
|
-
|
630
|
-
|
629
|
+
if self.prefetch_rate_limit_check():
|
630
|
+
last_hash = operation.last_hash
|
631
|
+
tokens_to_fetch = operation.token_ids
|
632
|
+
|
633
|
+
remaining_tokens = len(tokens_to_fetch)
|
634
|
+
hash_value = []
|
635
|
+
while remaining_tokens >= self.page_size:
|
636
|
+
last_hash = self.get_hash_str(
|
637
|
+
tokens_to_fetch[
|
638
|
+
storage_hit_count : storage_hit_count + self.page_size
|
639
|
+
],
|
640
|
+
last_hash,
|
641
|
+
)
|
642
|
+
|
643
|
+
# todo, more unified interface
|
644
|
+
if not self.is_mooncake_backend():
|
645
|
+
if not self.storage_backend.exists(last_hash):
|
646
|
+
break
|
647
|
+
hash_value.append(last_hash)
|
648
|
+
storage_hit_count += self.page_size
|
649
|
+
remaining_tokens -= self.page_size
|
650
|
+
|
651
|
+
if self.is_mooncake_backend():
|
652
|
+
# deferring to batch exists for mooncake store
|
653
|
+
exist_result = self.storage_backend.exists(hash_value)
|
654
|
+
storage_hit_count = (
|
655
|
+
sum(1 for v in exist_result.values() if v != 0)
|
656
|
+
* self.page_size
|
657
|
+
)
|
631
658
|
|
632
659
|
if self.tp_world_size > 1:
|
633
660
|
storage_hit_count_tensor = torch.tensor(
|
@@ -744,8 +771,10 @@ class HiCacheController:
|
|
744
771
|
remaining_tokens -= self.page_size
|
745
772
|
operation.hash_value = hash_value
|
746
773
|
|
747
|
-
if
|
774
|
+
if self.is_mooncake_backend():
|
748
775
|
self.mooncake_page_backup(operation)
|
776
|
+
elif self.storage_backend_type == "hf3fs":
|
777
|
+
self.generic_page_backup(operation, batch_size=128)
|
749
778
|
else:
|
750
779
|
self.generic_page_backup(operation)
|
751
780
|
|
@@ -16,9 +16,13 @@
|
|
16
16
|
import logging
|
17
17
|
import multiprocessing as mp
|
18
18
|
import signal
|
19
|
+
import struct
|
20
|
+
import sys
|
19
21
|
import threading
|
20
22
|
import time
|
21
23
|
from enum import Enum, auto
|
24
|
+
from multiprocessing import shared_memory
|
25
|
+
from typing import Dict, List
|
22
26
|
|
23
27
|
import psutil
|
24
28
|
import setproctitle
|
@@ -32,6 +36,7 @@ from sglang.srt.managers.io_struct import (
|
|
32
36
|
)
|
33
37
|
from sglang.srt.managers.schedule_batch import Req
|
34
38
|
from sglang.srt.managers.scheduler import run_scheduler_process
|
39
|
+
from sglang.srt.managers.utils import DPBalanceMeta
|
35
40
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
36
41
|
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
37
42
|
from sglang.srt.utils import bind_port, configure_logger, get_zmq_socket
|
@@ -45,6 +50,7 @@ class LoadBalanceMethod(Enum):
|
|
45
50
|
|
46
51
|
ROUND_ROBIN = auto()
|
47
52
|
SHORTEST_QUEUE = auto()
|
53
|
+
MINIMUM_TOKENS = auto()
|
48
54
|
|
49
55
|
@classmethod
|
50
56
|
def from_str(cls, method: str):
|
@@ -58,7 +64,16 @@ class LoadBalanceMethod(Enum):
|
|
58
64
|
class DataParallelController:
|
59
65
|
"""A controller that dispatches requests to multiple data parallel workers."""
|
60
66
|
|
61
|
-
def __init__(
|
67
|
+
def __init__(
|
68
|
+
self,
|
69
|
+
server_args: ServerArgs,
|
70
|
+
port_args: PortArgs,
|
71
|
+
dp_balance_meta: DPBalanceMeta,
|
72
|
+
) -> None:
|
73
|
+
# for dp balance
|
74
|
+
self.global_balance_id = 0
|
75
|
+
self.balance_meta = dp_balance_meta
|
76
|
+
|
62
77
|
# Parse args
|
63
78
|
self.max_total_num_tokens = None
|
64
79
|
self.server_args = server_args
|
@@ -79,6 +94,7 @@ class DataParallelController:
|
|
79
94
|
dispatch_lookup = {
|
80
95
|
LoadBalanceMethod.ROUND_ROBIN: self.round_robin_scheduler,
|
81
96
|
LoadBalanceMethod.SHORTEST_QUEUE: self.shortest_queue_scheduler,
|
97
|
+
LoadBalanceMethod.MINIMUM_TOKENS: self.minimum_tokens_scheduler,
|
82
98
|
}
|
83
99
|
self.dispatching = dispatch_lookup[self.load_balance_method]
|
84
100
|
|
@@ -234,6 +250,7 @@ class DataParallelController:
|
|
234
250
|
pp_rank,
|
235
251
|
dp_rank,
|
236
252
|
writer,
|
253
|
+
self.balance_meta,
|
237
254
|
),
|
238
255
|
)
|
239
256
|
with memory_saver_adapter.configure_subprocess():
|
@@ -269,6 +286,33 @@ class DataParallelController:
|
|
269
286
|
def shortest_queue_scheduler(self, input_requests):
|
270
287
|
raise NotImplementedError()
|
271
288
|
|
289
|
+
def minimum_tokens_scheduler(self, req):
|
290
|
+
# This variable corresponds to the balance_id in TokenizedGenerateReqInput.
|
291
|
+
# We use it to to control the number of onfly tokens (requests dispatched to workers but not yet received).
|
292
|
+
def get_next_global_balance_id() -> int:
|
293
|
+
INT32_MAX = 2147483647
|
294
|
+
current_id = self.global_balance_id
|
295
|
+
self.global_balance_id = (self.global_balance_id + 1) % INT32_MAX
|
296
|
+
return current_id
|
297
|
+
|
298
|
+
req.dp_balance_id = get_next_global_balance_id()
|
299
|
+
with self.balance_meta.mutex:
|
300
|
+
# 1. local_tokens represents the tokens currently inferring on the worker,
|
301
|
+
# while onfly refers to the requests dispatched by the dispatcher but not yet received by the scheduler.
|
302
|
+
onfly_info = self.balance_meta.get_shared_onfly()
|
303
|
+
local_tokens = self.balance_meta.get_shared_local_tokens()
|
304
|
+
total_tokens = [
|
305
|
+
local_token + sum(onfly_dict.values())
|
306
|
+
for local_token, onfly_dict in zip(local_tokens, onfly_info)
|
307
|
+
]
|
308
|
+
target_worker = total_tokens.index(min(total_tokens))
|
309
|
+
onfly_info[target_worker][req.dp_balance_id] = len(req.input_ids)
|
310
|
+
# 2. write the new onfly info to the shm
|
311
|
+
self.balance_meta.set_shared_onfly_info(onfly_info)
|
312
|
+
|
313
|
+
# logger.info(f"dp workers {local_tokens=}, {onfly_info=}, {target_worker=}")
|
314
|
+
self.workers[target_worker].send_pyobj(req)
|
315
|
+
|
272
316
|
def event_loop(self):
|
273
317
|
while True:
|
274
318
|
while True:
|
@@ -302,9 +346,12 @@ def run_data_parallel_controller_process(
|
|
302
346
|
setproctitle.setproctitle("sglang::data_parallel_controller")
|
303
347
|
configure_logger(server_args)
|
304
348
|
parent_process = psutil.Process().parent()
|
349
|
+
balance_meta = DPBalanceMeta(server_args.dp_size)
|
305
350
|
|
306
351
|
try:
|
307
|
-
controller = DataParallelController(
|
352
|
+
controller = DataParallelController(
|
353
|
+
server_args, port_args, dp_balance_meta=balance_meta
|
354
|
+
)
|
308
355
|
pipe_writer.send(
|
309
356
|
{
|
310
357
|
"status": "ready",
|
@@ -323,3 +370,6 @@ def run_data_parallel_controller_process(
|
|
323
370
|
traceback = get_exception_traceback()
|
324
371
|
logger.error(f"DataParallelController hit an exception: {traceback}")
|
325
372
|
parent_process.send_signal(signal.SIGQUIT)
|
373
|
+
finally:
|
374
|
+
# we need to destruct mp.Manager() in balance_meta
|
375
|
+
balance_meta.destructor()
|
@@ -216,7 +216,7 @@ class DetokenizerManager:
|
|
216
216
|
rids=recv_obj.rids,
|
217
217
|
finished_reasons=recv_obj.finished_reasons,
|
218
218
|
output_strs=output_strs,
|
219
|
-
output_ids=
|
219
|
+
output_ids=recv_obj.decode_ids,
|
220
220
|
prompt_tokens=recv_obj.prompt_tokens,
|
221
221
|
completion_tokens=recv_obj.completion_tokens,
|
222
222
|
cached_tokens=recv_obj.cached_tokens,
|
sglang/srt/managers/io_struct.py
CHANGED
@@ -26,6 +26,7 @@ from sglang.srt.lora.lora_registry import LoRARef
|
|
26
26
|
from sglang.srt.managers.schedule_batch import BaseFinishReason
|
27
27
|
from sglang.srt.multimodal.mm_utils import has_valid_data
|
28
28
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
29
|
+
from sglang.srt.utils import ImageData
|
29
30
|
|
30
31
|
# Handle serialization of Image for pydantic
|
31
32
|
if TYPE_CHECKING:
|
@@ -45,7 +46,7 @@ class SessionParams:
|
|
45
46
|
|
46
47
|
# Type definitions for multimodal input data
|
47
48
|
# Individual data item types for each modality
|
48
|
-
ImageDataInputItem = Union[Image, str, Dict]
|
49
|
+
ImageDataInputItem = Union[Image, str, ImageData, Dict]
|
49
50
|
AudioDataInputItem = Union[str, Dict]
|
50
51
|
VideoDataInputItem = Union[str, Dict]
|
51
52
|
# Union type for any multimodal data item
|
@@ -101,8 +102,10 @@ class GenerateReqInput:
|
|
101
102
|
|
102
103
|
# The modalities of the image data [image, multi-images, video]
|
103
104
|
modalities: Optional[List[str]] = None
|
104
|
-
# The path to the LoRA
|
105
|
+
# The path to the LoRA adaptors
|
105
106
|
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
|
107
|
+
# The uid of LoRA adaptors, should be initialized by tokenizer manager
|
108
|
+
lora_id: Optional[Union[List[Optional[str]], Optional[str]]] = None
|
106
109
|
|
107
110
|
# Session info for continual prompting
|
108
111
|
session_params: Optional[Union[List[Dict], Dict]] = None
|
@@ -123,6 +126,9 @@ class GenerateReqInput:
|
|
123
126
|
# For data parallel rank routing
|
124
127
|
data_parallel_rank: Optional[int] = None
|
125
128
|
|
129
|
+
# For background responses (OpenAI responses API)
|
130
|
+
background: bool = False
|
131
|
+
|
126
132
|
def contains_mm_input(self) -> bool:
|
127
133
|
return (
|
128
134
|
has_valid_data(self.image_data)
|
@@ -500,7 +506,7 @@ class TokenizedGenerateReqInput:
|
|
500
506
|
stream: bool
|
501
507
|
|
502
508
|
# LoRA related
|
503
|
-
|
509
|
+
lora_id: Optional[str] = None # None means just use the base model
|
504
510
|
# The input embeds
|
505
511
|
input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
|
506
512
|
|
@@ -523,6 +529,9 @@ class TokenizedGenerateReqInput:
|
|
523
529
|
# For data parallel rank routing
|
524
530
|
data_parallel_rank: Optional[int] = None
|
525
531
|
|
532
|
+
# For dp balance
|
533
|
+
dp_balance_id: int = -1
|
534
|
+
|
526
535
|
|
527
536
|
@dataclass
|
528
537
|
class EmbeddingReqInput:
|
@@ -554,6 +563,9 @@ class EmbeddingReqInput:
|
|
554
563
|
# For cross-encoder requests
|
555
564
|
is_cross_encoder_request: bool = False
|
556
565
|
|
566
|
+
# For background responses (OpenAI responses API)
|
567
|
+
background: bool = False
|
568
|
+
|
557
569
|
def normalize_batch_and_arguments(self):
|
558
570
|
# at least one of text, input_ids, or image should be provided
|
559
571
|
if self.text is None and self.input_ids is None and self.image_data is None:
|
@@ -648,6 +660,8 @@ class TokenizedEmbeddingReqInput:
|
|
648
660
|
token_type_ids: List[int]
|
649
661
|
# Dummy sampling params for compatibility
|
650
662
|
sampling_params: SamplingParams
|
663
|
+
# For dp balance
|
664
|
+
dp_balance_id: int = -1
|
651
665
|
|
652
666
|
|
653
667
|
@dataclass
|
@@ -1068,6 +1082,8 @@ class LoadLoRAAdapterReqInput:
|
|
1068
1082
|
lora_name: str
|
1069
1083
|
# The path of loading.
|
1070
1084
|
lora_path: str
|
1085
|
+
# Whether to pin the LoRA adapter in memory.
|
1086
|
+
pinned: bool = False
|
1071
1087
|
# The unique identifier for the LoRA adapter, which automatically generated in the `TokenizerManager`.
|
1072
1088
|
lora_id: Optional[str] = None
|
1073
1089
|
|
@@ -1076,6 +1092,7 @@ class LoadLoRAAdapterReqInput:
|
|
1076
1092
|
lora_id=self.lora_id,
|
1077
1093
|
lora_name=self.lora_name,
|
1078
1094
|
lora_path=self.lora_path,
|
1095
|
+
pinned=self.pinned,
|
1079
1096
|
)
|
1080
1097
|
|
1081
1098
|
|
@@ -1097,7 +1114,7 @@ class UnloadLoRAAdapterReqInput:
|
|
1097
1114
|
class LoRAUpdateResult:
|
1098
1115
|
success: bool
|
1099
1116
|
error_message: Optional[str] = None
|
1100
|
-
loaded_adapters: Dict[str, LoRARef] =
|
1117
|
+
loaded_adapters: Optional[Dict[str, LoRARef]] = None
|
1101
1118
|
|
1102
1119
|
|
1103
1120
|
LoadLoRAAdapterReqOutput = UnloadLoRAAdapterReqOutput = LoRAUpdateResult
|
sglang/srt/managers/mm_utils.py
CHANGED
@@ -388,24 +388,18 @@ def _get_chunked_prefill_embedding(
|
|
388
388
|
embedding_per_req = data_embedding_func(embedding_items_per_req)
|
389
389
|
if not embedding_cache.put(embedding_items_hash, embedding_per_req):
|
390
390
|
print_warning_once(
|
391
|
-
"Multimodal embedding cache is full.
|
392
|
-
"
|
391
|
+
"Multimodal embedding cache is full. This typically occurs when a single "
|
392
|
+
"embedding exceeds the cache size limit. Consider increasing the "
|
393
|
+
"`SGLANG_VLM_CACHE_SIZE_MB` environment variable or reducing the input "
|
394
|
+
"embedding size."
|
393
395
|
)
|
394
396
|
|
395
|
-
embedding_per_req_chunk, _,
|
397
|
+
embedding_per_req_chunk, _, _ = get_embedding_chunk(
|
396
398
|
embedding=embedding_per_req,
|
397
399
|
extend_prefix_len=prefix_length[i],
|
398
400
|
extend_seq_len=extend_length[i] if i < len(extend_length) else 0,
|
399
401
|
items_offset=items_offset,
|
400
402
|
)
|
401
|
-
# remove this item from cache if chunk reaches to the end
|
402
|
-
embedding_per_req_length = (
|
403
|
-
embedding_per_req.shape[0]
|
404
|
-
if embedding_per_req.dim() == 2
|
405
|
-
else embedding_per_req.shape[0] * embedding_per_req.shape[1]
|
406
|
-
)
|
407
|
-
if end_index == embedding_per_req_length:
|
408
|
-
embedding_cache.free(embedding_items_hash)
|
409
403
|
embedding_list.append(embedding_per_req_chunk)
|
410
404
|
if len(embedding_list) == 0:
|
411
405
|
return None
|
@@ -84,10 +84,10 @@ GLOBAL_SERVER_ARGS_KEYS = [
|
|
84
84
|
"disable_radix_cache",
|
85
85
|
"enable_dp_attention",
|
86
86
|
"enable_two_batch_overlap",
|
87
|
+
"tbo_token_distribution_threshold",
|
87
88
|
"enable_dp_lm_head",
|
88
|
-
"
|
89
|
+
"moe_a2a_backend",
|
89
90
|
"deepep_mode",
|
90
|
-
"enable_ep_moe",
|
91
91
|
"enable_flashinfer_cutlass_moe",
|
92
92
|
"enable_flashinfer_trtllm_moe",
|
93
93
|
"enable_flashinfer_allreduce_fusion",
|
@@ -107,7 +107,10 @@ GLOBAL_SERVER_ARGS_KEYS = [
|
|
107
107
|
"num_reserved_decode_tokens",
|
108
108
|
"weight_loader_disable_mmap",
|
109
109
|
"enable_triton_kernel_moe",
|
110
|
+
"enable_flashinfer_mxfp4_moe",
|
110
111
|
"enable_multimodal",
|
112
|
+
"enable_symm_mem",
|
113
|
+
"quantization",
|
111
114
|
]
|
112
115
|
|
113
116
|
# Put some global args for easy access
|
@@ -422,7 +425,7 @@ class Req:
|
|
422
425
|
token_ids_logprob: List[int] = None,
|
423
426
|
stream: bool = False,
|
424
427
|
origin_input_ids_unpadded: Optional[Tuple[int]] = None,
|
425
|
-
|
428
|
+
lora_id: Optional[str] = None,
|
426
429
|
input_embeds: Optional[List[List[float]]] = None,
|
427
430
|
token_type_ids: List[int] = None,
|
428
431
|
session_id: Optional[str] = None,
|
@@ -466,7 +469,7 @@ class Req:
|
|
466
469
|
self.sampling_params = sampling_params
|
467
470
|
self.custom_logit_processor = custom_logit_processor
|
468
471
|
self.return_hidden_states = return_hidden_states
|
469
|
-
self.
|
472
|
+
self.lora_id = lora_id
|
470
473
|
|
471
474
|
# Memory pool info
|
472
475
|
self.req_pool_idx: Optional[int] = None
|
@@ -844,6 +847,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
844
847
|
|
845
848
|
# The sum of all sequence lengths
|
846
849
|
seq_lens_sum: int = None
|
850
|
+
# The original sequence lengths, Qwen-1M related
|
851
|
+
orig_seq_lens: torch.Tensor = None # shape: [b], int32
|
847
852
|
|
848
853
|
# For DP attention
|
849
854
|
global_num_tokens: Optional[List[int]] = None
|
@@ -916,8 +921,10 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
916
921
|
|
917
922
|
is_hybrid = False
|
918
923
|
if isinstance(token_to_kv_pool_allocator, SWATokenToKVPoolAllocator):
|
919
|
-
assert
|
920
|
-
tree_cache
|
924
|
+
assert (
|
925
|
+
tree_cache is None
|
926
|
+
or isinstance(tree_cache, SWARadixCache)
|
927
|
+
or isinstance(tree_cache, SWAChunkCache)
|
921
928
|
), "SWARadixCache or SWAChunkCache is required for SWATokenToKVPoolAllocator"
|
922
929
|
is_hybrid = True
|
923
930
|
|
@@ -1127,6 +1134,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1127
1134
|
input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
|
1128
1135
|
extend_num_tokens = sum(len(ids) for ids in input_ids)
|
1129
1136
|
seq_lens = [len(r.fill_ids) for r in reqs]
|
1137
|
+
orig_seq_lens = [max(len(r.fill_ids), len(r.origin_input_ids)) for r in reqs]
|
1130
1138
|
prefix_lens = [len(r.prefix_indices) for r in reqs]
|
1131
1139
|
extend_lens = [r.extend_input_len for r in reqs]
|
1132
1140
|
|
@@ -1143,6 +1151,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1143
1151
|
seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int64).to(
|
1144
1152
|
self.device, non_blocking=True
|
1145
1153
|
)
|
1154
|
+
orig_seq_lens_tensor = torch.tensor(orig_seq_lens, dtype=torch.int32).to(
|
1155
|
+
self.device, non_blocking=True
|
1156
|
+
)
|
1146
1157
|
prefix_lens_tensor = torch.tensor(
|
1147
1158
|
prefix_lens, dtype=torch.int64, device=self.device
|
1148
1159
|
)
|
@@ -1256,6 +1267,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1256
1267
|
self.input_ids = input_ids_tensor
|
1257
1268
|
self.req_pool_indices = req_pool_indices_tensor
|
1258
1269
|
self.seq_lens = seq_lens_tensor
|
1270
|
+
self.orig_seq_lens = orig_seq_lens_tensor
|
1259
1271
|
self.out_cache_loc = out_cache_loc
|
1260
1272
|
self.input_embeds = (
|
1261
1273
|
torch.tensor(input_embeds).to(self.device, non_blocking=True)
|
@@ -1503,6 +1515,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1503
1515
|
self.forward_mode = ForwardMode.IDLE
|
1504
1516
|
self.input_ids = torch.empty(0, dtype=torch.int64, device=self.device)
|
1505
1517
|
self.seq_lens = torch.empty(0, dtype=torch.int64, device=self.device)
|
1518
|
+
self.orig_seq_lens = torch.empty(0, dtype=torch.int32, device=self.device)
|
1506
1519
|
self.out_cache_loc = torch.empty(0, dtype=torch.int64, device=self.device)
|
1507
1520
|
self.req_pool_indices = torch.empty(0, dtype=torch.int32, device=self.device)
|
1508
1521
|
self.seq_lens_sum = 0
|
@@ -1557,9 +1570,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1557
1570
|
if self.enable_overlap:
|
1558
1571
|
# Do not use in-place operations in the overlap mode
|
1559
1572
|
self.seq_lens = self.seq_lens + 1
|
1573
|
+
self.orig_seq_lens = self.orig_seq_lens + 1
|
1560
1574
|
else:
|
1561
1575
|
# A faster in-place version
|
1562
1576
|
self.seq_lens.add_(1)
|
1577
|
+
self.orig_seq_lens.add_(1)
|
1563
1578
|
self.seq_lens_sum += bs
|
1564
1579
|
|
1565
1580
|
# free memory
|
@@ -1623,6 +1638,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1623
1638
|
self.multimodal_inputs = [self.multimodal_inputs[i] for i in keep_indices]
|
1624
1639
|
self.req_pool_indices = self.req_pool_indices[keep_indices_device]
|
1625
1640
|
self.seq_lens = self.seq_lens[keep_indices_device]
|
1641
|
+
self.orig_seq_lens = self.orig_seq_lens[keep_indices_device]
|
1626
1642
|
self.out_cache_loc = None
|
1627
1643
|
self.seq_lens_sum = self.seq_lens.sum().item()
|
1628
1644
|
self.output_ids = self.output_ids[keep_indices_device]
|
@@ -1655,6 +1671,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1655
1671
|
[self.req_pool_indices, other.req_pool_indices]
|
1656
1672
|
)
|
1657
1673
|
self.seq_lens = torch.cat([self.seq_lens, other.seq_lens])
|
1674
|
+
self.orig_seq_lens = torch.cat([self.orig_seq_lens, other.orig_seq_lens])
|
1658
1675
|
self.out_cache_loc = None
|
1659
1676
|
self.seq_lens_sum += other.seq_lens_sum
|
1660
1677
|
if self.output_ids is not None:
|
@@ -1704,6 +1721,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1704
1721
|
or attention_backend_str == "flashmla"
|
1705
1722
|
or attention_backend_str == "cutlass_mla"
|
1706
1723
|
or attention_backend_str == "ascend"
|
1724
|
+
or attention_backend_str == "trtllm_mha"
|
1707
1725
|
or global_server_args_dict["enable_two_batch_overlap"]
|
1708
1726
|
):
|
1709
1727
|
seq_lens_cpu = (
|
@@ -1728,6 +1746,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1728
1746
|
input_ids=self.input_ids,
|
1729
1747
|
req_pool_indices=self.req_pool_indices,
|
1730
1748
|
seq_lens=self.seq_lens,
|
1749
|
+
orig_seq_lens=self.orig_seq_lens,
|
1731
1750
|
out_cache_loc=self.out_cache_loc,
|
1732
1751
|
seq_lens_cpu=seq_lens_cpu,
|
1733
1752
|
seq_lens_sum=self.seq_lens_sum,
|
@@ -1749,7 +1768,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1749
1768
|
encoder_lens=self.encoder_lens,
|
1750
1769
|
encoder_lens_cpu=self.encoder_lens_cpu,
|
1751
1770
|
encoder_out_cache_loc=self.encoder_out_cache_loc,
|
1752
|
-
|
1771
|
+
lora_ids=[req.lora_id for req in self.reqs],
|
1753
1772
|
sampling_info=self.sampling_info,
|
1754
1773
|
input_embeds=self.input_embeds,
|
1755
1774
|
token_type_ids=self.token_type_ids,
|
@@ -1890,11 +1909,14 @@ class ModelWorkerBatch:
|
|
1890
1909
|
encoder_out_cache_loc: Optional[torch.Tensor]
|
1891
1910
|
|
1892
1911
|
# For LoRA
|
1893
|
-
|
1912
|
+
lora_ids: Optional[List[str]]
|
1894
1913
|
|
1895
1914
|
# Sampling info
|
1896
1915
|
sampling_info: SamplingBatchInfo
|
1897
1916
|
|
1917
|
+
# The original sequence lengths, Qwen-1M related
|
1918
|
+
orig_seq_lens: Optional[torch.Tensor] = None
|
1919
|
+
|
1898
1920
|
# The input Embeds
|
1899
1921
|
input_embeds: Optional[torch.Tensor] = None
|
1900
1922
|
|
@@ -455,7 +455,9 @@ class PrefillAdder:
|
|
455
455
|
if not self.is_hybrid:
|
456
456
|
# Skip this logic for swa. The SWA has different memory management, and
|
457
457
|
# this mechanism is underestimating the memory usage.
|
458
|
-
cur_rem_tokens = self.cur_rem_tokens -
|
458
|
+
cur_rem_tokens = self.cur_rem_tokens - self.ceil_paged_tokens(
|
459
|
+
req.extend_input_len
|
460
|
+
)
|
459
461
|
tokens_freed = 0
|
460
462
|
for i, (tokens_left, tokens_occupied) in enumerate(self.req_states):
|
461
463
|
# tokens_left gives a reservative calculation as the last token is not stored
|