sglang 0.5.0rc0__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +8 -3
- sglang/bench_one_batch.py +6 -0
- sglang/lang/chat_template.py +18 -0
- sglang/srt/bench_utils.py +137 -0
- sglang/srt/configs/model_config.py +7 -7
- sglang/srt/disaggregation/decode.py +8 -3
- sglang/srt/disaggregation/mooncake/conn.py +43 -25
- sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
- sglang/srt/distributed/parallel_state.py +4 -2
- sglang/srt/entrypoints/context.py +3 -20
- sglang/srt/entrypoints/engine.py +13 -8
- sglang/srt/entrypoints/harmony_utils.py +2 -0
- sglang/srt/entrypoints/http_server.py +4 -5
- sglang/srt/entrypoints/openai/protocol.py +0 -9
- sglang/srt/entrypoints/openai/serving_chat.py +59 -265
- sglang/srt/entrypoints/openai/tool_server.py +4 -3
- sglang/srt/function_call/ebnf_composer.py +1 -0
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +331 -0
- sglang/srt/function_call/kimik2_detector.py +3 -3
- sglang/srt/function_call/qwen3_coder_detector.py +219 -9
- sglang/srt/jinja_template_utils.py +6 -0
- sglang/srt/layers/attention/aiter_backend.py +370 -107
- sglang/srt/layers/attention/ascend_backend.py +3 -0
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/flashattention_backend.py +18 -0
- sglang/srt/layers/attention/flashinfer_backend.py +52 -13
- sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
- sglang/srt/layers/attention/trtllm_mla_backend.py +119 -22
- sglang/srt/layers/attention/vision.py +9 -1
- sglang/srt/layers/attention/wave_backend.py +627 -0
- sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
- sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
- sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
- sglang/srt/layers/communicator.py +8 -10
- sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
- sglang/srt/layers/linear.py +1 -0
- sglang/srt/layers/moe/cutlass_moe.py +11 -16
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
- sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
- sglang/srt/layers/moe/ep_moe/layer.py +60 -2
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -9
- sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
- sglang/srt/layers/moe/topk.py +4 -1
- sglang/srt/layers/quantization/__init__.py +5 -3
- sglang/srt/layers/quantization/fp8_kernel.py +277 -0
- sglang/srt/layers/quantization/fp8_utils.py +22 -10
- sglang/srt/layers/quantization/modelopt_quant.py +6 -11
- sglang/srt/layers/quantization/mxfp4.py +4 -1
- sglang/srt/layers/quantization/w4afp8.py +20 -11
- sglang/srt/layers/quantization/w8a8_int8.py +48 -34
- sglang/srt/layers/rotary_embedding.py +281 -2
- sglang/srt/lora/backend/base_backend.py +3 -23
- sglang/srt/lora/layers.py +60 -114
- sglang/srt/lora/lora.py +17 -62
- sglang/srt/lora/lora_manager.py +12 -48
- sglang/srt/lora/lora_registry.py +20 -9
- sglang/srt/lora/mem_pool.py +20 -63
- sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
- sglang/srt/lora/utils.py +25 -58
- sglang/srt/managers/cache_controller.py +21 -29
- sglang/srt/managers/detokenizer_manager.py +1 -1
- sglang/srt/managers/io_struct.py +6 -6
- sglang/srt/managers/mm_utils.py +1 -2
- sglang/srt/managers/multimodal_processor.py +1 -1
- sglang/srt/managers/schedule_batch.py +35 -20
- sglang/srt/managers/schedule_policy.py +6 -6
- sglang/srt/managers/scheduler.py +15 -7
- sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
- sglang/srt/managers/tokenizer_manager.py +25 -26
- sglang/srt/mem_cache/allocator.py +61 -87
- sglang/srt/mem_cache/hicache_storage.py +1 -1
- sglang/srt/mem_cache/hiradix_cache.py +34 -24
- sglang/srt/mem_cache/lora_radix_cache.py +421 -0
- sglang/srt/mem_cache/memory_pool_host.py +33 -35
- sglang/srt/mem_cache/radix_cache.py +2 -5
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
- sglang/srt/model_executor/cuda_graph_runner.py +22 -3
- sglang/srt/model_executor/forward_batch_info.py +26 -5
- sglang/srt/model_executor/model_runner.py +129 -35
- sglang/srt/model_loader/loader.py +18 -6
- sglang/srt/models/deepseek_v2.py +74 -35
- sglang/srt/models/gemma2.py +0 -34
- sglang/srt/models/gemma3n_mm.py +8 -9
- sglang/srt/models/glm4.py +6 -0
- sglang/srt/models/glm4_moe.py +9 -9
- sglang/srt/models/glm4v.py +589 -0
- sglang/srt/models/glm4v_moe.py +400 -0
- sglang/srt/models/gpt_oss.py +136 -19
- sglang/srt/models/granite.py +0 -25
- sglang/srt/models/llama.py +0 -25
- sglang/srt/models/llama4.py +1 -1
- sglang/srt/models/qwen2_5_vl.py +7 -3
- sglang/srt/models/qwen2_audio.py +10 -9
- sglang/srt/models/qwen3.py +0 -24
- sglang/srt/models/registry.py +1 -1
- sglang/srt/models/torch_native_llama.py +0 -24
- sglang/srt/multimodal/processors/base_processor.py +23 -13
- sglang/srt/multimodal/processors/glm4v.py +132 -0
- sglang/srt/multimodal/processors/qwen_audio.py +4 -2
- sglang/srt/reasoning_parser.py +316 -0
- sglang/srt/server_args.py +115 -139
- sglang/srt/speculative/eagle_worker.py +16 -0
- sglang/srt/two_batch_overlap.py +12 -4
- sglang/srt/utils.py +3 -3
- sglang/srt/weight_sync/tensor_bucket.py +106 -0
- sglang/test/attention/test_trtllm_mla_backend.py +186 -36
- sglang/test/doc_patch.py +59 -0
- sglang/test/few_shot_gsm8k.py +1 -1
- sglang/test/few_shot_gsm8k_engine.py +1 -1
- sglang/test/run_eval.py +4 -1
- sglang/test/simple_eval_common.py +6 -0
- sglang/test/simple_eval_gpqa.py +2 -0
- sglang/test/test_fp4_moe.py +118 -36
- sglang/utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc1.dist-info}/METADATA +26 -30
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc1.dist-info}/RECORD +127 -115
- sglang/lang/backend/__init__.py +0 -0
- sglang/srt/function_call/harmony_tool_parser.py +0 -130
- sglang/srt/lora/backend/flashinfer_backend.py +0 -131
- /sglang/{api.py → lang/api.py} +0 -0
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc1.dist-info}/WHEEL +0 -0
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc1.dist-info}/top_level.txt +0 -0
@@ -169,12 +169,13 @@ class StorageOperation:
|
|
169
169
|
host_indices: torch.Tensor,
|
170
170
|
token_ids: List[int],
|
171
171
|
last_hash: Optional[str] = None,
|
172
|
+
hash_value: Optional[List[str]] = None,
|
172
173
|
):
|
173
174
|
self.host_indices = host_indices
|
174
175
|
self.token_ids = token_ids
|
175
176
|
self.last_hash = last_hash
|
176
177
|
self.completed_tokens = 0
|
177
|
-
self.hash_value = []
|
178
|
+
self.hash_value = hash_value if hash_value is not None else []
|
178
179
|
|
179
180
|
self.id = StorageOperation.counter
|
180
181
|
StorageOperation.counter += 1
|
@@ -259,6 +260,7 @@ class HiCacheController:
|
|
259
260
|
self.storage_backend = MooncakeStore()
|
260
261
|
self.get_hash_str = get_hash_str_mooncake
|
261
262
|
self.storage_backend.register_buffer(self.mem_pool_host.kv_buffer)
|
263
|
+
assert self.mem_pool_host.layout == "page_first"
|
262
264
|
elif storage_backend == "hf3fs":
|
263
265
|
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
264
266
|
from sglang.srt.mem_cache.storage.hf3fs.storage_hf3fs import (
|
@@ -433,7 +435,9 @@ class HiCacheController:
|
|
433
435
|
if self.io_backend == "kernel":
|
434
436
|
return host_indices.to(self.mem_pool_device.device), device_indices
|
435
437
|
elif self.io_backend == "direct":
|
436
|
-
|
438
|
+
device_indices = device_indices.cpu()
|
439
|
+
host_indices, idx = host_indices.sort()
|
440
|
+
return host_indices, device_indices.index_select(0, idx)
|
437
441
|
else:
|
438
442
|
raise ValueError(f"Unsupported io backend")
|
439
443
|
|
@@ -570,10 +574,6 @@ class HiCacheController:
|
|
570
574
|
)
|
571
575
|
completed_tokens += self.page_size
|
572
576
|
else:
|
573
|
-
# operation terminated by controller, release pre-allocated memory
|
574
|
-
self.mem_pool_host.free(
|
575
|
-
operation.host_indices[operation.completed_tokens :]
|
576
|
-
)
|
577
577
|
break
|
578
578
|
|
579
579
|
def mooncake_page_transfer(self, operation):
|
@@ -599,6 +599,14 @@ class HiCacheController:
|
|
599
599
|
self.generic_page_transfer(operation, batch_size=128)
|
600
600
|
else:
|
601
601
|
self.generic_page_transfer(operation)
|
602
|
+
|
603
|
+
if self.tp_world_size > 1:
|
604
|
+
# to ensure all TP workers release the host memory at the same time
|
605
|
+
torch.distributed.barrier(group=self.prefetch_tp_group)
|
606
|
+
# operation terminated by controller, release pre-allocated memory
|
607
|
+
self.mem_pool_host.free(
|
608
|
+
operation.host_indices[operation.completed_tokens :]
|
609
|
+
)
|
602
610
|
except Empty:
|
603
611
|
continue
|
604
612
|
|
@@ -626,7 +634,9 @@ class HiCacheController:
|
|
626
634
|
continue
|
627
635
|
|
628
636
|
storage_hit_count = 0
|
629
|
-
if
|
637
|
+
if (
|
638
|
+
operation.host_indices is not None
|
639
|
+
) and self.prefetch_rate_limit_check():
|
630
640
|
last_hash = operation.last_hash
|
631
641
|
tokens_to_fetch = operation.token_ids
|
632
642
|
|
@@ -670,7 +680,8 @@ class HiCacheController:
|
|
670
680
|
if storage_hit_count < self.prefetch_threshold:
|
671
681
|
# not to prefetch if not enough benefits
|
672
682
|
self.prefetch_revoke_queue.put(operation.request_id)
|
673
|
-
|
683
|
+
if operation.host_indices is not None:
|
684
|
+
self.mem_pool_host.free(operation.host_indices)
|
674
685
|
logger.debug(
|
675
686
|
f"Revoking prefetch for request {operation.request_id} due to insufficient hits ({storage_hit_count})."
|
676
687
|
)
|
@@ -693,12 +704,12 @@ class HiCacheController:
|
|
693
704
|
self,
|
694
705
|
host_indices: torch.Tensor,
|
695
706
|
token_ids: List[int],
|
696
|
-
|
707
|
+
hash_value: Optional[List[str]] = None,
|
697
708
|
) -> int:
|
698
709
|
"""
|
699
710
|
Write KV caches from host memory to storage backend.
|
700
711
|
"""
|
701
|
-
operation = StorageOperation(host_indices, token_ids,
|
712
|
+
operation = StorageOperation(host_indices, token_ids, hash_value=hash_value)
|
702
713
|
self.backup_queue.put(operation)
|
703
714
|
return operation.id
|
704
715
|
|
@@ -753,24 +764,6 @@ class HiCacheController:
|
|
753
764
|
if operation is None:
|
754
765
|
continue
|
755
766
|
|
756
|
-
last_hash = operation.last_hash
|
757
|
-
tokens_to_backup = operation.token_ids
|
758
|
-
|
759
|
-
backup_hit_count = 0
|
760
|
-
remaining_tokens = len(tokens_to_backup)
|
761
|
-
hash_value = []
|
762
|
-
while remaining_tokens >= self.page_size:
|
763
|
-
last_hash = self.get_hash_str(
|
764
|
-
tokens_to_backup[
|
765
|
-
backup_hit_count : backup_hit_count + self.page_size
|
766
|
-
],
|
767
|
-
last_hash,
|
768
|
-
)
|
769
|
-
backup_hit_count += self.page_size
|
770
|
-
hash_value.append(last_hash)
|
771
|
-
remaining_tokens -= self.page_size
|
772
|
-
operation.hash_value = hash_value
|
773
|
-
|
774
767
|
if self.is_mooncake_backend():
|
775
768
|
self.mooncake_page_backup(operation)
|
776
769
|
elif self.storage_backend_type == "hf3fs":
|
@@ -793,7 +786,6 @@ class HiCacheController:
|
|
793
786
|
self.ack_backup_queue.put(
|
794
787
|
(
|
795
788
|
operation.id,
|
796
|
-
operation.hash_value[: min_completed_tokens // self.page_size],
|
797
789
|
min_completed_tokens,
|
798
790
|
)
|
799
791
|
)
|
@@ -216,7 +216,7 @@ class DetokenizerManager:
|
|
216
216
|
rids=recv_obj.rids,
|
217
217
|
finished_reasons=recv_obj.finished_reasons,
|
218
218
|
output_strs=output_strs,
|
219
|
-
output_ids=recv_obj.
|
219
|
+
output_ids=recv_obj.output_ids,
|
220
220
|
prompt_tokens=recv_obj.prompt_tokens,
|
221
221
|
completion_tokens=recv_obj.completion_tokens,
|
222
222
|
cached_tokens=recv_obj.cached_tokens,
|
sglang/srt/managers/io_struct.py
CHANGED
@@ -99,25 +99,24 @@ class GenerateReqInput:
|
|
99
99
|
stream: bool = False
|
100
100
|
# Whether to log metrics for this request (e.g. health_generate calls do not log metrics)
|
101
101
|
log_metrics: bool = True
|
102
|
+
# Whether to return hidden states
|
103
|
+
return_hidden_states: Union[List[bool], bool] = False
|
102
104
|
|
103
105
|
# The modalities of the image data [image, multi-images, video]
|
104
106
|
modalities: Optional[List[str]] = None
|
107
|
+
# Session info for continual prompting
|
108
|
+
session_params: Optional[Union[List[Dict], Dict]] = None
|
109
|
+
|
105
110
|
# The path to the LoRA adaptors
|
106
111
|
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
|
107
112
|
# The uid of LoRA adaptors, should be initialized by tokenizer manager
|
108
113
|
lora_id: Optional[Union[List[Optional[str]], Optional[str]]] = None
|
109
114
|
|
110
|
-
# Session info for continual prompting
|
111
|
-
session_params: Optional[Union[List[Dict], Dict]] = None
|
112
|
-
|
113
115
|
# Custom logit processor for advanced sampling control. Must be a serialized instance
|
114
116
|
# of `CustomLogitProcessor` in python/sglang/srt/sampling/custom_logit_processor.py
|
115
117
|
# Use the processor's `to_str()` method to generate the serialized string.
|
116
118
|
custom_logit_processor: Optional[Union[List[Optional[str]], str]] = None
|
117
119
|
|
118
|
-
# Whether to return hidden states
|
119
|
-
return_hidden_states: Union[List[bool], bool] = False
|
120
|
-
|
121
120
|
# For disaggregated inference
|
122
121
|
bootstrap_host: Optional[Union[List[str], str]] = None
|
123
122
|
bootstrap_port: Optional[Union[List[Optional[int]], int]] = None
|
@@ -456,6 +455,7 @@ class GenerateReqInput:
|
|
456
455
|
log_metrics=self.log_metrics,
|
457
456
|
modalities=self.modalities[i] if self.modalities else None,
|
458
457
|
lora_path=self.lora_path[i] if self.lora_path is not None else None,
|
458
|
+
lora_id=self.lora_id[i] if self.lora_id is not None else None,
|
459
459
|
custom_logit_processor=(
|
460
460
|
self.custom_logit_processor[i]
|
461
461
|
if self.custom_logit_processor is not None
|
sglang/srt/managers/mm_utils.py
CHANGED
@@ -614,8 +614,7 @@ def general_mm_embed_routine(
|
|
614
614
|
input_ids: Input token IDs tensor
|
615
615
|
forward_batch: Batch information for model forward pass
|
616
616
|
language_model: Base language model to use
|
617
|
-
|
618
|
-
audio_data_embedding_func: Function to embed audio data
|
617
|
+
data_embedding_funcs: A dictionary mapping from modality type to the corresponding embedding function.
|
619
618
|
placeholder_tokens: Token IDs for multimodal placeholders
|
620
619
|
**kwargs: Additional arguments passed to language model
|
621
620
|
|
@@ -20,7 +20,7 @@ def import_processors():
|
|
20
20
|
try:
|
21
21
|
module = importlib.import_module(name)
|
22
22
|
except Exception as e:
|
23
|
-
logger.warning(f"Ignore import error when loading {name}:
|
23
|
+
logger.warning(f"Ignore import error when loading {name}: {e}")
|
24
24
|
continue
|
25
25
|
all_members = inspect.getmembers(module, inspect.isclass)
|
26
26
|
classes = [
|
@@ -37,6 +37,7 @@ import logging
|
|
37
37
|
import threading
|
38
38
|
from enum import Enum, auto
|
39
39
|
from http import HTTPStatus
|
40
|
+
from itertools import chain
|
40
41
|
from typing import TYPE_CHECKING, Any, List, Optional, Set, Tuple, Union
|
41
42
|
|
42
43
|
import numpy as np
|
@@ -57,6 +58,7 @@ from sglang.srt.mem_cache.allocator import (
|
|
57
58
|
)
|
58
59
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
59
60
|
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
|
61
|
+
from sglang.srt.mem_cache.lora_radix_cache import LoRAKey, LoRARadixCache
|
60
62
|
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
|
61
63
|
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
|
62
64
|
from sglang.srt.metrics.collector import TimeStats
|
@@ -638,14 +640,26 @@ class Req:
|
|
638
640
|
):
|
639
641
|
self.fill_ids = self.origin_input_ids + self.output_ids
|
640
642
|
if tree_cache is not None:
|
641
|
-
(
|
642
|
-
|
643
|
-
|
644
|
-
|
645
|
-
|
646
|
-
|
647
|
-
|
648
|
-
|
643
|
+
if isinstance(tree_cache, LoRARadixCache):
|
644
|
+
(
|
645
|
+
self.prefix_indices,
|
646
|
+
self.last_node,
|
647
|
+
self.last_host_node,
|
648
|
+
self.host_hit_length,
|
649
|
+
) = tree_cache.match_prefix_with_lora_id(
|
650
|
+
key=LoRAKey(
|
651
|
+
lora_id=self.lora_id, token_ids=self.adjust_max_prefix_ids()
|
652
|
+
),
|
653
|
+
)
|
654
|
+
else:
|
655
|
+
(
|
656
|
+
self.prefix_indices,
|
657
|
+
self.last_node,
|
658
|
+
self.last_host_node,
|
659
|
+
self.host_hit_length,
|
660
|
+
) = tree_cache.match_prefix(
|
661
|
+
key=self.adjust_max_prefix_ids(),
|
662
|
+
)
|
649
663
|
self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
|
650
664
|
|
651
665
|
def adjust_max_prefix_ids(self):
|
@@ -1145,9 +1159,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1145
1159
|
req_pool_indices_tensor = torch.tensor(req_pool_indices, dtype=torch.int64).to(
|
1146
1160
|
self.device, non_blocking=True
|
1147
1161
|
)
|
1148
|
-
input_ids_tensor = torch.tensor(
|
1149
|
-
|
1150
|
-
)
|
1162
|
+
input_ids_tensor = torch.tensor(
|
1163
|
+
list(chain.from_iterable(input_ids)), dtype=torch.int64
|
1164
|
+
).to(self.device, non_blocking=True)
|
1151
1165
|
seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int64).to(
|
1152
1166
|
self.device, non_blocking=True
|
1153
1167
|
)
|
@@ -1713,15 +1727,16 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1713
1727
|
attention_backend_str = global_server_args_dict["prefill_attention_backend"]
|
1714
1728
|
# Create seq_lens_cpu when needed
|
1715
1729
|
if (
|
1716
|
-
attention_backend_str
|
1717
|
-
|
1718
|
-
|
1719
|
-
|
1720
|
-
|
1721
|
-
|
1722
|
-
|
1723
|
-
|
1724
|
-
|
1730
|
+
attention_backend_str
|
1731
|
+
in [
|
1732
|
+
"fa3",
|
1733
|
+
"flashinfer",
|
1734
|
+
"flashmla",
|
1735
|
+
"cutlass_mla",
|
1736
|
+
"ascend",
|
1737
|
+
"trtllm_mha",
|
1738
|
+
"aiter",
|
1739
|
+
]
|
1725
1740
|
or global_server_args_dict["enable_two_batch_overlap"]
|
1726
1741
|
):
|
1727
1742
|
seq_lens_cpu = (
|
@@ -36,7 +36,7 @@ if TYPE_CHECKING:
|
|
36
36
|
# This can prevent the server from being too conservative.
|
37
37
|
# Note that this only clips the estimation in the scheduler but does not change the stop
|
38
38
|
# condition. The request can still generate tokens until it hits the unclipped max_new_tokens.
|
39
|
-
|
39
|
+
CLIP_MAX_NEW_TOKENS = int(
|
40
40
|
os.environ.get("SGLANG_CLIP_MAX_NEW_TOKENS_ESTIMATION", "4096")
|
41
41
|
)
|
42
42
|
|
@@ -305,7 +305,7 @@ class PrefillAdder:
|
|
305
305
|
[
|
306
306
|
min(
|
307
307
|
(r.sampling_params.max_new_tokens - len(r.output_ids)),
|
308
|
-
|
308
|
+
CLIP_MAX_NEW_TOKENS,
|
309
309
|
)
|
310
310
|
* self.new_token_ratio
|
311
311
|
for r in running_batch.reqs
|
@@ -388,7 +388,7 @@ class PrefillAdder:
|
|
388
388
|
0,
|
389
389
|
req.extend_input_len,
|
390
390
|
(
|
391
|
-
min(req.sampling_params.max_new_tokens,
|
391
|
+
min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS)
|
392
392
|
if not truncated
|
393
393
|
else 0
|
394
394
|
),
|
@@ -477,7 +477,7 @@ class PrefillAdder:
|
|
477
477
|
self._update_prefill_budget(
|
478
478
|
0,
|
479
479
|
req.extend_input_len,
|
480
|
-
min(req.sampling_params.max_new_tokens,
|
480
|
+
min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS),
|
481
481
|
)
|
482
482
|
else:
|
483
483
|
if self.rem_chunk_tokens == 0:
|
@@ -499,7 +499,7 @@ class PrefillAdder:
|
|
499
499
|
return self.add_one_req_ignore_eos(req, has_chunked_req)
|
500
500
|
|
501
501
|
total_tokens = req.extend_input_len + min(
|
502
|
-
req.sampling_params.max_new_tokens,
|
502
|
+
req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS
|
503
503
|
)
|
504
504
|
|
505
505
|
# adjusting the input_tokens based on host_hit_length and page_size
|
@@ -544,7 +544,7 @@ class PrefillAdder:
|
|
544
544
|
input_tokens,
|
545
545
|
min(
|
546
546
|
req.sampling_params.max_new_tokens,
|
547
|
-
|
547
|
+
CLIP_MAX_NEW_TOKENS,
|
548
548
|
),
|
549
549
|
)
|
550
550
|
else:
|
sglang/srt/managers/scheduler.py
CHANGED
@@ -130,6 +130,7 @@ from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
|
|
130
130
|
from sglang.srt.managers.utils import DPBalanceMeta, validate_input_length
|
131
131
|
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
|
132
132
|
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
|
133
|
+
from sglang.srt.mem_cache.lora_radix_cache import LoRARadixCache
|
133
134
|
from sglang.srt.mem_cache.radix_cache import RadixCache
|
134
135
|
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
|
135
136
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors
|
@@ -611,12 +612,7 @@ class Scheduler(
|
|
611
612
|
hicache_ratio=server_args.hicache_ratio,
|
612
613
|
hicache_size=server_args.hicache_size,
|
613
614
|
hicache_write_policy=server_args.hicache_write_policy,
|
614
|
-
hicache_io_backend=
|
615
|
-
"direct"
|
616
|
-
if server_args.attention_backend
|
617
|
-
== "fa3" # hot fix for incompatibility
|
618
|
-
else server_args.hicache_io_backend
|
619
|
-
),
|
615
|
+
hicache_io_backend=server_args.hicache_io_backend,
|
620
616
|
hicache_mem_layout=server_args.hicache_mem_layout,
|
621
617
|
hicache_storage_backend=server_args.hicache_storage_backend,
|
622
618
|
hicache_storage_prefetch_policy=server_args.hicache_storage_prefetch_policy,
|
@@ -635,7 +631,19 @@ class Scheduler(
|
|
635
631
|
page_size=self.page_size,
|
636
632
|
disable=server_args.disable_radix_cache,
|
637
633
|
)
|
638
|
-
|
634
|
+
elif self.enable_lora:
|
635
|
+
assert (
|
636
|
+
not self.enable_hierarchical_cache
|
637
|
+
), "LoRA radix cache doesn't support hierarchical cache"
|
638
|
+
assert (
|
639
|
+
self.schedule_policy == "fcfs"
|
640
|
+
), "LoRA radix cache only supports FCFS policy"
|
641
|
+
self.tree_cache = LoRARadixCache(
|
642
|
+
req_to_token_pool=self.req_to_token_pool,
|
643
|
+
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
644
|
+
page_size=self.page_size,
|
645
|
+
disable=server_args.disable_radix_cache,
|
646
|
+
)
|
639
647
|
else:
|
640
648
|
self.tree_cache = RadixCache(
|
641
649
|
req_to_token_pool=self.req_to_token_pool,
|
@@ -8,6 +8,18 @@ import torch
|
|
8
8
|
|
9
9
|
from sglang.srt.managers.io_struct import ProfileReq, ProfileReqOutput, ProfileReqType
|
10
10
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
11
|
+
from sglang.srt.utils import is_npu
|
12
|
+
|
13
|
+
_is_npu = is_npu()
|
14
|
+
if _is_npu:
|
15
|
+
import torch_npu
|
16
|
+
|
17
|
+
patches = [
|
18
|
+
["profiler.profile", torch_npu.profiler.profile],
|
19
|
+
["profiler.ProfilerActivity.CUDA", torch_npu.profiler.ProfilerActivity.NPU],
|
20
|
+
["profiler.ProfilerActivity.CPU", torch_npu.profiler.ProfilerActivity.CPU],
|
21
|
+
]
|
22
|
+
torch_npu._apply_patches(patches)
|
11
23
|
|
12
24
|
logger = logging.getLogger(__name__)
|
13
25
|
|
@@ -136,6 +148,13 @@ class SchedulerProfilerMixin:
|
|
136
148
|
activities=torchprof_activities,
|
137
149
|
with_stack=with_stack if with_stack is not None else True,
|
138
150
|
record_shapes=record_shapes if record_shapes is not None else False,
|
151
|
+
on_trace_ready=(
|
152
|
+
None
|
153
|
+
if not _is_npu
|
154
|
+
else torch_npu.profiler.tensorboard_trace_handler(
|
155
|
+
self.torch_profiler_output_dir
|
156
|
+
)
|
157
|
+
),
|
139
158
|
)
|
140
159
|
self.torch_profiler.start()
|
141
160
|
self.profile_in_progress = True
|
@@ -166,15 +185,16 @@ class SchedulerProfilerMixin:
|
|
166
185
|
logger.info("Stop profiling" + stage_suffix + "...")
|
167
186
|
if self.torch_profiler is not None:
|
168
187
|
self.torch_profiler.stop()
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
188
|
+
if not _is_npu:
|
189
|
+
self.torch_profiler.export_chrome_trace(
|
190
|
+
os.path.join(
|
191
|
+
self.torch_profiler_output_dir,
|
192
|
+
self.profile_id
|
193
|
+
+ f"-TP-{self.tp_rank}"
|
194
|
+
+ stage_suffix
|
195
|
+
+ ".trace.json.gz",
|
196
|
+
)
|
176
197
|
)
|
177
|
-
)
|
178
198
|
torch.distributed.barrier(self.tp_cpu_group)
|
179
199
|
|
180
200
|
if self.rpd_profiler is not None:
|
@@ -269,10 +269,9 @@ class TokenizerManager:
|
|
269
269
|
self.asyncio_tasks = set()
|
270
270
|
|
271
271
|
# Health check
|
272
|
-
self.
|
272
|
+
self.server_status = ServerStatus.Starting
|
273
273
|
self.gracefully_exit = False
|
274
274
|
self.last_receive_tstamp = 0
|
275
|
-
self.server_status = ServerStatus.Starting
|
276
275
|
|
277
276
|
# Dumping
|
278
277
|
self.dump_requests_folder = "" # By default do not dump
|
@@ -291,8 +290,8 @@ class TokenizerManager:
|
|
291
290
|
self.model_update_result: Optional[Awaitable[UpdateWeightFromDiskReqOutput]] = (
|
292
291
|
None
|
293
292
|
)
|
294
|
-
self.
|
295
|
-
self.
|
293
|
+
self.is_pause = False
|
294
|
+
self.is_pause_cond = asyncio.Condition()
|
296
295
|
|
297
296
|
# LoRA
|
298
297
|
# Initialize the `LoRARegistry` with initial LoRA adapter paths provided in `server_args`.
|
@@ -476,16 +475,20 @@ class TokenizerManager:
|
|
476
475
|
self.auto_create_handle_loop()
|
477
476
|
obj.normalize_batch_and_arguments()
|
478
477
|
|
479
|
-
async with self._is_updating_cond:
|
480
|
-
await self._is_updating_cond.wait_for(lambda: not self._is_updating)
|
481
|
-
|
482
478
|
if self.log_requests:
|
483
479
|
max_length, skip_names, _ = self.log_request_metadata
|
484
480
|
logger.info(
|
485
481
|
f"Receive: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}"
|
486
482
|
)
|
487
483
|
|
484
|
+
async with self.is_pause_cond:
|
485
|
+
await self.is_pause_cond.wait_for(lambda: not self.is_pause)
|
486
|
+
|
488
487
|
async with self.model_update_lock.reader_lock:
|
488
|
+
if self.server_args.enable_lora and obj.lora_path:
|
489
|
+
# Look up the LoRA ID from the registry and start tracking ongoing LoRA requests.
|
490
|
+
obj.lora_id = await self.lora_registry.acquire(obj.lora_path)
|
491
|
+
|
489
492
|
if obj.is_single:
|
490
493
|
tokenized_obj = await self._tokenize_one_request(obj)
|
491
494
|
state = self._send_one_request(obj, tokenized_obj, created_time)
|
@@ -553,11 +556,6 @@ class TokenizerManager:
|
|
553
556
|
else:
|
554
557
|
mm_inputs = None
|
555
558
|
|
556
|
-
if self.server_args.enable_lora and obj.lora_path:
|
557
|
-
# Start tracking ongoing requests for LoRA adapters and replace the user-friendly LoRA names in
|
558
|
-
# `lora_path` with their corresponding unique LoRA IDs, as required for internal processing.
|
559
|
-
obj.lora_id = await self.lora_registry.acquire(obj.lora_path)
|
560
|
-
|
561
559
|
self._validate_one_request(obj, input_ids)
|
562
560
|
return self._create_tokenized_object(
|
563
561
|
obj, input_text, input_ids, input_embeds, mm_inputs, token_type_ids
|
@@ -775,10 +773,6 @@ class TokenizerManager:
|
|
775
773
|
msg = f"Finish: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}, out={dataclass_to_string_truncated(out, max_length, skip_names=out_skip_names)}"
|
776
774
|
logger.info(msg)
|
777
775
|
|
778
|
-
# Mark ongoing LoRA request as finished.
|
779
|
-
if self.server_args.enable_lora and obj.lora_path:
|
780
|
-
await self.lora_registry.release(obj.lora_id)
|
781
|
-
|
782
776
|
# Check if this was an abort/error created by scheduler
|
783
777
|
if isinstance(out["meta_info"].get("finish_reason"), dict):
|
784
778
|
finish_reason = out["meta_info"]["finish_reason"]
|
@@ -797,6 +791,11 @@ class TokenizerManager:
|
|
797
791
|
# Delete the key to prevent resending abort request to the scheduler and
|
798
792
|
# to ensure aborted request state is cleaned up.
|
799
793
|
del self.rid_to_state[state.obj.rid]
|
794
|
+
|
795
|
+
# Mark ongoing LoRA request as finished.
|
796
|
+
if self.server_args.enable_lora and state.obj.lora_path:
|
797
|
+
await self.lora_registry.release(state.obj.lora_id)
|
798
|
+
|
800
799
|
raise fastapi.HTTPException(
|
801
800
|
status_code=finish_reason["status_code"],
|
802
801
|
detail=finish_reason["message"],
|
@@ -982,14 +981,14 @@ class TokenizerManager:
|
|
982
981
|
await self.expert_distribution_communicator(ExpertDistributionReq.DUMP_RECORD)
|
983
982
|
|
984
983
|
async def pause_generation(self):
|
985
|
-
async with self.
|
986
|
-
self.
|
984
|
+
async with self.is_pause_cond:
|
985
|
+
self.is_pause = True
|
987
986
|
self.abort_request(abort_all=True)
|
988
987
|
|
989
988
|
async def continue_generation(self):
|
990
|
-
async with self.
|
991
|
-
self.
|
992
|
-
self.
|
989
|
+
async with self.is_pause_cond:
|
990
|
+
self.is_pause = False
|
991
|
+
self.is_pause_cond.notify_all()
|
993
992
|
|
994
993
|
async def update_weights_from_disk(
|
995
994
|
self,
|
@@ -1474,7 +1473,7 @@ class TokenizerManager:
|
|
1474
1473
|
while True:
|
1475
1474
|
remain_num_req = len(self.rid_to_state)
|
1476
1475
|
|
1477
|
-
if self.
|
1476
|
+
if self.server_status == ServerStatus.UnHealthy:
|
1478
1477
|
# if health check failed, we should exit immediately
|
1479
1478
|
logger.error(
|
1480
1479
|
"Signal SIGTERM received while health check failed. Exiting... remaining number of requests: %d",
|
@@ -1600,6 +1599,10 @@ class TokenizerManager:
|
|
1600
1599
|
meta_info["e2e_latency"] = state.finished_time - state.created_time
|
1601
1600
|
del self.rid_to_state[rid]
|
1602
1601
|
|
1602
|
+
# Mark ongoing LoRA request as finished.
|
1603
|
+
if self.server_args.enable_lora and state.obj.lora_path:
|
1604
|
+
asyncio.create_task(self.lora_registry.release(state.obj.lora_id))
|
1605
|
+
|
1603
1606
|
state.out_list.append(out_dict)
|
1604
1607
|
state.event.set()
|
1605
1608
|
|
@@ -1965,10 +1968,6 @@ class ServerStatus(Enum):
|
|
1965
1968
|
Up = "Up"
|
1966
1969
|
Starting = "Starting"
|
1967
1970
|
UnHealthy = "UnHealthy"
|
1968
|
-
Crashed = "Crashed"
|
1969
|
-
|
1970
|
-
def is_healthy(self) -> bool:
|
1971
|
-
return self == ServerStatus.Up
|
1972
1971
|
|
1973
1972
|
|
1974
1973
|
def _determine_tensor_transport_mode(server_args: ServerArgs) -> TensorTransportMode:
|