sglang 0.4.4.post3__py3-none-any.whl → 0.4.4.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_serving.py +49 -7
- sglang/srt/_custom_ops.py +59 -92
- sglang/srt/configs/model_config.py +1 -0
- sglang/srt/constrained/base_grammar_backend.py +5 -1
- sglang/srt/custom_op.py +5 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +27 -79
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
- sglang/srt/entrypoints/engine.py +0 -5
- sglang/srt/layers/attention/flashattention_backend.py +394 -76
- sglang/srt/layers/attention/flashinfer_backend.py +5 -7
- sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -3
- sglang/srt/layers/attention/flashmla_backend.py +1 -1
- sglang/srt/layers/moe/ep_moe/kernels.py +142 -0
- sglang/srt/layers/moe/ep_moe/layer.py +79 -80
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +382 -199
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +403 -47
- sglang/srt/layers/moe/topk.py +49 -3
- sglang/srt/layers/quantization/__init__.py +4 -1
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +2 -1
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +34 -10
- sglang/srt/layers/quantization/fp8_utils.py +1 -4
- sglang/srt/layers/quantization/moe_wna16.py +501 -0
- sglang/srt/layers/quantization/utils.py +1 -1
- sglang/srt/layers/rotary_embedding.py +0 -12
- sglang/srt/managers/cache_controller.py +34 -11
- sglang/srt/managers/mm_utils.py +202 -156
- sglang/srt/managers/multimodal_processor.py +0 -2
- sglang/srt/managers/multimodal_processors/base_processor.py +45 -77
- sglang/srt/managers/multimodal_processors/clip.py +7 -26
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +17 -58
- sglang/srt/managers/multimodal_processors/gemma3.py +12 -27
- sglang/srt/managers/multimodal_processors/janus_pro.py +21 -47
- sglang/srt/managers/multimodal_processors/llava.py +34 -14
- sglang/srt/managers/multimodal_processors/minicpm.py +35 -38
- sglang/srt/managers/multimodal_processors/mlama.py +10 -23
- sglang/srt/managers/multimodal_processors/qwen_vl.py +22 -45
- sglang/srt/managers/schedule_batch.py +185 -128
- sglang/srt/managers/scheduler.py +4 -4
- sglang/srt/managers/tokenizer_manager.py +1 -1
- sglang/srt/managers/utils.py +1 -6
- sglang/srt/mem_cache/hiradix_cache.py +62 -52
- sglang/srt/mem_cache/memory_pool.py +72 -6
- sglang/srt/mem_cache/paged_allocator.py +39 -0
- sglang/srt/metrics/collector.py +23 -53
- sglang/srt/model_executor/cuda_graph_runner.py +8 -6
- sglang/srt/model_executor/forward_batch_info.py +10 -10
- sglang/srt/model_executor/model_runner.py +59 -57
- sglang/srt/model_loader/loader.py +8 -0
- sglang/srt/models/clip.py +12 -7
- sglang/srt/models/deepseek_janus_pro.py +10 -15
- sglang/srt/models/deepseek_v2.py +212 -121
- sglang/srt/models/deepseek_vl2.py +105 -104
- sglang/srt/models/gemma3_mm.py +14 -80
- sglang/srt/models/llama.py +4 -1
- sglang/srt/models/llava.py +31 -19
- sglang/srt/models/llavavid.py +16 -7
- sglang/srt/models/minicpmo.py +63 -147
- sglang/srt/models/minicpmv.py +17 -27
- sglang/srt/models/mllama.py +29 -14
- sglang/srt/models/qwen2.py +9 -6
- sglang/srt/models/qwen2_5_vl.py +21 -31
- sglang/srt/models/qwen2_vl.py +20 -21
- sglang/srt/openai_api/adapter.py +18 -6
- sglang/srt/platforms/interface.py +371 -0
- sglang/srt/server_args.py +99 -14
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -5
- sglang/srt/speculative/eagle_utils.py +140 -28
- sglang/srt/speculative/eagle_worker.py +93 -24
- sglang/srt/utils.py +104 -51
- sglang/test/test_custom_ops.py +55 -0
- sglang/test/test_utils.py +13 -26
- sglang/utils.py +2 -2
- sglang/version.py +1 -1
- {sglang-0.4.4.post3.dist-info → sglang-0.4.4.post4.dist-info}/METADATA +4 -3
- {sglang-0.4.4.post3.dist-info → sglang-0.4.4.post4.dist-info}/RECORD +81 -76
- {sglang-0.4.4.post3.dist-info → sglang-0.4.4.post4.dist-info}/WHEEL +0 -0
- {sglang-0.4.4.post3.dist-info → sglang-0.4.4.post4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.4.post3.dist-info → sglang-0.4.4.post4.dist-info}/top_level.txt +0 -0
sglang/srt/managers/mm_utils.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1
1
|
"""
|
2
|
-
|
2
|
+
Multi-modality utils
|
3
3
|
"""
|
4
4
|
|
5
5
|
from abc import abstractmethod
|
@@ -9,11 +9,13 @@ import torch
|
|
9
9
|
from torch import nn
|
10
10
|
|
11
11
|
from sglang.srt.managers.schedule_batch import (
|
12
|
+
MultimodalDataItem,
|
12
13
|
MultimodalInputs,
|
13
14
|
global_server_args_dict,
|
14
15
|
logger,
|
15
16
|
)
|
16
17
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
18
|
+
from sglang.srt.utils import print_warning_once
|
17
19
|
from sglang.utils import logger
|
18
20
|
|
19
21
|
|
@@ -26,7 +28,7 @@ class MultiModalityDataPaddingPattern:
|
|
26
28
|
|
27
29
|
@abstractmethod
|
28
30
|
def pad_input_tokens(
|
29
|
-
self, input_ids: List[int],
|
31
|
+
self, input_ids: List[int], mm_inputs: MultimodalInputs
|
30
32
|
) -> List[int]:
|
31
33
|
"""
|
32
34
|
Pad the input ids sequence containing data tokens, and replace them with pad_values
|
@@ -49,13 +51,13 @@ class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern)
|
|
49
51
|
"""
|
50
52
|
This function will replace the data-tokens inbetween with pad_values accordingly
|
51
53
|
"""
|
52
|
-
pad_values = mm_inputs.
|
54
|
+
pad_values = [item.pad_value for item in mm_inputs.mm_items]
|
53
55
|
data_token_pairs = self.data_token_id_pairs
|
54
|
-
mm_inputs.
|
56
|
+
mm_inputs.data_offsets = []
|
55
57
|
if data_token_pairs is None:
|
56
58
|
data_token_pairs = [mm_inputs.im_start_id, mm_inputs.im_end_id]
|
57
59
|
if data_token_pairs is None:
|
58
|
-
|
60
|
+
print_warning_once(
|
59
61
|
"No data_token_pairs provided, RadixAttention might be influenced."
|
60
62
|
)
|
61
63
|
return input_ids
|
@@ -77,10 +79,10 @@ class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern)
|
|
77
79
|
|
78
80
|
if input_ids[start_idx] in start_token_ids:
|
79
81
|
data_idx += 1
|
80
|
-
mm_inputs.
|
82
|
+
mm_inputs.data_offsets += [start_idx]
|
81
83
|
|
82
|
-
if data_idx >= len(
|
83
|
-
data_idx = len(
|
84
|
+
if data_idx >= len(pad_values):
|
85
|
+
data_idx = len(pad_values) - 1
|
84
86
|
|
85
87
|
num_tokens = end_idx - start_idx - 1
|
86
88
|
pad_value = pad_values[data_idx]
|
@@ -94,68 +96,19 @@ class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern)
|
|
94
96
|
return padded_ids
|
95
97
|
|
96
98
|
|
97
|
-
class MultModalityDataPaddingPatternSingleToken(MultiModalityDataPaddingPattern):
|
98
|
-
"""In this pattern, data is represented with a special token_id ( image_inputs.im_token_id ),
|
99
|
-
which needs first to be expanded to multiple tokens, then replaced with their padding values
|
100
|
-
|
101
|
-
This strategy should be used when a single data token represents content that should
|
102
|
-
be expanded to multiple tokens during processing.
|
103
|
-
"""
|
104
|
-
|
105
|
-
def __init__(
|
106
|
-
self, num_data_token_calc_func: Callable[[Tuple[int, int, int]], int]
|
107
|
-
) -> None:
|
108
|
-
self.num_data_token_calc_func = num_data_token_calc_func
|
109
|
-
|
110
|
-
def pad_input_tokens(
|
111
|
-
self, input_ids: List[int], mm_inputs: MultimodalInputs
|
112
|
-
) -> List[int]:
|
113
|
-
"""
|
114
|
-
This function will follow the procedure of:
|
115
|
-
1. the data token will be expanded, of which the final number will be calculated by `num_data_token_calc_func`
|
116
|
-
2. the padded data tokens will be replaced with their pad_values
|
117
|
-
"""
|
118
|
-
image_grid_thws = mm_inputs.image_grid_thws
|
119
|
-
pad_values = mm_inputs.pad_values
|
120
|
-
|
121
|
-
image_indices = [
|
122
|
-
idx for idx, token in enumerate(input_ids) if token == mm_inputs.im_token_id
|
123
|
-
]
|
124
|
-
|
125
|
-
mm_inputs.image_offsets = []
|
126
|
-
|
127
|
-
input_ids_with_image = []
|
128
|
-
for image_cnt, _ in enumerate(image_grid_thws):
|
129
|
-
# print(f"image_cnt {image_cnt}")
|
130
|
-
num_image_tokens = self.num_data_token_calc_func(image_grid_thws[image_cnt])
|
131
|
-
if image_cnt == 0:
|
132
|
-
non_image_tokens = input_ids[: image_indices[image_cnt]]
|
133
|
-
else:
|
134
|
-
non_image_tokens = input_ids[
|
135
|
-
image_indices[image_cnt - 1] + 1 : image_indices[image_cnt]
|
136
|
-
]
|
137
|
-
input_ids_with_image.extend(non_image_tokens)
|
138
|
-
mm_inputs.image_offsets.append(len(input_ids_with_image))
|
139
|
-
pad_ids = pad_values * (
|
140
|
-
(num_image_tokens + len(pad_values)) // len(pad_values)
|
141
|
-
)
|
142
|
-
input_ids_with_image.extend(pad_ids[:num_image_tokens])
|
143
|
-
input_ids_with_image.extend(input_ids[image_indices[-1] + 1 :])
|
144
|
-
|
145
|
-
return input_ids_with_image
|
146
|
-
|
147
|
-
|
148
99
|
class MultiModalityDataPaddingPatternImageTokens(MultiModalityDataPaddingPattern):
|
149
|
-
"""In this pattern, data tokens should be represented as
|
100
|
+
"""In this pattern, data tokens should be represented as repetitions of a single token
|
101
|
+
e.g. <image><image>....<image>, or <audio><audio>...<audio>
|
102
|
+
"""
|
150
103
|
|
151
104
|
def __init__(self, image_token_id: torch.Tensor) -> None:
|
152
105
|
self.image_token_id = image_token_id
|
153
106
|
|
154
|
-
def pad_input_tokens(self, input_ids: List[int],
|
107
|
+
def pad_input_tokens(self, input_ids: List[int], mm_inputs) -> List[int]:
|
155
108
|
"""
|
156
109
|
This function will replace the data-tokens in between with pad_values accordingly
|
157
110
|
"""
|
158
|
-
pad_values =
|
111
|
+
pad_values = [item.pad_value for item in mm_inputs.mm_items]
|
159
112
|
assert len(pad_values) != 0
|
160
113
|
|
161
114
|
input_ids_tensor = torch.tensor(input_ids)
|
@@ -170,138 +123,227 @@ class MultiModalityDataPaddingPatternImageTokens(MultiModalityDataPaddingPattern
|
|
170
123
|
return input_ids_tensor.tolist()
|
171
124
|
|
172
125
|
|
126
|
+
def get_embedding_and_mask(
|
127
|
+
data_embedding_func: Callable[[List[MultimodalDataItem]], torch.Tensor],
|
128
|
+
embedding_items: List[MultimodalDataItem],
|
129
|
+
placeholder_tensor: torch.Tensor,
|
130
|
+
input_ids: torch.Tensor,
|
131
|
+
):
|
132
|
+
"""
|
133
|
+
Get the multimodal embedding and its mask from input_ids
|
134
|
+
|
135
|
+
"""
|
136
|
+
# 1. Get the embedding
|
137
|
+
embedding = data_embedding_func(embedding_items)
|
138
|
+
|
139
|
+
# 2. Check the embedding
|
140
|
+
if embedding.dim() == 2:
|
141
|
+
num_mm_tokens_in_embedding = embedding.shape[0]
|
142
|
+
else:
|
143
|
+
num_mm_tokens_in_embedding = embedding.shape[0] * embedding.shape[1]
|
144
|
+
|
145
|
+
# the mask of multimodal tokens from input_ids
|
146
|
+
special_multimodal_mask = torch.isin(
|
147
|
+
input_ids,
|
148
|
+
placeholder_tensor,
|
149
|
+
).unsqueeze(-1)
|
150
|
+
|
151
|
+
num_mm_tokens_in_input_ids = special_multimodal_mask.sum()
|
152
|
+
if num_mm_tokens_in_input_ids != num_mm_tokens_in_embedding:
|
153
|
+
logger.warning(
|
154
|
+
f"Number of tokens in multimodal embedding does not match those in the input text."
|
155
|
+
f"Got {num_mm_tokens_in_input_ids} tokens in the text but {num_mm_tokens_in_embedding} "
|
156
|
+
"tokens from multimodal embeddings."
|
157
|
+
)
|
158
|
+
if num_mm_tokens_in_input_ids < num_mm_tokens_in_embedding:
|
159
|
+
# TODO: chunked prefill will split special tokens from input_ids into several passes, failing the embedding
|
160
|
+
# a fix may be cache the unfinished multimodal embedding for future reuse, determine the tokens to embed with
|
161
|
+
# extend_start_loc and extend_seq_lens
|
162
|
+
chunked_prefill_size = global_server_args_dict["chunked_prefill_size"]
|
163
|
+
if chunked_prefill_size != -1:
|
164
|
+
logger.warning(
|
165
|
+
"You may want to avoid this issue by raising `chunked_prefill_size`, or disabling chunked prefill"
|
166
|
+
)
|
167
|
+
# extract from the end: this is a compromise
|
168
|
+
if embedding.dim() == 2:
|
169
|
+
embedding = embedding[-num_mm_tokens_in_input_ids:, :]
|
170
|
+
else:
|
171
|
+
num_multimodal = num_mm_tokens_in_input_ids // embedding.shape[0]
|
172
|
+
embedding = embedding[-num_multimodal:, :]
|
173
|
+
else:
|
174
|
+
raise RuntimeError(
|
175
|
+
"Insufficient multimodal embedding length. This is an internal error"
|
176
|
+
)
|
177
|
+
|
178
|
+
return embedding, special_multimodal_mask
|
179
|
+
|
180
|
+
|
173
181
|
def embed_mm_inputs(
|
174
|
-
|
182
|
+
mm_inputs: MultimodalInputs,
|
175
183
|
input_ids: torch.Tensor,
|
176
184
|
input_embedding: nn.Embedding,
|
177
|
-
|
185
|
+
image_data_embedding_func: Callable[
|
186
|
+
[List[MultimodalDataItem]], torch.Tensor
|
187
|
+
] = None,
|
188
|
+
audio_data_embedding_func: Callable[
|
189
|
+
[List[MultimodalDataItem]], torch.Tensor
|
190
|
+
] = None,
|
178
191
|
placeholder_token_ids: List[int] = None,
|
179
192
|
) -> Optional[torch.Tensor]:
|
180
193
|
"""
|
181
|
-
Calculate the
|
182
|
-
the help of a boolean mask denoting the embed locations
|
194
|
+
Calculate the multimodal embeddings if necessary, then scatter the result with the help of a boolean mask denoting the embed locations
|
183
195
|
|
184
|
-
|
185
|
-
|
196
|
+
Args:
|
197
|
+
placeholder_token_ids: denoting the token of multimodal data in input_ids.
|
198
|
+
If none, the pad_values of multimodal items are used
|
199
|
+
|
200
|
+
Returns:
|
201
|
+
final embedding: Optional[torch.Tensor]
|
186
202
|
"""
|
187
|
-
|
203
|
+
|
204
|
+
if mm_inputs is None:
|
188
205
|
return None
|
189
206
|
|
190
|
-
|
207
|
+
# 1. Calculate the multimodal data which exists in input_ids, with the help of pad_values
|
208
|
+
# we assume that multimodal data are represented with its pad_values in input_ids
|
209
|
+
placeholder_token_ids = placeholder_token_ids or [
|
210
|
+
item.pad_value for item in mm_inputs.mm_items
|
211
|
+
]
|
191
212
|
|
192
|
-
|
193
|
-
special_image_mask = torch.isin(
|
194
|
-
input_ids,
|
195
|
-
torch.tensor(placeholder_token_ids, device=input_ids.device),
|
196
|
-
).unsqueeze(-1)
|
213
|
+
placeholder_tensor = torch.tensor(placeholder_token_ids, device=input_ids.device)
|
197
214
|
|
198
|
-
|
199
|
-
# print(f"{num_image_tokens_in_input_ids}")
|
200
|
-
# print(f"{input_ids}")
|
215
|
+
placeholder_masks = torch.isin(input_ids, placeholder_tensor)
|
201
216
|
|
202
|
-
|
203
|
-
|
204
|
-
|
217
|
+
appearing_pad_values = torch.unique(
|
218
|
+
input_ids[placeholder_masks], return_counts=False
|
219
|
+
)
|
220
|
+
|
221
|
+
if appearing_pad_values.numel() == 0:
|
222
|
+
# all been prefixed
|
205
223
|
inputs_embeds = input_embedding(input_ids)
|
206
224
|
else:
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
225
|
+
appearing_items = [
|
226
|
+
item
|
227
|
+
for item in mm_inputs.mm_items
|
228
|
+
if item.pad_value is not None and item.pad_value in appearing_pad_values
|
229
|
+
]
|
211
230
|
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
231
|
+
using_all_items = False
|
232
|
+
if len(appearing_items) == 0:
|
233
|
+
# This happens mostly when arg placeholder_token_ids is passed
|
234
|
+
logger.warning_once(
|
235
|
+
"No multimodal data item's pad value exist in placeholder ids. Using all items"
|
217
236
|
)
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
237
|
+
using_all_items = True
|
238
|
+
appearing_items = mm_inputs.mm_items
|
239
|
+
|
240
|
+
embeddings, masks = [], []
|
241
|
+
|
242
|
+
# 2. Get multimodal embedding separately
|
243
|
+
# TODO: make this more generic
|
244
|
+
# Try get image embedding if any
|
245
|
+
if (
|
246
|
+
any(True for item in appearing_items if item.is_image())
|
247
|
+
and image_data_embedding_func
|
248
|
+
):
|
249
|
+
items = [item for item in appearing_items if item.is_image()]
|
250
|
+
embedding, mask = get_embedding_and_mask(
|
251
|
+
data_embedding_func=image_data_embedding_func,
|
252
|
+
embedding_items=items,
|
253
|
+
placeholder_tensor=(
|
254
|
+
placeholder_tensor
|
255
|
+
if using_all_items
|
256
|
+
else torch.tensor(
|
257
|
+
[item.pad_value for item in items],
|
258
|
+
device=input_ids.device,
|
259
|
+
)
|
260
|
+
),
|
261
|
+
input_ids=input_ids,
|
225
262
|
)
|
263
|
+
embeddings += [embedding]
|
264
|
+
masks += [mask]
|
226
265
|
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
266
|
+
# Try get audio embedding if any
|
267
|
+
if (
|
268
|
+
any(True for item in appearing_items if item.is_audio())
|
269
|
+
and audio_data_embedding_func
|
270
|
+
):
|
271
|
+
items = [item for item in appearing_items if item.is_audio()]
|
272
|
+
embedding, mask = get_embedding_and_mask(
|
273
|
+
data_embedding_func=audio_data_embedding_func,
|
274
|
+
embedding_items=items,
|
275
|
+
placeholder_tensor=(
|
276
|
+
placeholder_tensor
|
277
|
+
if using_all_items
|
278
|
+
else torch.tensor(
|
279
|
+
[item.pad_value for item in items],
|
280
|
+
device=input_ids.device,
|
235
281
|
)
|
282
|
+
),
|
283
|
+
input_ids=input_ids,
|
284
|
+
)
|
285
|
+
embeddings += [embedding]
|
286
|
+
masks += [mask]
|
236
287
|
|
288
|
+
# 3. Get input embeddings
|
237
289
|
vocab_size = input_embedding.num_embeddings
|
238
|
-
# Important: clamp after getting original
|
239
|
-
# Clamp input ids. This is because the input_ids for the
|
240
|
-
# filled with the hash values of the
|
290
|
+
# Important: clamp after getting original multimodal regions
|
291
|
+
# Clamp input ids. This is because the input_ids for the multimodal tokens are
|
292
|
+
# filled with the hash values of the multimodal for the prefix matching in the radix attention.
|
241
293
|
# There values are useless because their embeddings will be replaced by vision embeddings anyway.
|
242
294
|
input_ids.clamp_(min=0, max=vocab_size - 1)
|
243
295
|
inputs_embeds = input_embedding(input_ids)
|
244
296
|
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
)
|
253
|
-
return inputs_embeds
|
254
|
-
|
255
|
-
|
256
|
-
def embed_image_embedding(
|
257
|
-
inputs_embeds: torch.Tensor,
|
258
|
-
image_embedding: torch.Tensor,
|
259
|
-
image_bounds: torch.Tensor,
|
260
|
-
) -> torch.Tensor:
|
261
|
-
"""
|
262
|
-
scatter image_embedding into inputs_embeds according to image_bounds
|
263
|
-
"""
|
264
|
-
if len(image_bounds) > 0:
|
265
|
-
image_indices = torch.stack(
|
266
|
-
[
|
267
|
-
torch.arange(start, end, dtype=torch.long)
|
268
|
-
for start, end in image_bounds.tolist()
|
269
|
-
]
|
270
|
-
).to(inputs_embeds.device)
|
271
|
-
|
272
|
-
inputs_embeds.scatter_(
|
273
|
-
0,
|
274
|
-
image_indices.view(-1, 1).repeat(1, inputs_embeds.shape[-1]),
|
275
|
-
image_embedding.view(-1, image_embedding.shape[-1]),
|
276
|
-
)
|
297
|
+
# 4. scatter embeddings into input embedding
|
298
|
+
for embedding, mask in zip(embeddings, masks):
|
299
|
+
mask = mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
300
|
+
inputs_embeds = inputs_embeds.masked_scatter(
|
301
|
+
mask,
|
302
|
+
embedding.to(inputs_embeds.device, inputs_embeds.dtype),
|
303
|
+
)
|
277
304
|
return inputs_embeds
|
278
305
|
|
279
306
|
|
280
307
|
def general_mm_embed_routine(
|
281
308
|
input_ids: torch.Tensor,
|
282
309
|
forward_batch: ForwardBatch,
|
283
|
-
|
284
|
-
|
310
|
+
language_model: nn.Module,
|
311
|
+
image_data_embedding_func: Callable[
|
312
|
+
[List[MultimodalDataItem]], torch.Tensor
|
313
|
+
] = None,
|
314
|
+
audio_data_embedding_func: Callable[
|
315
|
+
[List[MultimodalDataItem]], torch.Tensor
|
316
|
+
] = None,
|
285
317
|
placeholder_token_ids: List[int] = None,
|
286
|
-
|
318
|
+
**kwargs,
|
319
|
+
) -> torch.Tensor:
|
287
320
|
"""
|
288
|
-
|
289
|
-
with a language model as causal model
|
321
|
+
A general wrapper function to get final input embeds from multimodal models with a language model as causal model
|
290
322
|
|
291
323
|
Args:
|
292
324
|
placeholder_token_ids (List[int]): the ids of mm data placeholder tokens
|
325
|
+
image_data_embedding_func : the function returning the image embedding
|
326
|
+
audio_data_embedding_func : the function returning the image embedding
|
327
|
+
|
328
|
+
Returns:
|
329
|
+
inputs_embedding
|
330
|
+
forwarded hidden states
|
293
331
|
|
294
332
|
"""
|
333
|
+
|
334
|
+
assert hasattr(language_model, "get_input_embeddings")
|
335
|
+
embed_tokens = language_model.get_input_embeddings()
|
295
336
|
if (
|
296
337
|
not forward_batch.forward_mode.is_decode()
|
297
338
|
and forward_batch.contains_mm_inputs()
|
298
339
|
):
|
299
|
-
|
340
|
+
mm_input = forward_batch.merge_mm_inputs()
|
300
341
|
inputs_embeds = embed_mm_inputs(
|
301
|
-
mm_input
|
342
|
+
mm_inputs=mm_input,
|
302
343
|
input_ids=input_ids,
|
303
344
|
input_embedding=embed_tokens,
|
304
|
-
|
345
|
+
image_data_embedding_func=image_data_embedding_func,
|
346
|
+
audio_data_embedding_func=audio_data_embedding_func,
|
305
347
|
placeholder_token_ids=placeholder_token_ids,
|
306
348
|
)
|
307
349
|
# once used, mm_inputs is useless
|
@@ -310,7 +352,13 @@ def general_mm_embed_routine(
|
|
310
352
|
else:
|
311
353
|
inputs_embeds = embed_tokens(input_ids)
|
312
354
|
|
313
|
-
|
355
|
+
hidden_states = language_model(
|
356
|
+
input_ids=None,
|
357
|
+
forward_batch=forward_batch,
|
358
|
+
input_embeds=inputs_embeds,
|
359
|
+
**kwargs,
|
360
|
+
)
|
361
|
+
return hidden_states
|
314
362
|
|
315
363
|
|
316
364
|
def get_multimodal_data_bounds(
|
@@ -322,15 +370,13 @@ def get_multimodal_data_bounds(
|
|
322
370
|
Returns:
|
323
371
|
[bounds_count, 2]
|
324
372
|
"""
|
325
|
-
# All the
|
326
|
-
# bound token ids.
|
373
|
+
# All the multimodal data in the batch should share the same special bound token ids.
|
327
374
|
start_tokens = [s for s, _e in token_pairs]
|
328
375
|
end_tokens = [e for _s, e in token_pairs]
|
329
376
|
|
330
377
|
assert all(isinstance(t, int) for t in start_tokens)
|
331
378
|
assert all(isinstance(t, int) for t in end_tokens)
|
332
379
|
|
333
|
-
# print(input_ids)
|
334
380
|
start_cond = torch.isin(
|
335
381
|
input_ids, torch.tensor(start_tokens, device=input_ids.device)
|
336
382
|
)
|
@@ -339,7 +385,7 @@ def get_multimodal_data_bounds(
|
|
339
385
|
(data_start_tokens,) = torch.where(start_cond)
|
340
386
|
(data_end_tokens,) = torch.where(end_cond)
|
341
387
|
|
342
|
-
# the im_start_id sometimes can be cached as prefix, but it is needed for the embedding of the
|
388
|
+
# the im_start_id sometimes can be cached as prefix, but it is needed for the embedding of the multimodal data
|
343
389
|
if len(data_start_tokens) != len(data_end_tokens):
|
344
390
|
if (
|
345
391
|
len(data_start_tokens) + 1 == len(data_end_tokens)
|
@@ -352,14 +398,14 @@ def get_multimodal_data_bounds(
|
|
352
398
|
data_start_tokens,
|
353
399
|
]
|
354
400
|
)
|
355
|
-
|
401
|
+
valid_mm_data_nums = min(len(data_start_tokens), len(data_end_tokens))
|
356
402
|
|
357
|
-
if
|
403
|
+
if valid_mm_data_nums == 0:
|
358
404
|
return torch.zeros((0, 2), device=input_ids.device)
|
359
405
|
|
360
406
|
# Filter out pairs where start_token >= end_token
|
361
407
|
valid_pairs = []
|
362
|
-
for i in range(
|
408
|
+
for i in range(valid_mm_data_nums):
|
363
409
|
start_token = data_start_tokens[i]
|
364
410
|
end_token = data_end_tokens[i]
|
365
411
|
if start_token < end_token:
|