sglang 0.4.7__py3-none-any.whl → 0.4.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -0
- sglang/api.py +7 -0
- sglang/bench_one_batch.py +8 -6
- sglang/bench_serving.py +1 -1
- sglang/lang/interpreter.py +40 -1
- sglang/lang/ir.py +27 -0
- sglang/math_utils.py +8 -0
- sglang/srt/_custom_ops.py +2 -2
- sglang/srt/code_completion_parser.py +2 -44
- sglang/srt/configs/model_config.py +6 -0
- sglang/srt/constants.py +3 -0
- sglang/srt/conversation.py +19 -3
- sglang/srt/custom_op.py +5 -1
- sglang/srt/disaggregation/base/__init__.py +1 -1
- sglang/srt/disaggregation/base/conn.py +25 -11
- sglang/srt/disaggregation/common/__init__.py +5 -1
- sglang/srt/disaggregation/common/utils.py +42 -0
- sglang/srt/disaggregation/decode.py +211 -72
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
- sglang/srt/disaggregation/fake/__init__.py +1 -1
- sglang/srt/disaggregation/fake/conn.py +15 -9
- sglang/srt/disaggregation/mini_lb.py +34 -4
- sglang/srt/disaggregation/mooncake/__init__.py +1 -1
- sglang/srt/disaggregation/mooncake/conn.py +30 -29
- sglang/srt/disaggregation/nixl/__init__.py +6 -1
- sglang/srt/disaggregation/nixl/conn.py +17 -12
- sglang/srt/disaggregation/prefill.py +144 -55
- sglang/srt/disaggregation/utils.py +155 -123
- sglang/srt/distributed/parallel_state.py +12 -4
- sglang/srt/entrypoints/engine.py +37 -29
- sglang/srt/entrypoints/http_server.py +153 -72
- sglang/srt/entrypoints/http_server_engine.py +0 -3
- sglang/srt/entrypoints/openai/__init__.py +0 -0
- sglang/srt/{openai_api → entrypoints/openai}/protocol.py +84 -10
- sglang/srt/entrypoints/openai/serving_base.py +149 -0
- sglang/srt/entrypoints/openai/serving_chat.py +921 -0
- sglang/srt/entrypoints/openai/serving_completions.py +424 -0
- sglang/srt/entrypoints/openai/serving_embedding.py +169 -0
- sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
- sglang/srt/entrypoints/openai/serving_score.py +61 -0
- sglang/srt/entrypoints/openai/usage_processor.py +81 -0
- sglang/srt/entrypoints/openai/utils.py +72 -0
- sglang/srt/eplb_simulator/__init__.py +1 -0
- sglang/srt/eplb_simulator/reader.py +51 -0
- sglang/srt/function_call/base_format_detector.py +7 -4
- sglang/srt/function_call/deepseekv3_detector.py +1 -1
- sglang/srt/function_call/ebnf_composer.py +64 -10
- sglang/srt/function_call/function_call_parser.py +6 -6
- sglang/srt/function_call/llama32_detector.py +1 -1
- sglang/srt/function_call/mistral_detector.py +1 -1
- sglang/srt/function_call/pythonic_detector.py +1 -1
- sglang/srt/function_call/qwen25_detector.py +1 -1
- sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
- sglang/srt/layers/activation.py +40 -3
- sglang/srt/layers/attention/aiter_backend.py +20 -4
- sglang/srt/layers/attention/base_attn_backend.py +1 -1
- sglang/srt/layers/attention/cutlass_mla_backend.py +39 -15
- sglang/srt/layers/attention/flashattention_backend.py +71 -72
- sglang/srt/layers/attention/flashinfer_backend.py +10 -8
- sglang/srt/layers/attention/flashinfer_mla_backend.py +29 -28
- sglang/srt/layers/attention/flashmla_backend.py +7 -12
- sglang/srt/layers/attention/tbo_backend.py +3 -3
- sglang/srt/layers/attention/triton_backend.py +138 -130
- sglang/srt/layers/attention/triton_ops/decode_attention.py +2 -7
- sglang/srt/layers/attention/vision.py +51 -24
- sglang/srt/layers/communicator.py +28 -10
- sglang/srt/layers/dp_attention.py +11 -2
- sglang/srt/layers/layernorm.py +29 -2
- sglang/srt/layers/linear.py +0 -4
- sglang/srt/layers/logits_processor.py +2 -14
- sglang/srt/layers/moe/ep_moe/kernels.py +165 -7
- sglang/srt/layers/moe/ep_moe/layer.py +249 -33
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +11 -37
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +7 -4
- sglang/srt/layers/moe/fused_moe_triton/layer.py +75 -12
- sglang/srt/layers/moe/topk.py +107 -12
- sglang/srt/layers/pooler.py +56 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
- sglang/srt/layers/quantization/deep_gemm_wrapper/__init__.py +1 -0
- sglang/srt/layers/quantization/{deep_gemm.py → deep_gemm_wrapper/compile_utils.py} +23 -80
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +32 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +110 -0
- sglang/srt/layers/quantization/fp8.py +25 -17
- sglang/srt/layers/quantization/fp8_kernel.py +44 -15
- sglang/srt/layers/quantization/fp8_utils.py +87 -22
- sglang/srt/layers/quantization/modelopt_quant.py +62 -8
- sglang/srt/layers/quantization/utils.py +5 -2
- sglang/srt/layers/radix_attention.py +2 -3
- sglang/srt/layers/rotary_embedding.py +42 -2
- sglang/srt/layers/sampler.py +1 -1
- sglang/srt/lora/lora_manager.py +249 -105
- sglang/srt/lora/mem_pool.py +53 -50
- sglang/srt/lora/utils.py +1 -1
- sglang/srt/managers/cache_controller.py +33 -14
- sglang/srt/managers/io_struct.py +31 -10
- sglang/srt/managers/multimodal_processors/base_processor.py +2 -2
- sglang/srt/managers/multimodal_processors/vila.py +85 -0
- sglang/srt/managers/schedule_batch.py +79 -37
- sglang/srt/managers/schedule_policy.py +70 -56
- sglang/srt/managers/scheduler.py +220 -79
- sglang/srt/managers/template_manager.py +226 -0
- sglang/srt/managers/tokenizer_manager.py +40 -10
- sglang/srt/managers/tp_worker.py +12 -2
- sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
- sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
- sglang/srt/mem_cache/base_prefix_cache.py +52 -8
- sglang/srt/mem_cache/chunk_cache.py +11 -15
- sglang/srt/mem_cache/hiradix_cache.py +38 -25
- sglang/srt/mem_cache/memory_pool.py +213 -505
- sglang/srt/mem_cache/memory_pool_host.py +380 -0
- sglang/srt/mem_cache/radix_cache.py +56 -28
- sglang/srt/model_executor/cuda_graph_runner.py +198 -100
- sglang/srt/model_executor/forward_batch_info.py +32 -10
- sglang/srt/model_executor/model_runner.py +28 -12
- sglang/srt/model_loader/loader.py +16 -2
- sglang/srt/model_loader/weight_utils.py +11 -2
- sglang/srt/models/bert.py +113 -13
- sglang/srt/models/deepseek_nextn.py +29 -27
- sglang/srt/models/deepseek_v2.py +213 -173
- sglang/srt/models/glm4.py +312 -0
- sglang/srt/models/internvl.py +46 -102
- sglang/srt/models/mimo_mtp.py +2 -18
- sglang/srt/models/roberta.py +117 -9
- sglang/srt/models/vila.py +305 -0
- sglang/srt/reasoning_parser.py +21 -11
- sglang/srt/sampling/sampling_batch_info.py +24 -0
- sglang/srt/sampling/sampling_params.py +2 -0
- sglang/srt/server_args.py +351 -238
- sglang/srt/speculative/build_eagle_tree.py +1 -1
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -9
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +130 -14
- sglang/srt/speculative/eagle_utils.py +468 -116
- sglang/srt/speculative/eagle_worker.py +258 -84
- sglang/srt/torch_memory_saver_adapter.py +19 -15
- sglang/srt/two_batch_overlap.py +4 -2
- sglang/srt/utils.py +235 -11
- sglang/test/attention/test_prefix_chunk_info.py +2 -0
- sglang/test/runners.py +38 -3
- sglang/test/test_block_fp8.py +1 -0
- sglang/test/test_block_fp8_deep_gemm_blackwell.py +252 -0
- sglang/test/test_block_fp8_ep.py +2 -0
- sglang/test/test_utils.py +4 -1
- sglang/utils.py +9 -0
- sglang/version.py +1 -1
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/METADATA +8 -14
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/RECORD +150 -128
- sglang/srt/entrypoints/verl_engine.py +0 -179
- sglang/srt/openai_api/adapter.py +0 -1990
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/WHEEL +0 -0
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/top_level.txt +0 -0
sglang/srt/lora/mem_pool.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
from typing import Dict, List, Optional, Set, Tuple
|
1
|
+
from typing import Callable, Dict, List, Optional, Set, Tuple
|
2
2
|
|
3
3
|
import torch
|
4
4
|
|
@@ -22,21 +22,16 @@ class LoRAMemoryPool:
|
|
22
22
|
self,
|
23
23
|
base_hf_config: AutoConfig,
|
24
24
|
max_loras_per_batch: int,
|
25
|
-
max_lora_dim: int,
|
26
25
|
dtype: torch.dtype,
|
27
26
|
tp_size: int,
|
28
27
|
tp_rank: int,
|
29
|
-
lora_modules: Dict[int, List[Tuple[str, BaseLayerWithLoRA]]],
|
30
28
|
):
|
31
|
-
|
32
29
|
self.base_hf_config: AutoConfig = base_hf_config
|
33
30
|
self.num_layer: int = base_hf_config.num_hidden_layers
|
34
31
|
self.max_loras_per_batch: int = max_loras_per_batch
|
35
|
-
self.max_lora_dim: int = max_lora_dim
|
36
32
|
self.dtype: torch.dtype = dtype
|
37
33
|
self.tp_size: int = tp_size
|
38
34
|
self.tp_rank: int = tp_rank
|
39
|
-
self.lora_modules: Dict[int, List[Tuple[str, BaseLayerWithLoRA]]] = lora_modules
|
40
35
|
|
41
36
|
# Both A_buffer and B_buffer maps lora weight names to its buffer space.
|
42
37
|
# A_buffer contains num_layer number of row-major tensors with shape
|
@@ -55,89 +50,95 @@ class LoRAMemoryPool:
|
|
55
50
|
self.buffer_id_to_uid: List[Optional[str]] = [""] * self.max_loras_per_batch
|
56
51
|
|
57
52
|
def get_lora_A_shape(
|
58
|
-
self, module_name: str, base_model: torch.nn.Module
|
53
|
+
self, module_name: str, base_model: torch.nn.Module, max_lora_dim: int
|
59
54
|
) -> Tuple[int]:
|
60
55
|
"""
|
61
56
|
Given a module_name (might be a stacked name), return the hidden dims of modules' input and output.
|
62
57
|
"""
|
63
58
|
input_dim, _ = get_hidden_dim(module_name, self.base_hf_config, base_model)
|
64
59
|
c = get_stacked_multiply(module_name)
|
65
|
-
if self.tp_size > 1:
|
66
|
-
|
67
|
-
input_dim = divide(input_dim, self.tp_size)
|
60
|
+
if self.tp_size > 1 and module_name in ROW_PARALLELISM_LINEAR_LORA_NAMES:
|
61
|
+
input_dim = divide(input_dim, self.tp_size)
|
68
62
|
return (
|
69
63
|
self.max_loras_per_batch,
|
70
|
-
|
64
|
+
max_lora_dim * c,
|
71
65
|
input_dim,
|
72
66
|
)
|
73
67
|
|
74
68
|
def get_lora_B_shape(
|
75
|
-
self, module_name: str, base_model: torch.nn.Module
|
69
|
+
self, module_name: str, base_model: torch.nn.Module, max_lora_dim: int
|
76
70
|
) -> Tuple[int]:
|
77
71
|
"""
|
78
72
|
Given a module_name (might be a stacked name), return the hidden dims of modules' input and output.
|
79
73
|
"""
|
80
74
|
_, output_dim = get_hidden_dim(module_name, self.base_hf_config, base_model)
|
81
75
|
c = get_stacked_multiply(module_name)
|
82
|
-
if self.tp_size > 1:
|
83
|
-
|
84
|
-
output_dim = divide(output_dim, self.tp_size)
|
76
|
+
if self.tp_size > 1 and module_name not in ROW_PARALLELISM_LINEAR_LORA_NAMES:
|
77
|
+
output_dim = divide(output_dim, self.tp_size)
|
85
78
|
return (
|
86
79
|
c,
|
87
80
|
self.max_loras_per_batch,
|
88
81
|
output_dim,
|
89
|
-
|
82
|
+
max_lora_dim,
|
90
83
|
)
|
91
84
|
|
92
85
|
def init_buffers(
|
93
86
|
self,
|
94
87
|
lora_weight_names: Tuple[Set[str]],
|
95
88
|
base_model: torch.nn.Module,
|
89
|
+
max_lora_dim: int,
|
96
90
|
):
|
97
|
-
|
98
91
|
# lora_weight_names is a set of name pairs indicating each pair of lora modules to load
|
99
92
|
# e.g., {("qkv_proj", "q_proj"), ("qkv_proj", "kv_proj"), ("o_proj", "o_proj")}
|
100
93
|
self.lora_weight_names: Tuple[Set[str]] = lora_weight_names
|
101
94
|
device = next(base_model.parameters()).device
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
)
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
95
|
+
|
96
|
+
def update_buffer(
|
97
|
+
buffer: Dict[str, List[torch.Tensor]],
|
98
|
+
lora_weight_names: Set[str],
|
99
|
+
get_lora_shape_fn: Callable[[str, torch.nn.Module, int], Tuple[int]],
|
100
|
+
):
|
101
|
+
new_weight_names = lora_weight_names - buffer.keys()
|
102
|
+
for module_name in new_weight_names:
|
103
|
+
lora_shape = get_lora_shape_fn(module_name, base_model, max_lora_dim)
|
104
|
+
buffer[module_name] = [
|
105
|
+
torch.empty(
|
106
|
+
lora_shape,
|
107
|
+
dtype=self.dtype,
|
108
|
+
device=device,
|
109
|
+
)
|
110
|
+
for _ in range(self.num_layer)
|
111
|
+
]
|
112
|
+
|
113
|
+
update_buffer(
|
114
|
+
self.A_buffer,
|
115
|
+
lora_weight_names[0],
|
116
|
+
self.get_lora_A_shape,
|
117
|
+
)
|
118
|
+
|
119
|
+
update_buffer(
|
120
|
+
self.B_buffer,
|
121
|
+
lora_weight_names[1],
|
122
|
+
self.get_lora_B_shape,
|
123
|
+
)
|
124
124
|
|
125
125
|
def prepare_lora_batch(
|
126
126
|
self,
|
127
127
|
cur_uids: Set[Optional[str]],
|
128
128
|
lora_adapters: Dict[str, LoRAAdapter],
|
129
|
+
lora_modules: Dict[int, Dict[str, BaseLayerWithLoRA]],
|
129
130
|
):
|
130
|
-
|
131
131
|
def get_available_buffer_slot():
|
132
132
|
for buffer_id in range(self.max_loras_per_batch):
|
133
133
|
# Prioritize empty slots
|
134
134
|
if self.buffer_id_to_uid[buffer_id] == "":
|
135
|
-
return buffer_id
|
135
|
+
return buffer_id
|
136
136
|
|
137
137
|
for buffer_id in range(self.max_loras_per_batch):
|
138
138
|
# Evict unneeded lora
|
139
139
|
if self.buffer_id_to_uid[buffer_id] not in cur_uids:
|
140
|
-
|
140
|
+
self.uid_to_buffer_id.pop(self.buffer_id_to_uid[buffer_id])
|
141
|
+
return buffer_id
|
141
142
|
|
142
143
|
raise ValueError(
|
143
144
|
"No available buffer slots found. Please ensure the number of active loras is less than max_loras_per_batch."
|
@@ -145,17 +146,20 @@ class LoRAMemoryPool:
|
|
145
146
|
|
146
147
|
for uid in cur_uids:
|
147
148
|
if uid not in self.uid_to_buffer_id:
|
148
|
-
buffer_id
|
149
|
-
|
150
|
-
self.uid_to_buffer_id.pop(evicted_lora_uid)
|
149
|
+
buffer_id = get_available_buffer_slot()
|
150
|
+
lora_adapter = lora_adapters.get(uid, None)
|
151
151
|
self.load_lora_weight_to_buffer(
|
152
|
-
uid, buffer_id,
|
152
|
+
uid, buffer_id, lora_adapter, lora_modules
|
153
153
|
)
|
154
154
|
self.uid_to_buffer_id[uid] = buffer_id
|
155
155
|
self.buffer_id_to_uid[buffer_id] = uid
|
156
156
|
|
157
157
|
def load_lora_weight_to_buffer(
|
158
|
-
self,
|
158
|
+
self,
|
159
|
+
uid: str,
|
160
|
+
buffer_id: int,
|
161
|
+
lora_adapter: LoRAAdapter,
|
162
|
+
lora_modules: Dict[int, Dict[str, BaseLayerWithLoRA]],
|
159
163
|
):
|
160
164
|
def check_lora_weight_shape(buffer_view: torch.Tensor, weight: torch.Tensor):
|
161
165
|
assert (
|
@@ -187,8 +191,8 @@ class LoRAMemoryPool:
|
|
187
191
|
temp_B_buffer[lora_weight_name] = weights
|
188
192
|
|
189
193
|
if self.tp_size > 1:
|
190
|
-
cur_layer_modules =
|
191
|
-
for module_name, module in cur_layer_modules:
|
194
|
+
cur_layer_modules = lora_modules[layer_id]
|
195
|
+
for module_name, module in cur_layer_modules.items():
|
192
196
|
if "qkv_proj" in module_name:
|
193
197
|
temp_A_buffer["qkv_proj"] = module.slice_lora_a_weights(
|
194
198
|
temp_A_buffer["qkv_proj"], self.tp_rank
|
@@ -237,7 +241,6 @@ class LoRAMemoryPool:
|
|
237
241
|
def get_tensor(
|
238
242
|
self, weight_name: str, layer_id: int, lora_type: LoRAType
|
239
243
|
) -> torch.Tensor:
|
240
|
-
|
241
244
|
if lora_type == LoRAType.LORA_A:
|
242
245
|
return self.A_buffer[weight_name][layer_id]
|
243
246
|
|
sglang/srt/lora/utils.py
CHANGED
@@ -108,7 +108,7 @@ def get_hidden_dim(
|
|
108
108
|
|
109
109
|
def get_normalized_lora_weight_names(name: str) -> Tuple[List[str], List[str]]:
|
110
110
|
"""
|
111
|
-
Mapping a target module name to names of the
|
111
|
+
Mapping a target module name to names of the normalized LoRA weights.
|
112
112
|
Returned tuple contains (name for Lora A, name for Lora B)
|
113
113
|
"""
|
114
114
|
params_mapping = {
|
@@ -18,33 +18,50 @@ import logging
|
|
18
18
|
import math
|
19
19
|
import threading
|
20
20
|
from queue import Empty, Full, PriorityQueue, Queue
|
21
|
-
from typing import List, Optional
|
21
|
+
from typing import TYPE_CHECKING, List, Optional
|
22
22
|
|
23
23
|
import torch
|
24
24
|
|
25
|
-
|
25
|
+
if TYPE_CHECKING:
|
26
|
+
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
27
|
+
from sglang.srt.mem_cache.memory_pool_host import HostKVCache
|
26
28
|
|
27
29
|
logger = logging.getLogger(__name__)
|
28
30
|
|
29
31
|
|
30
32
|
class LayerDoneCounter:
|
31
33
|
def __init__(self, num_layers):
|
32
|
-
self.
|
33
|
-
|
34
|
+
self.num_layers = num_layers
|
35
|
+
# extra producer and consumer counters for overlap mode
|
36
|
+
self.num_counters = 3
|
37
|
+
self.counters = [num_layers] * self.num_counters
|
38
|
+
self.conditions = [threading.Condition() for _ in range(self.num_counters)]
|
39
|
+
self.producer_index = 0
|
40
|
+
self.consumer_index = 0
|
41
|
+
|
42
|
+
def next_producer(self):
|
43
|
+
return (self.producer_index + 1) % self.num_counters
|
44
|
+
|
45
|
+
def update_producer(self):
|
46
|
+
self.producer_index = self.next_producer()
|
47
|
+
return self.producer_index
|
48
|
+
|
49
|
+
def set_consumer(self, index):
|
50
|
+
self.consumer_index = index
|
34
51
|
|
35
52
|
def increment(self):
|
36
|
-
with self.
|
37
|
-
self.
|
38
|
-
self.
|
53
|
+
with self.conditions[self.producer_index]:
|
54
|
+
self.counters[self.producer_index] += 1
|
55
|
+
self.conditions[self.producer_index].notify_all()
|
39
56
|
|
40
57
|
def wait_until(self, threshold):
|
41
|
-
with self.
|
42
|
-
while self.
|
43
|
-
self.
|
58
|
+
with self.conditions[self.consumer_index]:
|
59
|
+
while self.counters[self.consumer_index] <= threshold:
|
60
|
+
self.conditions[self.consumer_index].wait()
|
44
61
|
|
45
62
|
def reset(self):
|
46
|
-
with self.
|
47
|
-
self.
|
63
|
+
with self.conditions[self.producer_index]:
|
64
|
+
self.counters[self.producer_index] = 0
|
48
65
|
|
49
66
|
|
50
67
|
class CacheOperation:
|
@@ -147,7 +164,7 @@ class HiCacheController:
|
|
147
164
|
|
148
165
|
def __init__(
|
149
166
|
self,
|
150
|
-
token_to_kv_pool_allocator:
|
167
|
+
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
|
151
168
|
mem_pool_host: HostKVCache,
|
152
169
|
page_size: int,
|
153
170
|
load_cache_event: threading.Event = None,
|
@@ -295,7 +312,6 @@ class HiCacheController:
|
|
295
312
|
while not self.stop_event.is_set():
|
296
313
|
try:
|
297
314
|
operation = self.load_queue.get(block=True, timeout=1)
|
298
|
-
# time.sleep(18e-6 * len(operation.host_indices))
|
299
315
|
operation.data = self.mem_pool_host.get_flat_data(
|
300
316
|
operation.host_indices
|
301
317
|
)
|
@@ -319,6 +335,7 @@ class HiCacheController:
|
|
319
335
|
if not self.load_cache_event.is_set():
|
320
336
|
continue
|
321
337
|
self.load_cache_event.clear()
|
338
|
+
self.layer_done_counter.update_producer()
|
322
339
|
|
323
340
|
batch_operation = None
|
324
341
|
while self.load_queue.qsize() > 0:
|
@@ -330,6 +347,7 @@ class HiCacheController:
|
|
330
347
|
if batch_operation is None:
|
331
348
|
continue
|
332
349
|
|
350
|
+
# start layer-wise KV cache transfer from CPU to GPU
|
333
351
|
self.layer_done_counter.reset()
|
334
352
|
for i in range(self.mem_pool_host.layer_num):
|
335
353
|
if self.page_size == 1:
|
@@ -465,6 +483,7 @@ class HiCacheController:
|
|
465
483
|
except Exception as e:
|
466
484
|
logger.error(e)
|
467
485
|
|
486
|
+
# todo (zhiqiang): double buffering to be deprecated
|
468
487
|
def write_thread_func_buffer(self):
|
469
488
|
aux_thread = threading.Thread(target=self.write_aux_func, daemon=True)
|
470
489
|
aux_thread.start()
|
sglang/srt/managers/io_struct.py
CHANGED
@@ -87,7 +87,7 @@ class GenerateReqInput:
|
|
87
87
|
|
88
88
|
# The modalities of the image data [image, multi-images, video]
|
89
89
|
modalities: Optional[List[str]] = None
|
90
|
-
# LoRA
|
90
|
+
# The path to the LoRA
|
91
91
|
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
|
92
92
|
|
93
93
|
# Session info for continual prompting
|
@@ -99,7 +99,7 @@ class GenerateReqInput:
|
|
99
99
|
custom_logit_processor: Optional[Union[List[Optional[str]], str]] = None
|
100
100
|
|
101
101
|
# Whether to return hidden states
|
102
|
-
return_hidden_states: bool = False
|
102
|
+
return_hidden_states: Union[List[bool], bool] = False
|
103
103
|
|
104
104
|
# For disaggregated inference
|
105
105
|
bootstrap_host: Optional[Union[List[str], str]] = None
|
@@ -226,11 +226,11 @@ class GenerateReqInput:
|
|
226
226
|
|
227
227
|
# Expand input based on type
|
228
228
|
self._expand_inputs(num)
|
229
|
+
self._normalize_rid(num)
|
229
230
|
self._normalize_lora_paths(num)
|
230
231
|
self._normalize_image_data(num)
|
231
232
|
self._normalize_audio_data(num)
|
232
233
|
self._normalize_sampling_params(num)
|
233
|
-
self._normalize_rid(num)
|
234
234
|
self._normalize_logprob_params(num)
|
235
235
|
self._normalize_custom_logit_processor(num)
|
236
236
|
|
@@ -409,7 +409,11 @@ class GenerateReqInput:
|
|
409
409
|
if self.custom_logit_processor is not None
|
410
410
|
else None
|
411
411
|
),
|
412
|
-
return_hidden_states=
|
412
|
+
return_hidden_states=(
|
413
|
+
self.return_hidden_states[i]
|
414
|
+
if isinstance(self.return_hidden_states, list)
|
415
|
+
else self.return_hidden_states
|
416
|
+
),
|
413
417
|
# if `__getitem__` is called, the bootstrap_host, bootstrap_port, bootstrap_room must be a list
|
414
418
|
bootstrap_host=(
|
415
419
|
self.bootstrap_host[i] if self.bootstrap_host is not None else None
|
@@ -477,7 +481,7 @@ class TokenizedGenerateReqInput:
|
|
477
481
|
@dataclass
|
478
482
|
class EmbeddingReqInput:
|
479
483
|
# The input prompt. It can be a single prompt or a batch of prompts.
|
480
|
-
text: Optional[Union[List[str], str]] = None
|
484
|
+
text: Optional[Union[List[List[str]], List[str], str]] = None
|
481
485
|
# The image input. It can be an image instance, file name, URL, or base64 encoded string.
|
482
486
|
# Can be formatted as:
|
483
487
|
# - Single image for a single request
|
@@ -501,6 +505,8 @@ class EmbeddingReqInput:
|
|
501
505
|
log_metrics: bool = True
|
502
506
|
# The modalities of the image data [image, multi-images, video]
|
503
507
|
modalities: Optional[List[str]] = None
|
508
|
+
# For cross-encoder requests
|
509
|
+
is_cross_encoder_request: bool = False
|
504
510
|
|
505
511
|
def contains_mm_input(self) -> bool:
|
506
512
|
return has_valid_data(self.image_data) or has_valid_data(self.audio_data)
|
@@ -524,6 +530,7 @@ class EmbeddingReqInput:
|
|
524
530
|
if self.text is not None:
|
525
531
|
if isinstance(self.text, list):
|
526
532
|
self.batch_size += len(self.text)
|
533
|
+
self.is_single = False
|
527
534
|
else:
|
528
535
|
self.batch_size += 1
|
529
536
|
|
@@ -531,12 +538,10 @@ class EmbeddingReqInput:
|
|
531
538
|
if self.input_ids is not None:
|
532
539
|
if isinstance(self.input_ids[0], list):
|
533
540
|
self.batch_size += len(self.input_ids)
|
541
|
+
self.is_single = False
|
534
542
|
else:
|
535
543
|
self.batch_size += 1
|
536
544
|
|
537
|
-
if self.batch_size > 1:
|
538
|
-
self.is_single = False
|
539
|
-
|
540
545
|
# Fill in default arguments
|
541
546
|
if self.is_single:
|
542
547
|
if self.rid is None:
|
@@ -560,6 +565,16 @@ class EmbeddingReqInput:
|
|
560
565
|
return self.rid
|
561
566
|
|
562
567
|
def __getitem__(self, i):
|
568
|
+
if self.is_cross_encoder_request:
|
569
|
+
return EmbeddingReqInput(
|
570
|
+
text=[self.text[i]] if self.text is not None else None,
|
571
|
+
input_ids=None,
|
572
|
+
image_data=None,
|
573
|
+
sampling_params=self.sampling_params[i],
|
574
|
+
rid=self.rid[i],
|
575
|
+
is_cross_encoder_request=True,
|
576
|
+
)
|
577
|
+
|
563
578
|
return EmbeddingReqInput(
|
564
579
|
text=self.text[i] if self.text is not None else None,
|
565
580
|
input_ids=self.input_ids[i] if self.input_ids is not None else None,
|
@@ -579,6 +594,8 @@ class TokenizedEmbeddingReqInput:
|
|
579
594
|
input_ids: List[int]
|
580
595
|
# The image inputs
|
581
596
|
image_inputs: dict
|
597
|
+
# The token type ids
|
598
|
+
token_type_ids: List[int]
|
582
599
|
# Dummy sampling params for compatibility
|
583
600
|
sampling_params: SamplingParams
|
584
601
|
|
@@ -794,7 +811,9 @@ class GetWeightsByNameReqOutput:
|
|
794
811
|
|
795
812
|
@dataclass
|
796
813
|
class ReleaseMemoryOccupationReqInput:
|
797
|
-
|
814
|
+
# Optional tags to identify the memory region, which is primarily used for RL
|
815
|
+
# Currently we only support `weights` and `kv_cache`
|
816
|
+
tags: Optional[List[str]] = None
|
798
817
|
|
799
818
|
|
800
819
|
@dataclass
|
@@ -804,7 +823,9 @@ class ReleaseMemoryOccupationReqOutput:
|
|
804
823
|
|
805
824
|
@dataclass
|
806
825
|
class ResumeMemoryOccupationReqInput:
|
807
|
-
|
826
|
+
# Optional tags to identify the memory region, which is primarily used for RL
|
827
|
+
# Currently we only support `weights` and `kv_cache`
|
828
|
+
tags: Optional[List[str]] = None
|
808
829
|
|
809
830
|
|
810
831
|
@dataclass
|
@@ -146,7 +146,7 @@ class BaseMultimodalProcessor(ABC):
|
|
146
146
|
request_obj,
|
147
147
|
max_req_input_len,
|
148
148
|
**kwargs,
|
149
|
-
):
|
149
|
+
) -> Optional[Dict[str, Any]]:
|
150
150
|
pass
|
151
151
|
|
152
152
|
def get_estimated_frames_list(self, image_data):
|
@@ -261,7 +261,7 @@ class BaseMultimodalProcessor(ABC):
|
|
261
261
|
|
262
262
|
def load_mm_data(
|
263
263
|
self,
|
264
|
-
prompt: str,
|
264
|
+
prompt: str | List[int],
|
265
265
|
multimodal_tokens: MultimodalSpecialTokens,
|
266
266
|
max_req_input_len: int,
|
267
267
|
image_data: Optional[list] = None,
|
@@ -0,0 +1,85 @@
|
|
1
|
+
from typing import Any, Dict, List, Optional, Type, cast
|
2
|
+
|
3
|
+
import torch.nn as nn
|
4
|
+
from transformers.configuration_utils import PretrainedConfig
|
5
|
+
from transformers.processing_utils import ProcessorMixin
|
6
|
+
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
|
7
|
+
|
8
|
+
from sglang.srt.managers.io_struct import (
|
9
|
+
EmbeddingReqInput,
|
10
|
+
GenerateReqInput,
|
11
|
+
ImageDataItem,
|
12
|
+
)
|
13
|
+
from sglang.srt.managers.multimodal_processors.base_processor import (
|
14
|
+
BaseMultimodalProcessor,
|
15
|
+
MultimodalSpecialTokens,
|
16
|
+
)
|
17
|
+
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
|
18
|
+
from sglang.srt.models.vila import VILAForConditionalGeneration
|
19
|
+
from sglang.srt.server_args import ServerArgs
|
20
|
+
|
21
|
+
|
22
|
+
class VILAProcessor(ProcessorMixin):
|
23
|
+
"""A stub class for the VILA processor."""
|
24
|
+
|
25
|
+
tokenizer: PreTrainedTokenizerBase
|
26
|
+
|
27
|
+
|
28
|
+
class VILAMultimodalProcessor(BaseMultimodalProcessor):
|
29
|
+
models: List[Type[nn.Module]] = [VILAForConditionalGeneration]
|
30
|
+
|
31
|
+
_processor: VILAProcessor
|
32
|
+
|
33
|
+
def __init__(
|
34
|
+
self,
|
35
|
+
hf_config: PretrainedConfig,
|
36
|
+
server_args: ServerArgs,
|
37
|
+
_processor: VILAProcessor,
|
38
|
+
) -> None:
|
39
|
+
super().__init__(hf_config, server_args, _processor)
|
40
|
+
|
41
|
+
async def process_mm_data_async(
|
42
|
+
self,
|
43
|
+
image_data: Optional[ImageDataItem | List[ImageDataItem]],
|
44
|
+
input_text: str | List[int],
|
45
|
+
request_obj: GenerateReqInput | EmbeddingReqInput,
|
46
|
+
max_req_input_len: int,
|
47
|
+
**kwargs,
|
48
|
+
) -> Optional[Dict[str, Any]]:
|
49
|
+
if not image_data:
|
50
|
+
return None
|
51
|
+
|
52
|
+
if not isinstance(image_data, list):
|
53
|
+
image_data = [image_data]
|
54
|
+
|
55
|
+
mm_data = self.load_mm_data(
|
56
|
+
prompt=input_text,
|
57
|
+
multimodal_tokens=MultimodalSpecialTokens(
|
58
|
+
image_token=self._processor.tokenizer.image_token
|
59
|
+
),
|
60
|
+
max_req_input_len=max_req_input_len,
|
61
|
+
image_data=image_data,
|
62
|
+
)
|
63
|
+
|
64
|
+
inputs = self.process_mm_data(
|
65
|
+
input_text=mm_data.input_text,
|
66
|
+
images=mm_data.images,
|
67
|
+
)
|
68
|
+
|
69
|
+
image_offsets = self.get_mm_items_offset(
|
70
|
+
input_ids=inputs.input_ids[0],
|
71
|
+
mm_token_id=cast(int, self._processor.tokenizer.image_token_id),
|
72
|
+
)
|
73
|
+
|
74
|
+
mm_items: List[MultimodalDataItem] = [
|
75
|
+
MultimodalDataItem(
|
76
|
+
modality=Modality.IMAGE,
|
77
|
+
image_offsets=image_offsets,
|
78
|
+
pixel_values=inputs.pixel_values,
|
79
|
+
)
|
80
|
+
]
|
81
|
+
|
82
|
+
return dict(
|
83
|
+
input_ids=inputs.input_ids[0].tolist(),
|
84
|
+
mm_items=mm_items,
|
85
|
+
)
|