sglang 0.4.7__py3-none-any.whl → 0.4.7.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -0
- sglang/api.py +7 -0
- sglang/bench_serving.py +1 -1
- sglang/lang/interpreter.py +40 -1
- sglang/lang/ir.py +27 -0
- sglang/math_utils.py +8 -0
- sglang/srt/configs/model_config.py +6 -0
- sglang/srt/conversation.py +6 -0
- sglang/srt/disaggregation/base/__init__.py +1 -1
- sglang/srt/disaggregation/base/conn.py +25 -11
- sglang/srt/disaggregation/common/__init__.py +5 -1
- sglang/srt/disaggregation/common/utils.py +42 -0
- sglang/srt/disaggregation/decode.py +196 -51
- sglang/srt/disaggregation/fake/__init__.py +1 -1
- sglang/srt/disaggregation/fake/conn.py +15 -9
- sglang/srt/disaggregation/mooncake/__init__.py +1 -1
- sglang/srt/disaggregation/mooncake/conn.py +18 -13
- sglang/srt/disaggregation/nixl/__init__.py +6 -1
- sglang/srt/disaggregation/nixl/conn.py +17 -12
- sglang/srt/disaggregation/prefill.py +128 -43
- sglang/srt/disaggregation/utils.py +127 -123
- sglang/srt/entrypoints/engine.py +15 -1
- sglang/srt/entrypoints/http_server.py +13 -2
- sglang/srt/eplb_simulator/__init__.py +1 -0
- sglang/srt/eplb_simulator/reader.py +51 -0
- sglang/srt/layers/activation.py +19 -0
- sglang/srt/layers/attention/aiter_backend.py +15 -2
- sglang/srt/layers/attention/cutlass_mla_backend.py +38 -15
- sglang/srt/layers/attention/flashattention_backend.py +53 -64
- sglang/srt/layers/attention/flashinfer_backend.py +1 -2
- sglang/srt/layers/attention/flashinfer_mla_backend.py +22 -24
- sglang/srt/layers/attention/flashmla_backend.py +2 -10
- sglang/srt/layers/attention/triton_backend.py +119 -119
- sglang/srt/layers/attention/triton_ops/decode_attention.py +2 -7
- sglang/srt/layers/attention/vision.py +51 -24
- sglang/srt/layers/communicator.py +23 -5
- sglang/srt/layers/linear.py +0 -4
- sglang/srt/layers/logits_processor.py +0 -12
- sglang/srt/layers/moe/ep_moe/kernels.py +6 -5
- sglang/srt/layers/moe/ep_moe/layer.py +42 -32
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +11 -37
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -4
- sglang/srt/layers/moe/topk.py +16 -8
- sglang/srt/layers/pooler.py +56 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/__init__.py +1 -0
- sglang/srt/layers/quantization/{deep_gemm.py → deep_gemm_wrapper/compile_utils.py} +23 -80
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +32 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +110 -0
- sglang/srt/layers/quantization/fp8_kernel.py +44 -15
- sglang/srt/layers/quantization/fp8_utils.py +87 -22
- sglang/srt/layers/radix_attention.py +2 -3
- sglang/srt/lora/lora_manager.py +79 -34
- sglang/srt/lora/mem_pool.py +4 -5
- sglang/srt/managers/cache_controller.py +2 -1
- sglang/srt/managers/io_struct.py +28 -4
- sglang/srt/managers/multimodal_processors/base_processor.py +2 -2
- sglang/srt/managers/multimodal_processors/vila.py +85 -0
- sglang/srt/managers/schedule_batch.py +39 -6
- sglang/srt/managers/scheduler.py +73 -17
- sglang/srt/managers/tokenizer_manager.py +29 -2
- sglang/srt/mem_cache/chunk_cache.py +1 -0
- sglang/srt/mem_cache/hiradix_cache.py +4 -2
- sglang/srt/mem_cache/memory_pool.py +111 -407
- sglang/srt/mem_cache/memory_pool_host.py +380 -0
- sglang/srt/mem_cache/radix_cache.py +36 -12
- sglang/srt/model_executor/cuda_graph_runner.py +122 -55
- sglang/srt/model_executor/forward_batch_info.py +14 -5
- sglang/srt/model_executor/model_runner.py +6 -6
- sglang/srt/model_loader/loader.py +8 -1
- sglang/srt/models/bert.py +113 -13
- sglang/srt/models/deepseek_v2.py +113 -155
- sglang/srt/models/internvl.py +46 -102
- sglang/srt/models/roberta.py +117 -9
- sglang/srt/models/vila.py +305 -0
- sglang/srt/openai_api/adapter.py +162 -4
- sglang/srt/openai_api/protocol.py +37 -1
- sglang/srt/sampling/sampling_batch_info.py +24 -0
- sglang/srt/sampling/sampling_params.py +2 -0
- sglang/srt/server_args.py +318 -233
- sglang/srt/speculative/build_eagle_tree.py +1 -1
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +4 -3
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +5 -2
- sglang/srt/speculative/eagle_utils.py +389 -109
- sglang/srt/speculative/eagle_worker.py +134 -43
- sglang/srt/two_batch_overlap.py +4 -2
- sglang/srt/utils.py +58 -0
- sglang/test/attention/test_prefix_chunk_info.py +2 -0
- sglang/test/runners.py +38 -3
- sglang/test/test_block_fp8.py +1 -0
- sglang/test/test_block_fp8_deep_gemm_blackwell.py +252 -0
- sglang/test/test_block_fp8_ep.py +1 -0
- sglang/test/test_utils.py +3 -1
- sglang/utils.py +9 -0
- sglang/version.py +1 -1
- {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/METADATA +5 -5
- {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/RECORD +99 -88
- {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/top_level.txt +0 -0
@@ -18,7 +18,6 @@ from typing import Optional
|
|
18
18
|
|
19
19
|
from torch import nn
|
20
20
|
|
21
|
-
from sglang.srt.layers.linear import UnquantizedLinearMethod
|
22
21
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
23
22
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
24
23
|
|
@@ -52,9 +51,9 @@ class RadixAttention(nn.Module):
|
|
52
51
|
sliding_window_size: int = -1,
|
53
52
|
is_cross_attention: bool = False,
|
54
53
|
quant_config: Optional[QuantizationConfig] = None,
|
55
|
-
attn_type=AttentionType.DECODER,
|
56
|
-
prefix: str = "",
|
54
|
+
attn_type: AttentionType = AttentionType.DECODER,
|
57
55
|
use_irope: bool = False,
|
56
|
+
prefix: str = "",
|
58
57
|
):
|
59
58
|
super().__init__()
|
60
59
|
self.tp_q_head_num = num_heads
|
sglang/srt/lora/lora_manager.py
CHANGED
@@ -81,7 +81,7 @@ class LoRAManager:
|
|
81
81
|
seg_indptr=torch.zeros(
|
82
82
|
self.max_bs_in_cuda_graph + 1, dtype=torch.int32
|
83
83
|
),
|
84
|
-
max_len=
|
84
|
+
max_len=1,
|
85
85
|
weight_indices=torch.zeros(
|
86
86
|
self.max_bs_in_cuda_graph, dtype=torch.int32
|
87
87
|
),
|
@@ -89,6 +89,17 @@ class LoRAManager:
|
|
89
89
|
scalings=torch.zeros(self.max_loras_per_batch, dtype=torch.float),
|
90
90
|
)
|
91
91
|
|
92
|
+
# Initialize seg_lens and seg_indptr for CUDA graph as they remain constant
|
93
|
+
# across batches.
|
94
|
+
self.cuda_graph_batch_info.seg_lens[: self.max_bs_in_cuda_graph].fill_(1)
|
95
|
+
torch.cumsum(
|
96
|
+
self.cuda_graph_batch_info.seg_lens[: self.max_bs_in_cuda_graph],
|
97
|
+
dim=0,
|
98
|
+
out=self.cuda_graph_batch_info.seg_indptr[
|
99
|
+
1 : self.max_bs_in_cuda_graph + 1
|
100
|
+
],
|
101
|
+
)
|
102
|
+
|
92
103
|
def init_loras(self):
|
93
104
|
# Config of each LoRA adapter
|
94
105
|
self.configs: Dict[str, LoRAConfig] = {}
|
@@ -159,6 +170,45 @@ class LoRAManager:
|
|
159
170
|
# set up batch info shared by all lora modules
|
160
171
|
bs = forward_batch.batch_size
|
161
172
|
|
173
|
+
def transfer_adapter_info(
|
174
|
+
weight_indices_out: torch.Tensor,
|
175
|
+
lora_ranks_out: torch.Tensor,
|
176
|
+
scalings_out: torch.Tensor,
|
177
|
+
):
|
178
|
+
"""
|
179
|
+
Transfer adapter metadata (weight indices, LoRA rank, scalings) from host
|
180
|
+
to device (CUDA) asynchronously.
|
181
|
+
"""
|
182
|
+
weight_indices = [0] * len(forward_batch.lora_paths)
|
183
|
+
lora_ranks = [0] * self.max_loras_per_batch
|
184
|
+
scalings = [0] * self.max_loras_per_batch
|
185
|
+
for i, lora_path in enumerate(forward_batch.lora_paths):
|
186
|
+
weight_indices[i] = self.memory_pool.get_buffer_id(lora_path)
|
187
|
+
if lora_path is not None:
|
188
|
+
lora = self.loras[lora_path]
|
189
|
+
lora_ranks[weight_indices[i]] = lora.config.hf_config["r"]
|
190
|
+
scalings[weight_indices[i]] = lora.scaling
|
191
|
+
|
192
|
+
# Use pinned memory to avoid synchronizations during host-to-device transfer
|
193
|
+
weight_indices_tensor = torch.tensor(
|
194
|
+
weight_indices, dtype=torch.int32, pin_memory=True, device="cpu"
|
195
|
+
)
|
196
|
+
lora_ranks_tensor = torch.tensor(
|
197
|
+
lora_ranks, dtype=torch.int32, pin_memory=True, device="cpu"
|
198
|
+
)
|
199
|
+
scalings_tensor = torch.tensor(
|
200
|
+
scalings, dtype=torch.float, pin_memory=True, device="cpu"
|
201
|
+
)
|
202
|
+
|
203
|
+
# Copy to device tensors asynchronously
|
204
|
+
weight_indices_out[:bs].copy_(weight_indices_tensor, non_blocking=True)
|
205
|
+
lora_ranks_out[: self.max_loras_per_batch].copy_(
|
206
|
+
lora_ranks_tensor, non_blocking=True
|
207
|
+
)
|
208
|
+
scalings_out[: self.max_loras_per_batch].copy_(
|
209
|
+
scalings_tensor, non_blocking=True
|
210
|
+
)
|
211
|
+
|
162
212
|
if (
|
163
213
|
hasattr(self, "max_bs_in_cuda_graph")
|
164
214
|
and bs <= self.max_bs_in_cuda_graph
|
@@ -166,51 +216,46 @@ class LoRAManager:
|
|
166
216
|
):
|
167
217
|
# Do in-place updates when CUDA graph is enabled and the batch forward mode
|
168
218
|
# could use CUDA graph.
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
self.cuda_graph_batch_info.
|
173
|
-
|
174
|
-
out=self.cuda_graph_batch_info.seg_indptr[1 : bs + 1],
|
219
|
+
|
220
|
+
transfer_adapter_info(
|
221
|
+
self.cuda_graph_batch_info.weight_indices,
|
222
|
+
self.cuda_graph_batch_info.lora_ranks,
|
223
|
+
self.cuda_graph_batch_info.scalings,
|
175
224
|
)
|
176
|
-
self.cuda_graph_batch_info.max_len = 1
|
177
225
|
|
178
|
-
|
179
|
-
|
180
|
-
self.memory_pool.get_buffer_id(lora_path)
|
181
|
-
)
|
182
|
-
if lora_path is not None:
|
183
|
-
lora = self.loras[lora_path]
|
184
|
-
self.cuda_graph_batch_info.lora_ranks[
|
185
|
-
self.cuda_graph_batch_info.weight_indices[i]
|
186
|
-
] = lora.config.hf_config["r"]
|
187
|
-
self.cuda_graph_batch_info.scalings[
|
188
|
-
self.cuda_graph_batch_info.weight_indices[i]
|
189
|
-
] = lora.scaling
|
226
|
+
self.cuda_graph_batch_info.bs = bs
|
227
|
+
self.cuda_graph_batch_info.max_len = 1
|
190
228
|
batch_info = self.cuda_graph_batch_info
|
191
229
|
else:
|
230
|
+
weight_indices = torch.empty((bs,), dtype=torch.int32, device=self.device)
|
231
|
+
lora_ranks = torch.zeros(
|
232
|
+
(self.max_loras_per_batch,), dtype=torch.int64, device=self.device
|
233
|
+
)
|
234
|
+
scalings = torch.zeros(
|
235
|
+
(self.max_loras_per_batch,), dtype=torch.float, device=self.device
|
236
|
+
)
|
237
|
+
transfer_adapter_info(
|
238
|
+
weight_indices,
|
239
|
+
lora_ranks,
|
240
|
+
scalings,
|
241
|
+
)
|
242
|
+
|
192
243
|
seg_lens = (
|
193
244
|
forward_batch.extend_seq_lens
|
194
245
|
if forward_batch.forward_mode.is_extend()
|
195
246
|
else torch.ones(bs, device=self.device)
|
196
247
|
)
|
248
|
+
|
249
|
+
max_len = (
|
250
|
+
# Calculate max_len from the CPU copy to avoid D2H transfer.
|
251
|
+
max(forward_batch.extend_seq_lens_cpu)
|
252
|
+
if forward_batch.forward_mode.is_extend()
|
253
|
+
else 1
|
254
|
+
)
|
255
|
+
|
197
256
|
seg_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device=self.device)
|
198
257
|
seg_indptr[1:] = torch.cumsum(seg_lens, dim=0)
|
199
|
-
max_len = int(torch.max(seg_lens))
|
200
|
-
weight_indices = torch.empty((bs,), dtype=torch.int64, device=self.device)
|
201
258
|
|
202
|
-
lora_ranks = torch.zeros(
|
203
|
-
(self.max_loras_per_batch,), dtype=torch.int64, device="cuda"
|
204
|
-
)
|
205
|
-
scalings = torch.zeros(
|
206
|
-
(self.max_loras_per_batch,), dtype=torch.float, device="cuda"
|
207
|
-
)
|
208
|
-
for i, lora_path in enumerate(forward_batch.lora_paths):
|
209
|
-
weight_indices[i] = self.memory_pool.get_buffer_id(lora_path)
|
210
|
-
if lora_path is not None:
|
211
|
-
lora = self.loras[lora_path]
|
212
|
-
lora_ranks[weight_indices[i]] = lora.config.hf_config["r"]
|
213
|
-
scalings[weight_indices[i]] = lora.scaling
|
214
259
|
batch_info = LoRABatchInfo(
|
215
260
|
bs=bs,
|
216
261
|
seg_lens=seg_lens,
|
sglang/srt/lora/mem_pool.py
CHANGED
@@ -132,12 +132,13 @@ class LoRAMemoryPool:
|
|
132
132
|
for buffer_id in range(self.max_loras_per_batch):
|
133
133
|
# Prioritize empty slots
|
134
134
|
if self.buffer_id_to_uid[buffer_id] == "":
|
135
|
-
return buffer_id
|
135
|
+
return buffer_id
|
136
136
|
|
137
137
|
for buffer_id in range(self.max_loras_per_batch):
|
138
138
|
# Evict unneeded lora
|
139
139
|
if self.buffer_id_to_uid[buffer_id] not in cur_uids:
|
140
|
-
|
140
|
+
self.uid_to_buffer_id.pop(self.buffer_id_to_uid[buffer_id])
|
141
|
+
return buffer_id
|
141
142
|
|
142
143
|
raise ValueError(
|
143
144
|
"No available buffer slots found. Please ensure the number of active loras is less than max_loras_per_batch."
|
@@ -145,9 +146,7 @@ class LoRAMemoryPool:
|
|
145
146
|
|
146
147
|
for uid in cur_uids:
|
147
148
|
if uid not in self.uid_to_buffer_id:
|
148
|
-
buffer_id
|
149
|
-
if evicted_lora_uid != "":
|
150
|
-
self.uid_to_buffer_id.pop(evicted_lora_uid)
|
149
|
+
buffer_id = get_available_buffer_slot()
|
151
150
|
self.load_lora_weight_to_buffer(
|
152
151
|
uid, buffer_id, lora_adapters.get(uid, None)
|
153
152
|
)
|
@@ -22,7 +22,8 @@ from typing import List, Optional
|
|
22
22
|
|
23
23
|
import torch
|
24
24
|
|
25
|
-
from sglang.srt.mem_cache.memory_pool import
|
25
|
+
from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator
|
26
|
+
from sglang.srt.mem_cache.memory_pool_host import HostKVCache
|
26
27
|
|
27
28
|
logger = logging.getLogger(__name__)
|
28
29
|
|
sglang/srt/managers/io_struct.py
CHANGED
@@ -87,7 +87,7 @@ class GenerateReqInput:
|
|
87
87
|
|
88
88
|
# The modalities of the image data [image, multi-images, video]
|
89
89
|
modalities: Optional[List[str]] = None
|
90
|
-
# LoRA
|
90
|
+
# The path to the LoRA
|
91
91
|
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
|
92
92
|
|
93
93
|
# Session info for continual prompting
|
@@ -99,7 +99,7 @@ class GenerateReqInput:
|
|
99
99
|
custom_logit_processor: Optional[Union[List[Optional[str]], str]] = None
|
100
100
|
|
101
101
|
# Whether to return hidden states
|
102
|
-
return_hidden_states: bool = False
|
102
|
+
return_hidden_states: Union[List[bool], bool] = False
|
103
103
|
|
104
104
|
# For disaggregated inference
|
105
105
|
bootstrap_host: Optional[Union[List[str], str]] = None
|
@@ -409,7 +409,11 @@ class GenerateReqInput:
|
|
409
409
|
if self.custom_logit_processor is not None
|
410
410
|
else None
|
411
411
|
),
|
412
|
-
return_hidden_states=
|
412
|
+
return_hidden_states=(
|
413
|
+
self.return_hidden_states[i]
|
414
|
+
if isinstance(self.return_hidden_states, list)
|
415
|
+
else self.return_hidden_states
|
416
|
+
),
|
413
417
|
# if `__getitem__` is called, the bootstrap_host, bootstrap_port, bootstrap_room must be a list
|
414
418
|
bootstrap_host=(
|
415
419
|
self.bootstrap_host[i] if self.bootstrap_host is not None else None
|
@@ -477,7 +481,7 @@ class TokenizedGenerateReqInput:
|
|
477
481
|
@dataclass
|
478
482
|
class EmbeddingReqInput:
|
479
483
|
# The input prompt. It can be a single prompt or a batch of prompts.
|
480
|
-
text: Optional[Union[List[str], str]] = None
|
484
|
+
text: Optional[Union[List[List[str]], List[str], str]] = None
|
481
485
|
# The image input. It can be an image instance, file name, URL, or base64 encoded string.
|
482
486
|
# Can be formatted as:
|
483
487
|
# - Single image for a single request
|
@@ -501,6 +505,8 @@ class EmbeddingReqInput:
|
|
501
505
|
log_metrics: bool = True
|
502
506
|
# The modalities of the image data [image, multi-images, video]
|
503
507
|
modalities: Optional[List[str]] = None
|
508
|
+
# For cross-encoder requests
|
509
|
+
is_cross_encoder_request: bool = False
|
504
510
|
|
505
511
|
def contains_mm_input(self) -> bool:
|
506
512
|
return has_valid_data(self.image_data) or has_valid_data(self.audio_data)
|
@@ -560,6 +566,16 @@ class EmbeddingReqInput:
|
|
560
566
|
return self.rid
|
561
567
|
|
562
568
|
def __getitem__(self, i):
|
569
|
+
if self.is_cross_encoder_request:
|
570
|
+
return EmbeddingReqInput(
|
571
|
+
text=[self.text[i]] if self.text is not None else None,
|
572
|
+
input_ids=None,
|
573
|
+
image_data=None,
|
574
|
+
sampling_params=self.sampling_params[i],
|
575
|
+
rid=self.rid[i],
|
576
|
+
is_cross_encoder_request=True,
|
577
|
+
)
|
578
|
+
|
563
579
|
return EmbeddingReqInput(
|
564
580
|
text=self.text[i] if self.text is not None else None,
|
565
581
|
input_ids=self.input_ids[i] if self.input_ids is not None else None,
|
@@ -579,6 +595,8 @@ class TokenizedEmbeddingReqInput:
|
|
579
595
|
input_ids: List[int]
|
580
596
|
# The image inputs
|
581
597
|
image_inputs: dict
|
598
|
+
# The token type ids
|
599
|
+
token_type_ids: List[int]
|
582
600
|
# Dummy sampling params for compatibility
|
583
601
|
sampling_params: SamplingParams
|
584
602
|
|
@@ -843,6 +861,12 @@ class SetInternalStateReq:
|
|
843
861
|
server_args: Dict[str, Any]
|
844
862
|
|
845
863
|
|
864
|
+
@dataclass
|
865
|
+
class V1RerankReqInput:
|
866
|
+
query: str
|
867
|
+
documents: List[str]
|
868
|
+
|
869
|
+
|
846
870
|
@dataclass
|
847
871
|
class SetInternalStateReqOutput:
|
848
872
|
updated: bool
|
@@ -146,7 +146,7 @@ class BaseMultimodalProcessor(ABC):
|
|
146
146
|
request_obj,
|
147
147
|
max_req_input_len,
|
148
148
|
**kwargs,
|
149
|
-
):
|
149
|
+
) -> Optional[Dict[str, Any]]:
|
150
150
|
pass
|
151
151
|
|
152
152
|
def get_estimated_frames_list(self, image_data):
|
@@ -261,7 +261,7 @@ class BaseMultimodalProcessor(ABC):
|
|
261
261
|
|
262
262
|
def load_mm_data(
|
263
263
|
self,
|
264
|
-
prompt: str,
|
264
|
+
prompt: str | List[int],
|
265
265
|
multimodal_tokens: MultimodalSpecialTokens,
|
266
266
|
max_req_input_len: int,
|
267
267
|
image_data: Optional[list] = None,
|
@@ -0,0 +1,85 @@
|
|
1
|
+
from typing import Any, Dict, List, Optional, Type, cast
|
2
|
+
|
3
|
+
import torch.nn as nn
|
4
|
+
from transformers.configuration_utils import PretrainedConfig
|
5
|
+
from transformers.processing_utils import ProcessorMixin
|
6
|
+
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
|
7
|
+
|
8
|
+
from sglang.srt.managers.io_struct import (
|
9
|
+
EmbeddingReqInput,
|
10
|
+
GenerateReqInput,
|
11
|
+
ImageDataItem,
|
12
|
+
)
|
13
|
+
from sglang.srt.managers.multimodal_processors.base_processor import (
|
14
|
+
BaseMultimodalProcessor,
|
15
|
+
MultimodalSpecialTokens,
|
16
|
+
)
|
17
|
+
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
|
18
|
+
from sglang.srt.models.vila import VILAForConditionalGeneration
|
19
|
+
from sglang.srt.server_args import ServerArgs
|
20
|
+
|
21
|
+
|
22
|
+
class VILAProcessor(ProcessorMixin):
|
23
|
+
"""A stub class for the VILA processor."""
|
24
|
+
|
25
|
+
tokenizer: PreTrainedTokenizerBase
|
26
|
+
|
27
|
+
|
28
|
+
class VILAMultimodalProcessor(BaseMultimodalProcessor):
|
29
|
+
models: List[Type[nn.Module]] = [VILAForConditionalGeneration]
|
30
|
+
|
31
|
+
_processor: VILAProcessor
|
32
|
+
|
33
|
+
def __init__(
|
34
|
+
self,
|
35
|
+
hf_config: PretrainedConfig,
|
36
|
+
server_args: ServerArgs,
|
37
|
+
_processor: VILAProcessor,
|
38
|
+
) -> None:
|
39
|
+
super().__init__(hf_config, server_args, _processor)
|
40
|
+
|
41
|
+
async def process_mm_data_async(
|
42
|
+
self,
|
43
|
+
image_data: Optional[ImageDataItem | List[ImageDataItem]],
|
44
|
+
input_text: str | List[int],
|
45
|
+
request_obj: GenerateReqInput | EmbeddingReqInput,
|
46
|
+
max_req_input_len: int,
|
47
|
+
**kwargs,
|
48
|
+
) -> Optional[Dict[str, Any]]:
|
49
|
+
if not image_data:
|
50
|
+
return None
|
51
|
+
|
52
|
+
if not isinstance(image_data, list):
|
53
|
+
image_data = [image_data]
|
54
|
+
|
55
|
+
mm_data = self.load_mm_data(
|
56
|
+
prompt=input_text,
|
57
|
+
multimodal_tokens=MultimodalSpecialTokens(
|
58
|
+
image_token=self._processor.tokenizer.image_token
|
59
|
+
),
|
60
|
+
max_req_input_len=max_req_input_len,
|
61
|
+
image_data=image_data,
|
62
|
+
)
|
63
|
+
|
64
|
+
inputs = self.process_mm_data(
|
65
|
+
input_text=mm_data.input_text,
|
66
|
+
images=mm_data.images,
|
67
|
+
)
|
68
|
+
|
69
|
+
image_offsets = self.get_mm_items_offset(
|
70
|
+
input_ids=inputs.input_ids[0],
|
71
|
+
mm_token_id=cast(int, self._processor.tokenizer.image_token_id),
|
72
|
+
)
|
73
|
+
|
74
|
+
mm_items: List[MultimodalDataItem] = [
|
75
|
+
MultimodalDataItem(
|
76
|
+
modality=Modality.IMAGE,
|
77
|
+
image_offsets=image_offsets,
|
78
|
+
pixel_values=inputs.pixel_values,
|
79
|
+
)
|
80
|
+
]
|
81
|
+
|
82
|
+
return dict(
|
83
|
+
input_ids=inputs.input_ids[0].tolist(),
|
84
|
+
mm_items=mm_items,
|
85
|
+
)
|
@@ -72,32 +72,33 @@ INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
|
|
72
72
|
|
73
73
|
GLOBAL_SERVER_ARGS_KEYS = [
|
74
74
|
"attention_backend",
|
75
|
+
"mm_attention_backend",
|
75
76
|
"debug_tensor_dump_inject",
|
76
77
|
"debug_tensor_dump_output_folder",
|
77
78
|
"chunked_prefill_size",
|
78
|
-
"deepep_mode",
|
79
79
|
"device",
|
80
80
|
"disable_chunked_prefix_cache",
|
81
81
|
"disable_radix_cache",
|
82
|
-
"enable_deepep_moe",
|
83
82
|
"enable_dp_attention",
|
84
83
|
"enable_two_batch_overlap",
|
85
84
|
"enable_dp_lm_head",
|
85
|
+
"enable_deepep_moe",
|
86
|
+
"deepep_mode",
|
86
87
|
"enable_ep_moe",
|
88
|
+
"moe_dense_tp_size",
|
89
|
+
"ep_dispatch_algorithm",
|
87
90
|
"deepep_config",
|
91
|
+
"ep_num_redundant_experts",
|
88
92
|
"enable_nan_detection",
|
89
93
|
"flashinfer_mla_disable_ragged",
|
90
94
|
"max_micro_batch_size",
|
91
|
-
"moe_dense_tp_size",
|
92
|
-
"ep_dispatch_algorithm",
|
93
95
|
"disable_shared_experts_fusion",
|
94
96
|
"sampling_backend",
|
95
97
|
"speculative_accept_threshold_acc",
|
96
98
|
"speculative_accept_threshold_single",
|
97
99
|
"torchao_config",
|
98
100
|
"triton_attention_reduce_in_fp32",
|
99
|
-
"
|
100
|
-
"mm_attention_backend",
|
101
|
+
"num_reserved_decode_tokens",
|
101
102
|
]
|
102
103
|
|
103
104
|
# Put some global args for easy access
|
@@ -444,6 +445,7 @@ class Req:
|
|
444
445
|
origin_input_ids_unpadded: Optional[Tuple[int]] = None,
|
445
446
|
lora_path: Optional[str] = None,
|
446
447
|
input_embeds: Optional[List[List[float]]] = None,
|
448
|
+
token_type_ids: List[int] = None,
|
447
449
|
session_id: Optional[str] = None,
|
448
450
|
custom_logit_processor: Optional[str] = None,
|
449
451
|
return_hidden_states: bool = False,
|
@@ -469,6 +471,9 @@ class Req:
|
|
469
471
|
self.session_id = session_id
|
470
472
|
self.input_embeds = input_embeds
|
471
473
|
|
474
|
+
# for corss-endoder model
|
475
|
+
self.token_type_ids = token_type_ids
|
476
|
+
|
472
477
|
# Sampling info
|
473
478
|
if isinstance(sampling_params.custom_params, dict):
|
474
479
|
sampling_params = copy.copy(sampling_params)
|
@@ -840,6 +845,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
840
845
|
# Batched arguments to model runner
|
841
846
|
input_ids: torch.Tensor = None # shape: [b], int64
|
842
847
|
input_embeds: torch.Tensor = None # shape: [b, hidden_size], float32
|
848
|
+
token_type_ids: torch.Tensor = None # shape: [b], int64
|
843
849
|
req_pool_indices: torch.Tensor = None # shape: [b], int64
|
844
850
|
seq_lens: torch.Tensor = None # shape: [b], int64
|
845
851
|
# The output locations of the KV cache
|
@@ -1141,6 +1147,10 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1141
1147
|
prefix_lens = [len(r.prefix_indices) for r in reqs]
|
1142
1148
|
extend_lens = [r.extend_input_len for r in reqs]
|
1143
1149
|
|
1150
|
+
token_type_ids = [
|
1151
|
+
r.token_type_ids for r in reqs if r.token_type_ids is not None
|
1152
|
+
]
|
1153
|
+
|
1144
1154
|
req_pool_indices_tensor = torch.tensor(req_pool_indices, dtype=torch.int64).to(
|
1145
1155
|
self.device, non_blocking=True
|
1146
1156
|
)
|
@@ -1153,6 +1163,13 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1153
1163
|
prefix_lens_tensor = torch.tensor(
|
1154
1164
|
prefix_lens, dtype=torch.int64, device=self.device
|
1155
1165
|
)
|
1166
|
+
|
1167
|
+
token_type_ids_tensor = None
|
1168
|
+
if len(token_type_ids) > 0:
|
1169
|
+
token_type_ids_tensor = torch.tensor(
|
1170
|
+
sum(token_type_ids, []), dtype=torch.int64
|
1171
|
+
).to(self.device, non_blocking=True)
|
1172
|
+
|
1156
1173
|
extend_lens_tensor = seq_lens_tensor - prefix_lens_tensor
|
1157
1174
|
|
1158
1175
|
# Copy prefix and do some basic check
|
@@ -1268,6 +1285,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1268
1285
|
self.device, non_blocking=True
|
1269
1286
|
)
|
1270
1287
|
self.multimodal_inputs = multimodal_inputs
|
1288
|
+
self.token_type_ids = token_type_ids_tensor
|
1271
1289
|
self.seq_lens_sum = sum(seq_lens)
|
1272
1290
|
|
1273
1291
|
if self.return_logprob:
|
@@ -1414,6 +1432,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1414
1432
|
req = self.reqs[idx]
|
1415
1433
|
retracted_reqs.append(req)
|
1416
1434
|
|
1435
|
+
if server_args.disaggregation_mode == "decode":
|
1436
|
+
req.offload_kv_cache(
|
1437
|
+
self.req_to_token_pool, self.token_to_kv_pool_allocator
|
1438
|
+
)
|
1439
|
+
|
1417
1440
|
if isinstance(self.tree_cache, ChunkCache):
|
1418
1441
|
# ChunkCache does not have eviction
|
1419
1442
|
token_indices = self.req_to_token_pool.req_to_token[
|
@@ -1445,6 +1468,12 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1445
1468
|
|
1446
1469
|
req.reset_for_retract()
|
1447
1470
|
|
1471
|
+
if len(retracted_reqs) == 0:
|
1472
|
+
# Corner case: only one request left
|
1473
|
+
raise ValueError(
|
1474
|
+
"Failed to retract any request. No space left for only one request."
|
1475
|
+
)
|
1476
|
+
|
1448
1477
|
self.filter_batch(keep_indices=sorted_indices)
|
1449
1478
|
|
1450
1479
|
# Reqs in batch are filtered
|
@@ -1702,6 +1731,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1702
1731
|
lora_paths=[req.lora_path for req in self.reqs],
|
1703
1732
|
sampling_info=self.sampling_info,
|
1704
1733
|
input_embeds=self.input_embeds,
|
1734
|
+
token_type_ids=self.token_type_ids,
|
1705
1735
|
spec_algorithm=self.spec_algorithm,
|
1706
1736
|
spec_info=self.spec_info,
|
1707
1737
|
capture_hidden_mode=(
|
@@ -1795,6 +1825,9 @@ class ModelWorkerBatch:
|
|
1795
1825
|
# The input Embeds
|
1796
1826
|
input_embeds: Optional[torch.tensor] = None
|
1797
1827
|
|
1828
|
+
# For corss-encoder model
|
1829
|
+
token_type_ids: Optional[torch.Tensor] = None
|
1830
|
+
|
1798
1831
|
# Speculative decoding
|
1799
1832
|
spec_algorithm: SpeculativeAlgorithm = None
|
1800
1833
|
spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput]] = None
|