sglang 0.4.6.post3__py3-none-any.whl → 0.4.6.post5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +10 -8
- sglang/bench_one_batch.py +7 -6
- sglang/bench_one_batch_server.py +157 -21
- sglang/bench_serving.py +137 -59
- sglang/compile_deep_gemm.py +5 -5
- sglang/eval/loogle_eval.py +157 -0
- sglang/lang/chat_template.py +78 -78
- sglang/lang/tracer.py +1 -1
- sglang/srt/code_completion_parser.py +1 -1
- sglang/srt/configs/deepseekvl2.py +2 -2
- sglang/srt/configs/model_config.py +40 -28
- sglang/srt/constrained/base_grammar_backend.py +55 -72
- sglang/srt/constrained/llguidance_backend.py +25 -21
- sglang/srt/constrained/outlines_backend.py +27 -26
- sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
- sglang/srt/constrained/xgrammar_backend.py +69 -43
- sglang/srt/conversation.py +49 -44
- sglang/srt/disaggregation/base/conn.py +1 -0
- sglang/srt/disaggregation/decode.py +129 -135
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +142 -0
- sglang/srt/disaggregation/fake/conn.py +3 -13
- sglang/srt/disaggregation/kv_events.py +357 -0
- sglang/srt/disaggregation/mini_lb.py +57 -24
- sglang/srt/disaggregation/mooncake/conn.py +238 -122
- sglang/srt/disaggregation/mooncake/transfer_engine.py +2 -1
- sglang/srt/disaggregation/nixl/conn.py +10 -19
- sglang/srt/disaggregation/prefill.py +132 -47
- sglang/srt/disaggregation/utils.py +123 -6
- sglang/srt/distributed/utils.py +3 -3
- sglang/srt/entrypoints/EngineBase.py +5 -0
- sglang/srt/entrypoints/engine.py +44 -9
- sglang/srt/entrypoints/http_server.py +23 -6
- sglang/srt/entrypoints/http_server_engine.py +5 -2
- sglang/srt/function_call/base_format_detector.py +250 -0
- sglang/srt/function_call/core_types.py +34 -0
- sglang/srt/function_call/deepseekv3_detector.py +157 -0
- sglang/srt/function_call/ebnf_composer.py +234 -0
- sglang/srt/function_call/function_call_parser.py +175 -0
- sglang/srt/function_call/llama32_detector.py +74 -0
- sglang/srt/function_call/mistral_detector.py +84 -0
- sglang/srt/function_call/pythonic_detector.py +163 -0
- sglang/srt/function_call/qwen25_detector.py +67 -0
- sglang/srt/function_call/utils.py +35 -0
- sglang/srt/hf_transformers_utils.py +46 -7
- sglang/srt/layers/attention/aiter_backend.py +513 -0
- sglang/srt/layers/attention/flashattention_backend.py +64 -18
- sglang/srt/layers/attention/flashinfer_mla_backend.py +8 -4
- sglang/srt/layers/attention/flashmla_backend.py +340 -78
- sglang/srt/layers/attention/triton_backend.py +3 -0
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
- sglang/srt/layers/attention/utils.py +6 -4
- sglang/srt/layers/attention/vision.py +1 -1
- sglang/srt/layers/communicator.py +451 -0
- sglang/srt/layers/dp_attention.py +61 -21
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/logits_processor.py +46 -11
- sglang/srt/layers/moe/cutlass_moe.py +207 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +34 -12
- sglang/srt/layers/moe/ep_moe/layer.py +105 -51
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +82 -7
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -1
- sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -0
- sglang/srt/layers/moe/topk.py +67 -10
- sglang/srt/layers/multimodal.py +70 -0
- sglang/srt/layers/quantization/__init__.py +8 -3
- sglang/srt/layers/quantization/blockwise_int8.py +2 -2
- sglang/srt/layers/quantization/deep_gemm.py +77 -74
- sglang/srt/layers/quantization/fp8.py +92 -2
- sglang/srt/layers/quantization/fp8_kernel.py +3 -3
- sglang/srt/layers/quantization/fp8_utils.py +6 -0
- sglang/srt/layers/quantization/gptq.py +298 -6
- sglang/srt/layers/quantization/int8_kernel.py +20 -7
- sglang/srt/layers/quantization/qoq.py +244 -0
- sglang/srt/layers/sampler.py +0 -4
- sglang/srt/layers/vocab_parallel_embedding.py +18 -7
- sglang/srt/lora/lora_manager.py +2 -4
- sglang/srt/lora/mem_pool.py +4 -4
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
- sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
- sglang/srt/lora/utils.py +1 -1
- sglang/srt/managers/data_parallel_controller.py +3 -3
- sglang/srt/managers/deepseek_eplb.py +278 -0
- sglang/srt/managers/detokenizer_manager.py +21 -8
- sglang/srt/managers/eplb_manager.py +55 -0
- sglang/srt/managers/expert_distribution.py +704 -56
- sglang/srt/managers/expert_location.py +394 -0
- sglang/srt/managers/expert_location_dispatch.py +91 -0
- sglang/srt/managers/io_struct.py +19 -4
- sglang/srt/managers/mm_utils.py +294 -140
- sglang/srt/managers/multimodal_processors/base_processor.py +127 -42
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +6 -1
- sglang/srt/managers/multimodal_processors/gemma3.py +31 -6
- sglang/srt/managers/multimodal_processors/internvl.py +14 -5
- sglang/srt/managers/multimodal_processors/janus_pro.py +7 -1
- sglang/srt/managers/multimodal_processors/kimi_vl.py +7 -6
- sglang/srt/managers/multimodal_processors/llava.py +46 -0
- sglang/srt/managers/multimodal_processors/minicpm.py +25 -31
- sglang/srt/managers/multimodal_processors/mllama4.py +6 -0
- sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
- sglang/srt/managers/multimodal_processors/qwen_vl.py +58 -16
- sglang/srt/managers/schedule_batch.py +122 -42
- sglang/srt/managers/schedule_policy.py +1 -5
- sglang/srt/managers/scheduler.py +205 -138
- sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +232 -58
- sglang/srt/managers/tp_worker.py +12 -9
- sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
- sglang/srt/mem_cache/base_prefix_cache.py +3 -0
- sglang/srt/mem_cache/chunk_cache.py +3 -1
- sglang/srt/mem_cache/hiradix_cache.py +4 -4
- sglang/srt/mem_cache/memory_pool.py +76 -52
- sglang/srt/mem_cache/multimodal_cache.py +45 -0
- sglang/srt/mem_cache/radix_cache.py +58 -5
- sglang/srt/metrics/collector.py +314 -39
- sglang/srt/mm_utils.py +10 -0
- sglang/srt/model_executor/cuda_graph_runner.py +29 -19
- sglang/srt/model_executor/expert_location_updater.py +422 -0
- sglang/srt/model_executor/forward_batch_info.py +5 -1
- sglang/srt/model_executor/model_runner.py +163 -68
- sglang/srt/model_loader/loader.py +10 -6
- sglang/srt/models/clip.py +5 -1
- sglang/srt/models/deepseek_janus_pro.py +2 -2
- sglang/srt/models/deepseek_v2.py +308 -351
- sglang/srt/models/exaone.py +8 -3
- sglang/srt/models/gemma3_mm.py +70 -33
- sglang/srt/models/llama.py +2 -0
- sglang/srt/models/llama4.py +15 -8
- sglang/srt/models/llava.py +258 -7
- sglang/srt/models/mimo_mtp.py +220 -0
- sglang/srt/models/minicpmo.py +5 -12
- sglang/srt/models/mistral.py +71 -1
- sglang/srt/models/mixtral.py +98 -34
- sglang/srt/models/mllama.py +3 -3
- sglang/srt/models/pixtral.py +467 -0
- sglang/srt/models/qwen2.py +95 -26
- sglang/srt/models/qwen2_5_vl.py +8 -0
- sglang/srt/models/qwen2_moe.py +330 -60
- sglang/srt/models/qwen2_vl.py +6 -0
- sglang/srt/models/qwen3.py +52 -10
- sglang/srt/models/qwen3_moe.py +411 -48
- sglang/srt/models/roberta.py +1 -1
- sglang/srt/models/siglip.py +294 -0
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/openai_api/adapter.py +58 -20
- sglang/srt/openai_api/protocol.py +6 -8
- sglang/srt/operations.py +154 -0
- sglang/srt/operations_strategy.py +31 -0
- sglang/srt/reasoning_parser.py +3 -3
- sglang/srt/sampling/custom_logit_processor.py +18 -3
- sglang/srt/sampling/sampling_batch_info.py +4 -56
- sglang/srt/sampling/sampling_params.py +2 -2
- sglang/srt/server_args.py +162 -22
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
- sglang/srt/speculative/eagle_utils.py +138 -7
- sglang/srt/speculative/eagle_worker.py +69 -21
- sglang/srt/utils.py +74 -17
- sglang/test/few_shot_gsm8k.py +2 -2
- sglang/test/few_shot_gsm8k_engine.py +2 -2
- sglang/test/run_eval.py +2 -2
- sglang/test/runners.py +8 -1
- sglang/test/send_one.py +13 -3
- sglang/test/simple_eval_common.py +1 -1
- sglang/test/simple_eval_humaneval.py +1 -1
- sglang/test/test_cutlass_moe.py +278 -0
- sglang/test/test_programs.py +5 -5
- sglang/test/test_utils.py +55 -14
- sglang/utils.py +3 -3
- sglang/version.py +1 -1
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/METADATA +23 -13
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/RECORD +178 -149
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/WHEEL +1 -1
- sglang/srt/function_call_parser.py +0 -858
- sglang/srt/platforms/interface.py +0 -371
- /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
- /sglang/srt/models/{xiaomi_mimo.py → mimo.py} +0 -0
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/top_level.txt +0 -0
sglang/srt/managers/mm_utils.py
CHANGED
@@ -2,6 +2,7 @@
|
|
2
2
|
Multi-modality utils
|
3
3
|
"""
|
4
4
|
|
5
|
+
import dataclasses
|
5
6
|
import logging
|
6
7
|
from abc import abstractmethod
|
7
8
|
from typing import Callable, List, Optional, Tuple
|
@@ -15,10 +16,15 @@ from sglang.srt.managers.schedule_batch import (
|
|
15
16
|
MultimodalInputs,
|
16
17
|
global_server_args_dict,
|
17
18
|
)
|
19
|
+
from sglang.srt.mem_cache.multimodal_cache import MultiModalCache
|
18
20
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
19
21
|
from sglang.srt.utils import flatten_nested_list, print_warning_once
|
22
|
+
from sglang.utils import logger
|
20
23
|
|
21
|
-
logger
|
24
|
+
# NOTE: Using the shared logger from sglang.utils instead of creating a module-specific logger
|
25
|
+
# to ensure consistent logging behavior across the codebase. This prevents issues with log
|
26
|
+
# propagation that can cause some log messages (like 'server is fired up') to not appear
|
27
|
+
# in the console when multimodal support is enabled.
|
22
28
|
|
23
29
|
|
24
30
|
class MultiModalityDataPaddingPattern:
|
@@ -41,17 +47,32 @@ class MultiModalityDataPaddingPattern:
|
|
41
47
|
class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern):
|
42
48
|
"""In this pattern, data tokens should be enclosed by special token pairs (e.g. <image>...</image>, data_token_pairs)
|
43
49
|
|
50
|
+
The padded value in a region enclosed by a token pair with be the same one, as the MultimodalDataItem's pad value
|
51
|
+
|
44
52
|
This strategy should be applied when data content is marked by start/end token pairs in the input sequence.
|
45
53
|
"""
|
46
54
|
|
47
|
-
def __init__(
|
55
|
+
def __init__(
|
56
|
+
self,
|
57
|
+
data_token_pairs: Optional[List[Tuple[int, int]]],
|
58
|
+
data_start_token_ids: Optional[List[int]] = None,
|
59
|
+
) -> None:
|
60
|
+
"""
|
61
|
+
|
62
|
+
Args:
|
63
|
+
data_start_token_ids marks the start of a single multimodal data
|
64
|
+
See Minicpmo's slice_start_id for example
|
65
|
+
"""
|
48
66
|
self.data_token_id_pairs = data_token_pairs
|
67
|
+
self.data_start_token_ids = data_start_token_ids or [
|
68
|
+
s for s, _e in data_token_pairs
|
69
|
+
]
|
49
70
|
|
50
71
|
def pad_input_tokens(
|
51
72
|
self, input_ids: List[int], mm_inputs: MultimodalInputs
|
52
73
|
) -> List[int]:
|
53
74
|
"""
|
54
|
-
This function will replace the data-tokens
|
75
|
+
This function will replace the data-tokens in between with pad_values accordingly
|
55
76
|
"""
|
56
77
|
pad_values = [item.pad_value for item in mm_inputs.mm_items]
|
57
78
|
data_token_pairs = self.data_token_id_pairs
|
@@ -79,7 +100,7 @@ class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern)
|
|
79
100
|
for start_idx, end_idx in zip(start_indices, end_indices):
|
80
101
|
padded_ids.extend(input_ids[last_idx : start_idx + 1])
|
81
102
|
|
82
|
-
if input_ids[start_idx] in
|
103
|
+
if input_ids[start_idx] in self.data_start_token_ids:
|
83
104
|
data_idx += 1
|
84
105
|
mm_inputs.data_offsets += [start_idx]
|
85
106
|
|
@@ -170,30 +191,140 @@ class MultiModalityDataPaddingPatternMultimodalTokens(MultiModalityDataPaddingPa
|
|
170
191
|
output_ids_tensor[start_idx:end_idx] = pad_value
|
171
192
|
else:
|
172
193
|
logger.warning(f"Skipping region {i} due to None pad_value.")
|
173
|
-
|
174
194
|
return output_ids_tensor.tolist()
|
175
195
|
|
176
196
|
|
197
|
+
embedding_cache = None
|
198
|
+
|
199
|
+
|
200
|
+
def init_embedding_cache(max_size: int):
|
201
|
+
global embedding_cache
|
202
|
+
embedding_cache = MultiModalCache(max_size)
|
203
|
+
|
204
|
+
|
205
|
+
def get_embedding_hash(embedding_items: List[MultimodalDataItem]) -> int:
|
206
|
+
hash_list = [item.hash for item in embedding_items]
|
207
|
+
return hash(tuple(hash_list))
|
208
|
+
|
209
|
+
|
210
|
+
def get_embedding_chunk(
|
211
|
+
embedding: torch.Tensor,
|
212
|
+
extend_prefix_len: int,
|
213
|
+
extend_seq_len: int,
|
214
|
+
items_offset: List[Tuple[int, int]],
|
215
|
+
) -> Tuple[torch.Tensor, int, int]:
|
216
|
+
"""
|
217
|
+
Extract a chunk of embeddings based on the specified prefix length, sequence length, and offset ranges.
|
218
|
+
|
219
|
+
Args:
|
220
|
+
embedding: The full embedding tensor to extract a chunk from
|
221
|
+
extend_prefix_len: The starting position (prefix length) for extraction
|
222
|
+
extend_seq_len: The number of tokens to extract
|
223
|
+
items_offset: List of [start, end] offset ranges for multimodal items in the input sequence
|
224
|
+
|
225
|
+
Returns:
|
226
|
+
A tuple containing:
|
227
|
+
- The extracted embedding chunk as a tensor
|
228
|
+
- The start index used for extraction
|
229
|
+
- The end index used for extraction
|
230
|
+
|
231
|
+
Note:
|
232
|
+
If there's no overlap between the requested range and the offset ranges,
|
233
|
+
an empty tensor is returned with zeros for start and end indices.
|
234
|
+
"""
|
235
|
+
start_index, end_index = 0, 0
|
236
|
+
extend_start_index = extend_prefix_len
|
237
|
+
extend_end_index = extend_prefix_len + extend_seq_len - 1
|
238
|
+
|
239
|
+
for start, end in items_offset:
|
240
|
+
if extend_start_index >= start and extend_start_index <= end:
|
241
|
+
start_index += extend_start_index - start
|
242
|
+
elif extend_start_index > end:
|
243
|
+
start_index += end - start + 1
|
244
|
+
|
245
|
+
if extend_end_index >= start and extend_end_index <= end:
|
246
|
+
end_index += extend_end_index - start + 1
|
247
|
+
elif extend_end_index > end:
|
248
|
+
end_index += end - start + 1
|
249
|
+
# some models embedding is 3-dim, reshape it to 2-dim
|
250
|
+
embedding = embedding.reshape(-1, embedding.shape[-1])
|
251
|
+
embedding_chunk = embedding[start_index:end_index]
|
252
|
+
return embedding_chunk, start_index, end_index
|
253
|
+
|
254
|
+
|
177
255
|
def get_embedding_and_mask(
|
178
256
|
data_embedding_func: Callable[[List[MultimodalDataItem]], torch.Tensor],
|
179
257
|
embedding_items: List[MultimodalDataItem],
|
180
258
|
placeholder_tensor: torch.Tensor,
|
181
259
|
input_ids: torch.Tensor,
|
182
|
-
|
260
|
+
items_size: List[int],
|
261
|
+
prefix_length: List[int],
|
262
|
+
extend_length: List[int],
|
263
|
+
items_offset_list: List[List[Tuple[int, int]]],
|
264
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
183
265
|
"""
|
184
|
-
|
266
|
+
Generate multimodal embeddings and create a mask for identifying their positions in the input sequence.
|
267
|
+
|
268
|
+
Args:
|
269
|
+
data_embedding_func: Function that generates embeddings for multimodal items
|
270
|
+
embedding_items: List of multimodal items to embed
|
271
|
+
placeholder_tensor: Tensor containing token IDs that serve as placeholders for multimodal content
|
272
|
+
input_ids: The input token IDs tensor
|
273
|
+
items_size: Cumulative sizes of multimodal items per request
|
274
|
+
prefix_length: Prefix lengths for each request
|
275
|
+
extend_length: Sequence lengths for each request
|
276
|
+
items_offset_list: List of offset ranges for multimodal items in each request
|
185
277
|
|
278
|
+
Returns:
|
279
|
+
A tuple containing:
|
280
|
+
- The generated embeddings tensor
|
281
|
+
- A boolean mask tensor indicating where these embeddings should be placed
|
282
|
+
|
283
|
+
Raises:
|
284
|
+
AssertionError: If the number of multimodal tokens in input_ids doesn't match
|
285
|
+
the number of tokens in the generated embeddings
|
186
286
|
"""
|
187
287
|
# 1. Get the embedding
|
188
|
-
embedding
|
288
|
+
# Calculate embedding for each request, try to get it from cache to avoid repeated calculation
|
289
|
+
embedding_list = []
|
290
|
+
for i in range(len(items_size) - 1):
|
291
|
+
if items_size[i] == items_size[i + 1]:
|
292
|
+
continue
|
293
|
+
embedding_items_per_req = embedding_items[items_size[i] : items_size[i + 1]]
|
294
|
+
items_offset = items_offset_list[i]
|
295
|
+
embedding_items_hash = get_embedding_hash(embedding_items_per_req)
|
296
|
+
# if all items has been prefixed, we do not need to calculate embedding
|
297
|
+
if all([offset_end < prefix_length[i] for _, offset_end in items_offset]):
|
298
|
+
continue
|
299
|
+
embedding_per_req = embedding_cache.get(embedding_items_hash)
|
300
|
+
if embedding_per_req is None:
|
301
|
+
embedding_per_req = data_embedding_func(embedding_items_per_req)
|
302
|
+
if not embedding_cache.put(embedding_items_hash, embedding_per_req):
|
303
|
+
print_warning_once(
|
304
|
+
"Multimodal embedding cache is full. Consider increasing the "
|
305
|
+
"`SGLANG_VLM_CACHE_SIZE_MB` environment variable."
|
306
|
+
)
|
189
307
|
|
308
|
+
embedding_per_req_chunk, _, end_index = get_embedding_chunk(
|
309
|
+
embedding=embedding_per_req,
|
310
|
+
extend_prefix_len=prefix_length[i],
|
311
|
+
extend_seq_len=extend_length[i],
|
312
|
+
items_offset=items_offset,
|
313
|
+
)
|
314
|
+
# remove this item from cache if chunk reaches to the end
|
315
|
+
embedding_per_req_length = (
|
316
|
+
embedding_per_req.shape[0]
|
317
|
+
if embedding_per_req.dim() == 2
|
318
|
+
else embedding_per_req.shape[0] * embedding_per_req.shape[1]
|
319
|
+
)
|
320
|
+
if end_index == embedding_per_req_length:
|
321
|
+
embedding_cache.free(embedding_items_hash)
|
322
|
+
embedding_list.append(embedding_per_req_chunk)
|
323
|
+
if len(embedding_list) == 0:
|
324
|
+
return None, None
|
325
|
+
embedding = torch.concat(embedding_list, dim=0)
|
190
326
|
# 2. Check the embedding
|
191
|
-
|
192
|
-
num_mm_tokens_in_embedding = embedding.shape[0]
|
193
|
-
else:
|
194
|
-
num_mm_tokens_in_embedding = embedding.shape[0] * embedding.shape[1]
|
195
|
-
|
196
|
-
# the mask of multimodal tokens from input_ids
|
327
|
+
num_mm_tokens_in_embedding = embedding.shape[0]
|
197
328
|
special_multimodal_mask = torch.isin(
|
198
329
|
input_ids,
|
199
330
|
placeholder_tensor,
|
@@ -202,14 +333,11 @@ def get_embedding_and_mask(
|
|
202
333
|
num_mm_tokens_in_input_ids = special_multimodal_mask.sum().item()
|
203
334
|
if num_mm_tokens_in_input_ids != num_mm_tokens_in_embedding:
|
204
335
|
logger.warning(
|
205
|
-
f"Number of tokens in multimodal embedding does not match those in the input text."
|
336
|
+
f"Number of tokens in multimodal embedding does not match those in the input text. "
|
206
337
|
f"Got {num_mm_tokens_in_input_ids} tokens in the text but {num_mm_tokens_in_embedding} "
|
207
338
|
"tokens from multimodal embeddings."
|
208
339
|
)
|
209
340
|
if num_mm_tokens_in_input_ids < num_mm_tokens_in_embedding:
|
210
|
-
# TODO: chunked prefill will split special tokens from input_ids into several passes, failing the embedding
|
211
|
-
# a fix may be cache the unfinished multimodal embedding for future reuse, determine the tokens to embed with
|
212
|
-
# extend_start_loc and extend_seq_lens
|
213
341
|
chunked_prefill_size = global_server_args_dict["chunked_prefill_size"]
|
214
342
|
if chunked_prefill_size != -1:
|
215
343
|
logger.warning(
|
@@ -230,7 +358,9 @@ def get_embedding_and_mask(
|
|
230
358
|
|
231
359
|
|
232
360
|
def embed_mm_inputs(
|
233
|
-
|
361
|
+
mm_inputs_list: List[MultimodalInputs],
|
362
|
+
extend_prefix_lens: List[int],
|
363
|
+
extend_seq_lens: List[int],
|
234
364
|
input_ids: torch.Tensor,
|
235
365
|
input_embedding: nn.Embedding,
|
236
366
|
image_data_embedding_func: Callable[
|
@@ -242,125 +372,133 @@ def embed_mm_inputs(
|
|
242
372
|
placeholder_tokens: dict[Modality, List[int]] = None,
|
243
373
|
) -> Optional[torch.Tensor]:
|
244
374
|
"""
|
245
|
-
|
375
|
+
Embed multimodal inputs and integrate them with text token embeddings.
|
376
|
+
|
377
|
+
Args:
|
378
|
+
mm_inputs_list: List of multimodal inputs to process
|
379
|
+
extend_prefix_lens: Prefix lengths for each request
|
380
|
+
extend_seq_lens: Sequence lengths for each request
|
381
|
+
input_ids: Input token IDs tensor
|
382
|
+
input_embedding: Embedding layer for text tokens
|
383
|
+
image_data_embedding_func: Function to embed image data
|
384
|
+
audio_data_embedding_func: Function to embed audio data
|
385
|
+
placeholder_tokens: Token IDs for multimodal placeholders (uses pad_values if None)
|
246
386
|
|
247
|
-
|
248
|
-
|
249
|
-
If none, the pad_values of multimodal items are used
|
250
|
-
|
251
|
-
Returns:
|
252
|
-
final embedding: Optional[torch.Tensor]
|
387
|
+
Returns:
|
388
|
+
Combined embedding tensor with multimodal content integrated
|
253
389
|
"""
|
254
390
|
|
255
|
-
if
|
391
|
+
if mm_inputs_list is None:
|
256
392
|
return None
|
257
393
|
|
258
394
|
# 1. Calculate the multimodal data which exists in input_ids, with the help of pad_values
|
259
395
|
# we assume that multimodal data are represented with its pad_values in input_ids
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
if placeholder_tokens is not None:
|
264
|
-
placeholder_token_ids = flatten_nested_list(
|
265
|
-
[placeholder_token for placeholder_token in placeholder_tokens.values()]
|
266
|
-
)
|
267
|
-
else:
|
268
|
-
placeholder_token_ids = [item.pad_value for item in mm_inputs.mm_items]
|
269
|
-
|
270
|
-
assert isinstance(placeholder_token_ids[0], int)
|
271
|
-
|
272
|
-
placeholder_tensor = torch.tensor(placeholder_token_ids, device=input_ids.device)
|
273
|
-
|
274
|
-
placeholder_masks = torch.isin(input_ids, placeholder_tensor)
|
396
|
+
item_flatten_list = []
|
397
|
+
for mm_inputs in mm_inputs_list:
|
398
|
+
item_flatten_list += [item for item in mm_inputs.mm_items if item is not None]
|
275
399
|
|
276
|
-
|
277
|
-
input_ids[placeholder_masks], return_counts=False
|
278
|
-
)
|
400
|
+
embeddings, masks = [], []
|
279
401
|
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
|
291
|
-
|
292
|
-
|
293
|
-
|
294
|
-
|
402
|
+
# 2. Get multimodal embedding separately
|
403
|
+
# TODO: make this more generic
|
404
|
+
# Try get image embedding if any
|
405
|
+
if (
|
406
|
+
any(True for item in item_flatten_list if item.is_image())
|
407
|
+
and image_data_embedding_func
|
408
|
+
):
|
409
|
+
items = [item for item in item_flatten_list if item.is_image()]
|
410
|
+
placeholder_tensor = torch.tensor(
|
411
|
+
[item.pad_value for item in items],
|
412
|
+
device=input_ids.device,
|
413
|
+
)
|
414
|
+
# calculate per request items length offset
|
415
|
+
items_size = torch.zeros(len(mm_inputs_list) + 1, dtype=int)
|
416
|
+
items_offsets = []
|
417
|
+
for i, mm_inputs in enumerate(mm_inputs_list):
|
418
|
+
image_items = [item for item in mm_inputs.mm_items if item.is_image()]
|
419
|
+
items_size[i + 1] = len(image_items)
|
420
|
+
items_offsets.append(
|
421
|
+
flatten_nested_list(
|
422
|
+
[
|
423
|
+
item.image_offsets
|
424
|
+
for item in mm_inputs.mm_items
|
425
|
+
if item.is_image()
|
426
|
+
]
|
427
|
+
)
|
295
428
|
)
|
296
|
-
|
297
|
-
appearing_items = mm_inputs.mm_items
|
429
|
+
items_size = torch.cumsum(items_size, dim=0).tolist()
|
298
430
|
|
299
|
-
|
431
|
+
embedding, mask = get_embedding_and_mask(
|
432
|
+
data_embedding_func=image_data_embedding_func,
|
433
|
+
embedding_items=items,
|
434
|
+
placeholder_tensor=placeholder_tensor,
|
435
|
+
input_ids=input_ids,
|
436
|
+
items_size=items_size,
|
437
|
+
prefix_length=extend_prefix_lens,
|
438
|
+
extend_length=extend_seq_lens,
|
439
|
+
items_offset_list=items_offsets,
|
440
|
+
)
|
441
|
+
embeddings += [embedding]
|
442
|
+
masks += [mask]
|
300
443
|
|
301
|
-
|
302
|
-
|
303
|
-
|
304
|
-
|
305
|
-
|
306
|
-
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
|
315
|
-
|
316
|
-
|
317
|
-
|
318
|
-
|
319
|
-
|
320
|
-
|
321
|
-
|
444
|
+
# Try get audio embedding if any
|
445
|
+
if (
|
446
|
+
any(True for item in item_flatten_list if item.is_audio())
|
447
|
+
and audio_data_embedding_func
|
448
|
+
):
|
449
|
+
items = [item for item in item_flatten_list if item.is_audio()]
|
450
|
+
placeholder_tensor = torch.tensor(
|
451
|
+
[item.pad_value for item in items],
|
452
|
+
device=input_ids.device,
|
453
|
+
)
|
454
|
+
items_offsets = []
|
455
|
+
# calculate per request items length offset
|
456
|
+
items_size = torch.zeros(len(mm_inputs_list) + 1, dtype=int)
|
457
|
+
for i, mm_inputs in enumerate(mm_inputs_list):
|
458
|
+
audio_items = [item for item in mm_inputs.mm_items if item.is_audio()]
|
459
|
+
items_size[i + 1] = len(audio_items)
|
460
|
+
items_offsets.append(
|
461
|
+
flatten_nested_list(
|
462
|
+
[
|
463
|
+
item.audio_offsets
|
464
|
+
for item in mm_inputs.mm_items
|
465
|
+
if item.is_audio()
|
466
|
+
]
|
467
|
+
)
|
322
468
|
)
|
323
|
-
|
324
|
-
masks += [mask]
|
469
|
+
items_size = torch.cumsum(items_size, dim=0)
|
325
470
|
|
326
|
-
|
327
|
-
|
328
|
-
|
329
|
-
|
330
|
-
|
331
|
-
|
332
|
-
|
333
|
-
|
334
|
-
|
335
|
-
|
336
|
-
|
337
|
-
|
338
|
-
|
339
|
-
|
340
|
-
|
341
|
-
|
342
|
-
|
343
|
-
|
344
|
-
|
345
|
-
|
346
|
-
|
347
|
-
|
348
|
-
|
349
|
-
|
350
|
-
|
351
|
-
|
352
|
-
|
353
|
-
|
354
|
-
|
355
|
-
|
356
|
-
|
357
|
-
# 4. Scatter embeddings into input embedding
|
358
|
-
for embedding, mask in zip(embeddings, masks):
|
359
|
-
mask = mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
360
|
-
inputs_embeds = inputs_embeds.masked_scatter(
|
361
|
-
mask,
|
362
|
-
embedding.to(inputs_embeds.device, inputs_embeds.dtype),
|
363
|
-
)
|
471
|
+
embedding, mask = get_embedding_and_mask(
|
472
|
+
data_embedding_func=audio_data_embedding_func,
|
473
|
+
embedding_items=items,
|
474
|
+
placeholder_tensor=placeholder_tensor,
|
475
|
+
input_ids=input_ids,
|
476
|
+
items_size=items_size,
|
477
|
+
prefix_length=extend_prefix_lens,
|
478
|
+
extend_length=extend_seq_lens,
|
479
|
+
items_offset_list=items_offsets,
|
480
|
+
)
|
481
|
+
embeddings += [embedding]
|
482
|
+
masks += [mask]
|
483
|
+
|
484
|
+
# 3. Get input embeddings
|
485
|
+
vocab_size = input_embedding.num_embeddings
|
486
|
+
# Important: clamp after getting original multimodal regions
|
487
|
+
# Clamp input ids. This is because the input_ids for the multimodal tokens are
|
488
|
+
# filled with the hash values of the multimodal for the prefix matching in the radix attention.
|
489
|
+
# There values are useless because their embeddings will be replaced by vision embeddings anyway.
|
490
|
+
input_ids.clamp_(min=0, max=vocab_size - 1)
|
491
|
+
inputs_embeds = input_embedding(input_ids)
|
492
|
+
|
493
|
+
# 4. scatter embeddings into input embedding
|
494
|
+
for embedding, mask in zip(embeddings, masks):
|
495
|
+
if embedding is None or mask is None:
|
496
|
+
continue
|
497
|
+
mask = mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
498
|
+
inputs_embeds = inputs_embeds.masked_scatter(
|
499
|
+
mask,
|
500
|
+
embedding.to(inputs_embeds.device, inputs_embeds.dtype),
|
501
|
+
)
|
364
502
|
return inputs_embeds
|
365
503
|
|
366
504
|
|
@@ -368,37 +506,53 @@ def general_mm_embed_routine(
|
|
368
506
|
input_ids: torch.Tensor,
|
369
507
|
forward_batch: ForwardBatch,
|
370
508
|
language_model: nn.Module,
|
371
|
-
image_data_embedding_func:
|
372
|
-
[List[MultimodalDataItem]], torch.Tensor
|
509
|
+
image_data_embedding_func: Optional[
|
510
|
+
Callable[[List[MultimodalDataItem]], torch.Tensor]
|
373
511
|
] = None,
|
374
|
-
audio_data_embedding_func:
|
375
|
-
[List[MultimodalDataItem]], torch.Tensor
|
512
|
+
audio_data_embedding_func: Optional[
|
513
|
+
Callable[[List[MultimodalDataItem]], torch.Tensor]
|
376
514
|
] = None,
|
377
|
-
placeholder_tokens: dict[Modality, List[int]] = None,
|
515
|
+
placeholder_tokens: Optional[dict[Modality, List[int]]] = None,
|
378
516
|
**kwargs,
|
379
517
|
) -> torch.Tensor:
|
380
518
|
"""
|
381
|
-
|
519
|
+
Process multimodal inputs and forward through language model.
|
382
520
|
|
383
|
-
|
384
|
-
|
385
|
-
|
386
|
-
|
387
|
-
|
388
|
-
|
389
|
-
|
521
|
+
Args:
|
522
|
+
input_ids: Input token IDs tensor
|
523
|
+
forward_batch: Batch information for model forward pass
|
524
|
+
language_model: Base language model to use
|
525
|
+
image_data_embedding_func: Function to embed image data
|
526
|
+
audio_data_embedding_func: Function to embed audio data
|
527
|
+
placeholder_tokens: Token IDs for multimodal placeholders
|
528
|
+
**kwargs: Additional arguments passed to language model
|
390
529
|
|
530
|
+
Returns:
|
531
|
+
Hidden states from language model forward pass
|
391
532
|
"""
|
392
|
-
|
393
533
|
assert hasattr(language_model, "get_input_embeddings")
|
394
534
|
embed_tokens = language_model.get_input_embeddings()
|
395
535
|
if (
|
396
536
|
not forward_batch.forward_mode.is_decode()
|
397
537
|
and forward_batch.contains_mm_inputs()
|
398
538
|
):
|
399
|
-
|
539
|
+
mm_inputs_list = [
|
540
|
+
mm_input for mm_input in forward_batch.mm_inputs if mm_input is not None
|
541
|
+
]
|
542
|
+
extend_prefix_lens = [
|
543
|
+
prefix_len
|
544
|
+
for i, prefix_len in enumerate(forward_batch.extend_prefix_lens_cpu)
|
545
|
+
if forward_batch.mm_inputs[i] is not None
|
546
|
+
]
|
547
|
+
extend_seq_lens = [
|
548
|
+
seq_len
|
549
|
+
for i, seq_len in enumerate(forward_batch.extend_seq_lens_cpu)
|
550
|
+
if forward_batch.mm_inputs[i] is not None
|
551
|
+
]
|
400
552
|
inputs_embeds = embed_mm_inputs(
|
401
|
-
|
553
|
+
mm_inputs_list=mm_inputs_list,
|
554
|
+
extend_prefix_lens=extend_prefix_lens,
|
555
|
+
extend_seq_lens=extend_seq_lens,
|
402
556
|
input_ids=input_ids,
|
403
557
|
input_embedding=embed_tokens,
|
404
558
|
image_data_embedding_func=image_data_embedding_func,
|