sglang 0.4.9.post3__py3-none-any.whl → 0.4.9.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/srt/_custom_ops.py +29 -1
- sglang/srt/configs/model_config.py +1 -1
- sglang/srt/conversation.py +1 -1
- sglang/srt/disaggregation/common/conn.py +34 -6
- sglang/srt/disaggregation/mini_lb.py +3 -2
- sglang/srt/disaggregation/mooncake/conn.py +49 -20
- sglang/srt/disaggregation/mooncake/transfer_engine.py +4 -2
- sglang/srt/disaggregation/nixl/conn.py +17 -13
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -91
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +96 -1
- sglang/srt/distributed/device_communicators/quick_all_reduce.py +273 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +12 -5
- sglang/srt/distributed/parallel_state.py +70 -15
- sglang/srt/entrypoints/engine.py +2 -8
- sglang/srt/entrypoints/http_server.py +20 -32
- sglang/srt/entrypoints/openai/protocol.py +3 -3
- sglang/srt/entrypoints/openai/serving_chat.py +27 -4
- sglang/srt/function_call/base_format_detector.py +74 -12
- sglang/srt/function_call/deepseekv3_detector.py +26 -11
- sglang/srt/function_call/ebnf_composer.py +95 -63
- sglang/srt/function_call/function_call_parser.py +4 -4
- sglang/srt/function_call/kimik2_detector.py +41 -16
- sglang/srt/function_call/llama32_detector.py +6 -3
- sglang/srt/function_call/mistral_detector.py +11 -3
- sglang/srt/function_call/pythonic_detector.py +16 -14
- sglang/srt/function_call/qwen25_detector.py +12 -3
- sglang/srt/function_call/{qwen3_detector.py → qwen3_coder_detector.py} +10 -9
- sglang/srt/layers/activation.py +11 -3
- sglang/srt/layers/attention/base_attn_backend.py +3 -1
- sglang/srt/layers/communicator.py +12 -12
- sglang/srt/layers/dp_attention.py +72 -24
- sglang/srt/layers/logits_processor.py +34 -24
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=320,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +25 -224
- sglang/srt/layers/moe/topk.py +5 -13
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -9
- sglang/srt/layers/quantization/modelopt_quant.py +8 -4
- sglang/srt/layers/quantization/utils.py +0 -9
- sglang/srt/layers/radix_attention.py +5 -3
- sglang/srt/lora/lora_manager.py +133 -169
- sglang/srt/lora/lora_registry.py +124 -0
- sglang/srt/lora/mem_pool.py +2 -2
- sglang/srt/managers/cache_controller.py +53 -6
- sglang/srt/managers/io_struct.py +19 -1
- sglang/srt/managers/schedule_batch.py +13 -3
- sglang/srt/managers/scheduler.py +13 -25
- sglang/srt/managers/tokenizer_manager.py +28 -25
- sglang/srt/managers/tp_worker.py +2 -4
- sglang/srt/mem_cache/allocator.py +67 -7
- sglang/srt/mem_cache/hicache_storage.py +17 -1
- sglang/srt/mem_cache/hiradix_cache.py +30 -16
- sglang/srt/mem_cache/memory_pool_host.py +3 -0
- sglang/srt/model_executor/cuda_graph_runner.py +61 -25
- sglang/srt/model_executor/forward_batch_info.py +201 -29
- sglang/srt/model_executor/model_runner.py +41 -23
- sglang/srt/models/deepseek_v2.py +1 -2
- sglang/srt/models/mllama4.py +10 -3
- sglang/srt/models/qwen2_moe.py +0 -4
- sglang/srt/models/qwen3_moe.py +1 -6
- sglang/srt/reasoning_parser.py +46 -4
- sglang/srt/sampling/sampling_batch_info.py +6 -5
- sglang/srt/server_args.py +76 -55
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +33 -28
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +37 -36
- sglang/srt/speculative/eagle_utils.py +51 -23
- sglang/srt/speculative/eagle_worker.py +59 -44
- sglang/srt/two_batch_overlap.py +9 -5
- sglang/srt/utils.py +17 -68
- sglang/test/test_activation.py +50 -1
- sglang/version.py +1 -1
- {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post4.dist-info}/METADATA +5 -5
- {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post4.dist-info}/RECORD +75 -72
- {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post4.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post4.dist-info}/top_level.txt +0 -0
@@ -219,6 +219,7 @@ class HiCacheController:
|
|
219
219
|
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
|
220
220
|
mem_pool_host: HostKVCache,
|
221
221
|
page_size: int,
|
222
|
+
tp_group: torch.distributed.ProcessGroup,
|
222
223
|
load_cache_event: threading.Event = None,
|
223
224
|
write_policy: str = "write_through_selective",
|
224
225
|
io_backend: str = "",
|
@@ -244,11 +245,17 @@ class HiCacheController:
|
|
244
245
|
self.enable_storage = False
|
245
246
|
# todo: move backend initialization to storage backend module
|
246
247
|
if storage_backend is not None:
|
248
|
+
# create a new communication group for synchronizing storage operations across TP workers
|
249
|
+
self.tp_world_size = torch.distributed.get_world_size(group=tp_group)
|
250
|
+
if self.tp_world_size > 1:
|
251
|
+
group_ranks = torch.distributed.get_process_group_ranks(tp_group)
|
252
|
+
self.tp_group = torch.distributed.new_group(group_ranks, backend="gloo")
|
253
|
+
|
247
254
|
if storage_backend == "file":
|
248
255
|
self.storage_backend = HiCacheFile()
|
249
256
|
self.enable_storage = True
|
250
257
|
# todo: threshold policy for prefetching
|
251
|
-
self.prefetch_threshold = prefetch_threshold
|
258
|
+
self.prefetch_threshold = max(prefetch_threshold, self.page_size)
|
252
259
|
else:
|
253
260
|
raise NotImplementedError(
|
254
261
|
f"Unsupported storage backend: {storage_backend}"
|
@@ -358,6 +365,7 @@ class HiCacheController:
|
|
358
365
|
if host_indices is None:
|
359
366
|
return None
|
360
367
|
self.mem_pool_host.protect_write(host_indices)
|
368
|
+
torch.cuda.current_stream().synchronize()
|
361
369
|
self.write_queue.put(
|
362
370
|
CacheOperation(host_indices, device_indices, node_id, priority)
|
363
371
|
)
|
@@ -567,13 +575,32 @@ class HiCacheController:
|
|
567
575
|
else:
|
568
576
|
break
|
569
577
|
|
578
|
+
if self.tp_world_size > 1:
|
579
|
+
storage_hit_count_tensor = torch.tensor(
|
580
|
+
storage_hit_count, dtype=torch.int
|
581
|
+
)
|
582
|
+
torch.distributed.all_reduce(
|
583
|
+
storage_hit_count_tensor,
|
584
|
+
op=torch.distributed.ReduceOp.MIN,
|
585
|
+
group=self.tp_group,
|
586
|
+
)
|
587
|
+
storage_hit_count = storage_hit_count_tensor.item()
|
588
|
+
|
570
589
|
if storage_hit_count < self.prefetch_threshold:
|
571
590
|
# not to prefetch if not enough benefits
|
572
591
|
self.prefetch_revoke_queue.put(operation.request_id)
|
592
|
+
logger.debug(
|
593
|
+
f"Revoking prefetch for request {operation.request_id} due to insufficient hits ({storage_hit_count})."
|
594
|
+
)
|
573
595
|
else:
|
574
|
-
operation.hash_value = hash_value
|
596
|
+
operation.hash_value = hash_value[
|
597
|
+
: (storage_hit_count // self.page_size)
|
598
|
+
]
|
599
|
+
# free the pre-allocated memory for pages that are not hit
|
600
|
+
self.mem_pool_host.free(operation.host_indices[storage_hit_count:])
|
601
|
+
operation.host_indices = operation.host_indices[:storage_hit_count]
|
575
602
|
logger.debug(
|
576
|
-
f"Prefetching {len(hash_value)} pages for request {operation.request_id}."
|
603
|
+
f"Prefetching {len(operation.hash_value)} pages for request {operation.request_id}."
|
577
604
|
)
|
578
605
|
self.prefetch_buffer.put(operation)
|
579
606
|
|
@@ -610,17 +637,37 @@ class HiCacheController:
|
|
610
637
|
last_hash = get_hash_str(
|
611
638
|
tokens_to_backup[i : i + self.page_size], last_hash
|
612
639
|
)
|
613
|
-
|
614
|
-
self.storage_backend.set(
|
640
|
+
success = self.storage_backend.set(
|
615
641
|
last_hash,
|
616
642
|
self.mem_pool_host.get_flat_data_page(
|
617
643
|
operation.host_indices[i]
|
618
644
|
),
|
619
645
|
)
|
646
|
+
if not success:
|
647
|
+
logger.warning(f"Failed to write page {last_hash} to storage.")
|
648
|
+
break
|
620
649
|
operation.completed_tokens += self.page_size
|
621
650
|
operation.hash_value.append(last_hash)
|
622
651
|
|
623
|
-
|
652
|
+
min_completed_tokens = operation.completed_tokens
|
653
|
+
if self.tp_world_size > 1:
|
654
|
+
completed_tokens_tensor = torch.tensor(
|
655
|
+
min_completed_tokens, dtype=torch.int
|
656
|
+
)
|
657
|
+
torch.distributed.all_reduce(
|
658
|
+
completed_tokens_tensor,
|
659
|
+
op=torch.distributed.ReduceOp.MIN,
|
660
|
+
group=self.tp_group,
|
661
|
+
)
|
662
|
+
min_completed_tokens = completed_tokens_tensor.item()
|
663
|
+
|
664
|
+
self.ack_backup_queue.put(
|
665
|
+
(
|
666
|
+
operation.id,
|
667
|
+
operation.hash_value[: min_completed_tokens // self.page_size],
|
668
|
+
min_completed_tokens,
|
669
|
+
)
|
670
|
+
)
|
624
671
|
|
625
672
|
except Empty:
|
626
673
|
continue
|
sglang/srt/managers/io_struct.py
CHANGED
@@ -22,6 +22,7 @@ from dataclasses import dataclass, field
|
|
22
22
|
from enum import Enum
|
23
23
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
24
24
|
|
25
|
+
from sglang.srt.lora.lora_registry import LoRARef
|
25
26
|
from sglang.srt.managers.schedule_batch import BaseFinishReason
|
26
27
|
from sglang.srt.multimodal.mm_utils import has_valid_data
|
27
28
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
@@ -1067,19 +1068,36 @@ class LoadLoRAAdapterReqInput:
|
|
1067
1068
|
lora_name: str
|
1068
1069
|
# The path of loading.
|
1069
1070
|
lora_path: str
|
1071
|
+
# The unique identifier for the LoRA adapter, which automatically generated in the `TokenizerManager`.
|
1072
|
+
lora_id: Optional[str] = None
|
1073
|
+
|
1074
|
+
def to_ref(self) -> LoRARef:
|
1075
|
+
return LoRARef(
|
1076
|
+
lora_id=self.lora_id,
|
1077
|
+
lora_name=self.lora_name,
|
1078
|
+
lora_path=self.lora_path,
|
1079
|
+
)
|
1070
1080
|
|
1071
1081
|
|
1072
1082
|
@dataclass
|
1073
1083
|
class UnloadLoRAAdapterReqInput:
|
1074
1084
|
# The name of lora module to unload.
|
1075
1085
|
lora_name: str
|
1086
|
+
# The unique identifier for the LoRA adapter, which automatically generated in the `TokenizerManager`.
|
1087
|
+
lora_id: Optional[str] = None
|
1088
|
+
|
1089
|
+
def to_ref(self) -> LoRARef:
|
1090
|
+
return LoRARef(
|
1091
|
+
lora_id=self.lora_id,
|
1092
|
+
lora_name=self.lora_name,
|
1093
|
+
)
|
1076
1094
|
|
1077
1095
|
|
1078
1096
|
@dataclass
|
1079
1097
|
class LoRAUpdateResult:
|
1080
1098
|
success: bool
|
1081
1099
|
error_message: Optional[str] = None
|
1082
|
-
loaded_adapters: Dict[str,
|
1100
|
+
loaded_adapters: Dict[str, LoRARef] = field(default_factory=dict)
|
1083
1101
|
|
1084
1102
|
|
1085
1103
|
LoadLoRAAdapterReqOutput = UnloadLoRAAdapterReqOutput = LoRAUpdateResult
|
@@ -45,7 +45,6 @@ import triton
|
|
45
45
|
import triton.language as tl
|
46
46
|
|
47
47
|
from sglang.global_config import global_config
|
48
|
-
from sglang.srt.configs.model_config import ModelConfig
|
49
48
|
from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
|
50
49
|
from sglang.srt.disaggregation.base import BaseKVSender
|
51
50
|
from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
|
@@ -68,6 +67,7 @@ from sglang.srt.server_args import ServerArgs
|
|
68
67
|
from sglang.srt.utils import flatten_nested_list, support_triton
|
69
68
|
|
70
69
|
if TYPE_CHECKING:
|
70
|
+
from sglang.srt.configs.model_config import ModelConfig
|
71
71
|
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
72
72
|
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
73
73
|
|
@@ -106,6 +106,7 @@ GLOBAL_SERVER_ARGS_KEYS = [
|
|
106
106
|
"num_reserved_decode_tokens",
|
107
107
|
"weight_loader_disable_mmap",
|
108
108
|
"enable_triton_kernel_moe",
|
109
|
+
"enable_multimodal",
|
109
110
|
]
|
110
111
|
|
111
112
|
# Put some global args for easy access
|
@@ -430,6 +431,7 @@ class Req:
|
|
430
431
|
bootstrap_port: Optional[int] = None,
|
431
432
|
bootstrap_room: Optional[int] = None,
|
432
433
|
data_parallel_rank: Optional[int] = None,
|
434
|
+
vocab_size: Optional[int] = None,
|
433
435
|
):
|
434
436
|
# Input and output info
|
435
437
|
self.rid = rid
|
@@ -479,6 +481,7 @@ class Req:
|
|
479
481
|
self.to_abort_message: str = None
|
480
482
|
self.stream = stream
|
481
483
|
self.eos_token_ids = eos_token_ids
|
484
|
+
self.vocab_size = vocab_size
|
482
485
|
|
483
486
|
# For incremental decoding
|
484
487
|
# ----- | --------- read_ids -------|
|
@@ -712,6 +715,14 @@ class Req:
|
|
712
715
|
self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id)
|
713
716
|
return
|
714
717
|
|
718
|
+
if last_token_id > self.vocab_size or last_token_id < 0:
|
719
|
+
if self.sampling_params.stop_token_ids:
|
720
|
+
self.output_ids[-1] = next(iter(self.sampling_params.stop_token_ids))
|
721
|
+
if self.eos_token_ids:
|
722
|
+
self.output_ids[-1] = next(iter(self.eos_token_ids))
|
723
|
+
self.finished_reason = FINISH_MATCHED_STR(matched="NaN happened")
|
724
|
+
return
|
725
|
+
|
715
726
|
# Check stop strings
|
716
727
|
if len(self.sampling_params.stop_strs) > 0:
|
717
728
|
tail_str = self.tokenizer.decode(
|
@@ -1879,7 +1890,7 @@ class ModelWorkerBatch:
|
|
1879
1890
|
sampling_info: SamplingBatchInfo
|
1880
1891
|
|
1881
1892
|
# The input Embeds
|
1882
|
-
input_embeds: Optional[torch.
|
1893
|
+
input_embeds: Optional[torch.Tensor] = None
|
1883
1894
|
|
1884
1895
|
# For corss-encoder model
|
1885
1896
|
token_type_ids: Optional[torch.Tensor] = None
|
@@ -1889,7 +1900,6 @@ class ModelWorkerBatch:
|
|
1889
1900
|
spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput]] = None
|
1890
1901
|
# If set, the output of the batch contains the hidden states of the run.
|
1891
1902
|
capture_hidden_mode: CaptureHiddenMode = None
|
1892
|
-
spec_num_draft_tokens: Optional[int] = None
|
1893
1903
|
hicache_consumer_index: int = 0
|
1894
1904
|
|
1895
1905
|
# Overlap event
|
sglang/srt/managers/scheduler.py
CHANGED
@@ -247,7 +247,7 @@ class Scheduler(
|
|
247
247
|
self.pp_size = server_args.pp_size
|
248
248
|
self.dp_size = server_args.dp_size
|
249
249
|
self.schedule_policy = server_args.schedule_policy
|
250
|
-
self.
|
250
|
+
self.enable_lora = server_args.enable_lora
|
251
251
|
self.max_loras_per_batch = server_args.max_loras_per_batch
|
252
252
|
self.enable_overlap = not server_args.disable_overlap_schedule
|
253
253
|
self.skip_tokenizer_init = server_args.skip_tokenizer_init
|
@@ -653,6 +653,9 @@ class Scheduler(
|
|
653
653
|
)
|
654
654
|
)
|
655
655
|
|
656
|
+
embedding_cache_size = int(os.environ.get("SGLANG_VLM_CACHE_SIZE_MB", "100"))
|
657
|
+
init_embedding_cache(embedding_cache_size * 1024 * 1024)
|
658
|
+
|
656
659
|
def init_profier(self):
|
657
660
|
self.torch_profiler = None
|
658
661
|
self.torch_profiler_output_dir: Optional[str] = None
|
@@ -1126,6 +1129,7 @@ class Scheduler(
|
|
1126
1129
|
bootstrap_port=recv_req.bootstrap_port,
|
1127
1130
|
bootstrap_room=recv_req.bootstrap_room,
|
1128
1131
|
data_parallel_rank=recv_req.data_parallel_rank,
|
1132
|
+
vocab_size=self.model_config.vocab_size,
|
1129
1133
|
)
|
1130
1134
|
req.tokenizer = self.tokenizer
|
1131
1135
|
|
@@ -1392,8 +1396,10 @@ class Scheduler(
|
|
1392
1396
|
logger.info(f)
|
1393
1397
|
|
1394
1398
|
if self.enable_metrics:
|
1395
|
-
|
1396
|
-
|
1399
|
+
total_tokens = adder.log_input_tokens + adder.log_hit_tokens
|
1400
|
+
|
1401
|
+
cache_hit_rate = (
|
1402
|
+
adder.log_hit_tokens / total_tokens if total_tokens > 0 else 0.0
|
1397
1403
|
)
|
1398
1404
|
self.stats.num_running_reqs = running_bs
|
1399
1405
|
self.stats.num_used_tokens = num_used
|
@@ -1706,13 +1712,13 @@ class Scheduler(
|
|
1706
1712
|
self.chunked_req.init_next_round_input()
|
1707
1713
|
self.chunked_req = adder.add_chunked_req(self.chunked_req)
|
1708
1714
|
|
1709
|
-
if self.
|
1715
|
+
if self.enable_lora:
|
1710
1716
|
lora_set = set([req.lora_path for req in self.running_batch.reqs])
|
1711
1717
|
|
1712
1718
|
# Get requests from the waiting queue to a new prefill batch
|
1713
1719
|
for req in self.waiting_queue:
|
1714
1720
|
if (
|
1715
|
-
self.
|
1721
|
+
self.enable_lora
|
1716
1722
|
and len(
|
1717
1723
|
lora_set
|
1718
1724
|
| set([req.lora_path for req in adder.can_run_list])
|
@@ -2466,12 +2472,6 @@ class Scheduler(
|
|
2466
2472
|
"""In-place loading a new lora adapter from disk or huggingface."""
|
2467
2473
|
|
2468
2474
|
result = self.tp_worker.load_lora_adapter(recv_req)
|
2469
|
-
|
2470
|
-
if result.success:
|
2471
|
-
flush_cache_success = self.flush_cache()
|
2472
|
-
assert flush_cache_success, "Cache flush failed after loading lora adapter."
|
2473
|
-
else:
|
2474
|
-
logger.error(result.error_message)
|
2475
2475
|
return result
|
2476
2476
|
|
2477
2477
|
def unload_lora_adapter(
|
@@ -2480,14 +2480,6 @@ class Scheduler(
|
|
2480
2480
|
"""Unload the lora adapter."""
|
2481
2481
|
|
2482
2482
|
result = self.tp_worker.unload_lora_adapter(recv_req)
|
2483
|
-
|
2484
|
-
if result.success:
|
2485
|
-
flush_cache_success = self.flush_cache()
|
2486
|
-
assert (
|
2487
|
-
flush_cache_success
|
2488
|
-
), "Cache flush failed after unloading LoRA weights"
|
2489
|
-
else:
|
2490
|
-
logger.error(result.error_message)
|
2491
2483
|
return result
|
2492
2484
|
|
2493
2485
|
def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
|
@@ -2909,9 +2901,9 @@ def run_scheduler_process(
|
|
2909
2901
|
prefix += f" PP{pp_rank}"
|
2910
2902
|
|
2911
2903
|
# Config the process
|
2912
|
-
kill_itself_when_parent_died()
|
2913
2904
|
setproctitle.setproctitle(f"sglang::scheduler{prefix.replace(' ', '_')}")
|
2914
2905
|
faulthandler.enable()
|
2906
|
+
kill_itself_when_parent_died()
|
2915
2907
|
parent_process = psutil.Process().parent()
|
2916
2908
|
|
2917
2909
|
# [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var
|
@@ -2926,10 +2918,6 @@ def run_scheduler_process(
|
|
2926
2918
|
if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
|
2927
2919
|
set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)
|
2928
2920
|
|
2929
|
-
embedding_cache_size = 100
|
2930
|
-
if "SGLANG_VLM_CACHE_SIZE_MB" in os.environ:
|
2931
|
-
embedding_cache_size = int(os.environ["SGLANG_VLM_CACHE_SIZE_MB"])
|
2932
|
-
init_embedding_cache(embedding_cache_size * 1024 * 1024)
|
2933
2921
|
# Create a scheduler and run the event loop
|
2934
2922
|
try:
|
2935
2923
|
scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, pp_rank, dp_rank)
|
@@ -2940,8 +2928,8 @@ def run_scheduler_process(
|
|
2940
2928
|
"max_req_input_len": scheduler.max_req_input_len,
|
2941
2929
|
}
|
2942
2930
|
)
|
2943
|
-
disaggregation_mode: DisaggregationMode = scheduler.disaggregation_mode
|
2944
2931
|
|
2932
|
+
disaggregation_mode: DisaggregationMode = scheduler.disaggregation_mode
|
2945
2933
|
if disaggregation_mode == DisaggregationMode.NULL:
|
2946
2934
|
if server_args.pp_size > 1:
|
2947
2935
|
scheduler.event_loop_pp()
|
@@ -62,6 +62,7 @@ from sglang.srt.hf_transformers_utils import (
|
|
62
62
|
get_tokenizer,
|
63
63
|
get_tokenizer_from_processor,
|
64
64
|
)
|
65
|
+
from sglang.srt.lora.lora_registry import LoRARef, LoRARegistry
|
65
66
|
from sglang.srt.managers.io_struct import (
|
66
67
|
AbortReq,
|
67
68
|
BatchEmbeddingOut,
|
@@ -242,11 +243,11 @@ class TokenizerManager:
|
|
242
243
|
revision=server_args.revision,
|
243
244
|
)
|
244
245
|
|
245
|
-
# Initialize
|
246
|
-
#
|
247
|
-
|
248
|
-
|
249
|
-
)
|
246
|
+
# Initialize the `LoRARegistry` with initial LoRA adapter paths provided in `server_args`.
|
247
|
+
# The registry dynamically updates as adapters are loaded / unloaded during runtime. It
|
248
|
+
# serves as the source of truth for available adapters and maps user-friendly LoRA names
|
249
|
+
# to internally used unique LoRA IDs.
|
250
|
+
self.lora_registry = LoRARegistry(self.server_args.lora_paths or {})
|
250
251
|
|
251
252
|
# Store states
|
252
253
|
self.no_create_loop = False
|
@@ -523,6 +524,10 @@ class TokenizerManager:
|
|
523
524
|
else:
|
524
525
|
mm_inputs = None
|
525
526
|
|
527
|
+
if self.server_args.enable_lora and obj.lora_path:
|
528
|
+
# Replace the user-friendly LoRA names in `lora_path` with their corresponding unique LoRA IDs.
|
529
|
+
obj.lora_path = await self.lora_registry.acquire(obj.lora_path)
|
530
|
+
|
526
531
|
self._validate_one_request(obj, input_ids)
|
527
532
|
return self._create_tokenized_object(
|
528
533
|
obj, input_text, input_ids, input_embeds, mm_inputs, token_type_ids
|
@@ -574,8 +579,6 @@ class TokenizerManager:
|
|
574
579
|
"The server is not configured to enable custom logit processor. "
|
575
580
|
"Please set `--enable-custom-logits-processor` to enable this feature."
|
576
581
|
)
|
577
|
-
if self.server_args.enable_lora and obj.lora_path:
|
578
|
-
self._validate_lora_adapters(obj)
|
579
582
|
|
580
583
|
def _validate_input_ids_in_vocab(
|
581
584
|
self, input_ids: List[int], vocab_size: int
|
@@ -689,21 +692,6 @@ class TokenizerManager:
|
|
689
692
|
"Batch tokenization is not needed for input_embeds. Do not set `enable_tokenizer_batch_encode`."
|
690
693
|
)
|
691
694
|
|
692
|
-
def _validate_lora_adapters(self, obj: GenerateReqInput):
|
693
|
-
"""Validate that the requested LoRA adapters are loaded."""
|
694
|
-
requested_adapters = (
|
695
|
-
set(obj.lora_path) if isinstance(obj.lora_path, list) else {obj.lora_path}
|
696
|
-
)
|
697
|
-
loaded_adapters = (
|
698
|
-
self.loaded_lora_adapters.keys() if self.loaded_lora_adapters else set()
|
699
|
-
)
|
700
|
-
unloaded_adapters = requested_adapters - loaded_adapters
|
701
|
-
if unloaded_adapters:
|
702
|
-
raise ValueError(
|
703
|
-
f"The following requested LoRA adapters are not loaded: {unloaded_adapters}\n"
|
704
|
-
f"Loaded adapters: {loaded_adapters}."
|
705
|
-
)
|
706
|
-
|
707
695
|
def _send_one_request(
|
708
696
|
self,
|
709
697
|
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
@@ -1054,8 +1042,18 @@ class TokenizerManager:
|
|
1054
1042
|
)
|
1055
1043
|
|
1056
1044
|
async with self.model_update_lock.writer_lock:
|
1045
|
+
# Generate new uniquely identifiable LoRARef object.
|
1046
|
+
new_adapter = LoRARef(
|
1047
|
+
lora_name=obj.lora_name,
|
1048
|
+
lora_path=obj.lora_path,
|
1049
|
+
)
|
1050
|
+
|
1051
|
+
# Register the new adapter in the registry.
|
1052
|
+
obj.lora_id = new_adapter.lora_id
|
1057
1053
|
result = (await self.update_lora_adapter_communicator(obj))[0]
|
1058
|
-
|
1054
|
+
if result.success:
|
1055
|
+
await self.lora_registry.register(new_adapter)
|
1056
|
+
|
1059
1057
|
return result
|
1060
1058
|
|
1061
1059
|
async def unload_lora_adapter(
|
@@ -1069,6 +1067,10 @@ class TokenizerManager:
|
|
1069
1067
|
"LoRA is not enabled. Please set `--enable-lora` to enable LoRA."
|
1070
1068
|
)
|
1071
1069
|
|
1070
|
+
assert (
|
1071
|
+
obj.lora_name is not None
|
1072
|
+
), "lora_name must be provided to unload LoRA adapter"
|
1073
|
+
|
1072
1074
|
# TODO (lifuhuang): Remove this after we verify that dynamic lora loading works
|
1073
1075
|
# with dp_size > 1.
|
1074
1076
|
assert (
|
@@ -1080,8 +1082,9 @@ class TokenizerManager:
|
|
1080
1082
|
)
|
1081
1083
|
|
1082
1084
|
async with self.model_update_lock.writer_lock:
|
1085
|
+
obj.lora_id = await self.lora_registry.unregister(obj.lora_name)
|
1083
1086
|
result = (await self.update_lora_adapter_communicator(obj))[0]
|
1084
|
-
|
1087
|
+
|
1085
1088
|
return result
|
1086
1089
|
|
1087
1090
|
async def get_weights_by_name(
|
@@ -1309,7 +1312,7 @@ class TokenizerManager:
|
|
1309
1312
|
filename = os.path.join(
|
1310
1313
|
self.crash_dump_folder,
|
1311
1314
|
os.getenv("HOSTNAME", None),
|
1312
|
-
f
|
1315
|
+
f"crash_dump_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.pkl",
|
1313
1316
|
)
|
1314
1317
|
|
1315
1318
|
os.makedirs(os.path.dirname(filename), exist_ok=True)
|
sglang/srt/managers/tp_worker.py
CHANGED
@@ -293,11 +293,9 @@ class TpModelWorker:
|
|
293
293
|
return parameter
|
294
294
|
|
295
295
|
def load_lora_adapter(self, recv_req: LoadLoRAAdapterReqInput):
|
296
|
-
result = self.model_runner.load_lora_adapter(
|
297
|
-
recv_req.lora_name, recv_req.lora_path
|
298
|
-
)
|
296
|
+
result = self.model_runner.load_lora_adapter(recv_req.to_ref())
|
299
297
|
return result
|
300
298
|
|
301
299
|
def unload_lora_adapter(self, recv_req: UnloadLoRAAdapterReqInput):
|
302
|
-
result = self.model_runner.unload_lora_adapter(recv_req.
|
300
|
+
result = self.model_runner.unload_lora_adapter(recv_req.to_ref())
|
303
301
|
return result
|
@@ -51,6 +51,7 @@ class BaseTokenToKVPoolAllocator(abc.ABC):
|
|
51
51
|
self._kvcache = kvcache
|
52
52
|
|
53
53
|
self.free_pages = None
|
54
|
+
self.release_pages = None
|
54
55
|
self.is_not_in_free_group = True
|
55
56
|
self.free_group = []
|
56
57
|
|
@@ -58,16 +59,16 @@ class BaseTokenToKVPoolAllocator(abc.ABC):
|
|
58
59
|
return ""
|
59
60
|
|
60
61
|
def available_size(self):
|
61
|
-
return len(self.free_pages) * self.page_size
|
62
|
+
return (len(self.free_pages) + len(self.release_pages)) * self.page_size
|
62
63
|
|
63
64
|
def get_kvcache(self):
|
64
65
|
return self._kvcache
|
65
66
|
|
66
|
-
def restore_state(self,
|
67
|
-
self.free_pages =
|
67
|
+
def restore_state(self, state):
|
68
|
+
self.free_pages, self.release_pages = state
|
68
69
|
|
69
70
|
def backup_state(self):
|
70
|
-
return self.free_pages
|
71
|
+
return (self.free_pages, self.release_pages)
|
71
72
|
|
72
73
|
def free_group_begin(self):
|
73
74
|
self.is_not_in_free_group = False
|
@@ -78,6 +79,14 @@ class BaseTokenToKVPoolAllocator(abc.ABC):
|
|
78
79
|
if self.free_group:
|
79
80
|
self.free(torch.cat(self.free_group))
|
80
81
|
|
82
|
+
def merge_and_sort_free(self):
|
83
|
+
if len(self.release_pages) > 0:
|
84
|
+
self.free_pages = torch.cat((self.free_pages, self.release_pages))
|
85
|
+
self.free_pages, _ = torch.sort(self.free_pages)
|
86
|
+
self.release_pages = torch.empty(
|
87
|
+
(0,), dtype=self.release_pages.dtype, device=self.device
|
88
|
+
)
|
89
|
+
|
81
90
|
def get_cpu_copy(self, *args, **kwargs):
|
82
91
|
# FIXME: reuse the get_cpu_copy after paged allocator is implemented
|
83
92
|
raise NotImplementedError()
|
@@ -119,12 +128,15 @@ class TokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
|
119
128
|
)
|
120
129
|
self.is_not_in_free_group = True
|
121
130
|
self.free_group = []
|
131
|
+
self.release_pages = torch.empty((0,), dtype=torch.int64, device=self.device)
|
122
132
|
|
123
133
|
def available_size(self):
|
124
134
|
# To avoid minor "len(free_pages) * 1" overhead
|
125
|
-
return len(self.free_pages)
|
135
|
+
return len(self.free_pages) + len(self.release_pages)
|
126
136
|
|
127
137
|
def alloc(self, need_size: int):
|
138
|
+
if need_size > len(self.free_pages):
|
139
|
+
self.merge_and_sort_free()
|
128
140
|
if need_size > len(self.free_pages):
|
129
141
|
return None
|
130
142
|
|
@@ -137,7 +149,7 @@ class TokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
|
137
149
|
return
|
138
150
|
|
139
151
|
if self.is_not_in_free_group:
|
140
|
-
self.
|
152
|
+
self.release_pages = torch.cat((self.release_pages, free_index))
|
141
153
|
else:
|
142
154
|
self.free_group.append(free_index)
|
143
155
|
|
@@ -421,6 +433,8 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
|
421
433
|
), "The allocation size should be page-aligned"
|
422
434
|
|
423
435
|
num_pages = need_size // self.page_size
|
436
|
+
if num_pages > len(self.free_pages):
|
437
|
+
self.merge_and_sort_free()
|
424
438
|
if num_pages > len(self.free_pages):
|
425
439
|
return None
|
426
440
|
|
@@ -446,6 +460,17 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
|
446
460
|
(last_loc + 1) % self.page_size == prefix_lens % self.page_size
|
447
461
|
)
|
448
462
|
|
463
|
+
estimated_num_new_pages = (
|
464
|
+
(
|
465
|
+
(seq_lens + self.page_size - 1) // self.page_size
|
466
|
+
- (prefix_lens + self.page_size - 1) // self.page_size
|
467
|
+
)
|
468
|
+
.sum()
|
469
|
+
.item()
|
470
|
+
)
|
471
|
+
if estimated_num_new_pages > len(self.free_pages):
|
472
|
+
self.merge_and_sort_free()
|
473
|
+
|
449
474
|
bs = len(prefix_lens)
|
450
475
|
out_indices = torch.empty(
|
451
476
|
(extend_num_tokens,), dtype=torch.int64, device=self.device
|
@@ -483,6 +508,17 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
|
483
508
|
(last_loc + 2) % self.page_size == seq_lens % self.page_size
|
484
509
|
)
|
485
510
|
|
511
|
+
estimated_num_new_pages = (
|
512
|
+
(
|
513
|
+
(seq_lens + self.page_size - 1) // self.page_size
|
514
|
+
- (seq_lens - 1 + self.page_size - 1) // self.page_size
|
515
|
+
)
|
516
|
+
.sum()
|
517
|
+
.item()
|
518
|
+
)
|
519
|
+
if estimated_num_new_pages > len(self.free_pages):
|
520
|
+
self.merge_and_sort_free()
|
521
|
+
|
486
522
|
bs = len(seq_lens)
|
487
523
|
out_indices = torch.empty((bs,), dtype=torch.int64, device=self.device)
|
488
524
|
alloc_decode_kernel[(bs,)](
|
@@ -511,7 +547,7 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
|
511
547
|
|
512
548
|
if self.is_not_in_free_group:
|
513
549
|
free_page_indices = torch.unique(free_index // self.page_size)
|
514
|
-
self.
|
550
|
+
self.release_pages = torch.cat((free_page_indices, self.release_pages))
|
515
551
|
else:
|
516
552
|
self.free_group.append(free_index)
|
517
553
|
|
@@ -525,6 +561,7 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
|
525
561
|
)
|
526
562
|
self.is_not_in_free_group = True
|
527
563
|
self.free_group = []
|
564
|
+
self.release_pages = torch.empty((0,), dtype=torch.int64, device=self.device)
|
528
565
|
|
529
566
|
def get_cpu_copy(self, indices):
|
530
567
|
return self._kvcache.get_cpu_copy(indices)
|
@@ -633,6 +670,17 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
|
|
633
670
|
(last_loc + 1) % self.page_size == prefix_lens % self.page_size
|
634
671
|
)
|
635
672
|
|
673
|
+
estimated_num_new_pages = (
|
674
|
+
(
|
675
|
+
(seq_lens + self.page_size - 1) // self.page_size
|
676
|
+
- (prefix_lens + self.page_size - 1) // self.page_size
|
677
|
+
)
|
678
|
+
.sum()
|
679
|
+
.item()
|
680
|
+
)
|
681
|
+
if estimated_num_new_pages > len(self.free_pages):
|
682
|
+
self.merge_and_sort_free()
|
683
|
+
|
636
684
|
bs = len(prefix_lens)
|
637
685
|
out_indices = torch.empty(
|
638
686
|
(extend_num_tokens,), dtype=torch.int32, device=self.device
|
@@ -668,6 +716,17 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
|
|
668
716
|
(last_loc + 2) % self.page_size == seq_lens % self.page_size
|
669
717
|
)
|
670
718
|
|
719
|
+
estimated_num_new_pages = (
|
720
|
+
(
|
721
|
+
(seq_lens + self.page_size - 1) // self.page_size
|
722
|
+
- (seq_lens - 1 + self.page_size - 1) // self.page_size
|
723
|
+
)
|
724
|
+
.sum()
|
725
|
+
.item()
|
726
|
+
)
|
727
|
+
if estimated_num_new_pages > len(self.free_pages):
|
728
|
+
self.merge_and_sort_free()
|
729
|
+
|
671
730
|
bs = len(seq_lens)
|
672
731
|
out_indices = torch.empty((bs,), dtype=torch.int32, device=self.device)
|
673
732
|
|
@@ -692,3 +751,4 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
|
|
692
751
|
def clear(self):
|
693
752
|
super().clear()
|
694
753
|
self.free_pages = self.free_pages.to(torch.int32)
|
754
|
+
self.release_pages = self.release_pages.to(torch.int32)
|