sglang 0.4.4.post3__py3-none-any.whl → 0.4.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_serving.py +49 -7
- sglang/lang/chat_template.py +24 -0
- sglang/srt/_custom_ops.py +59 -92
- sglang/srt/configs/model_config.py +5 -0
- sglang/srt/constrained/base_grammar_backend.py +5 -1
- sglang/srt/conversation.py +29 -4
- sglang/srt/custom_op.py +5 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +27 -79
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
- sglang/srt/entrypoints/engine.py +0 -5
- sglang/srt/layers/attention/flashattention_backend.py +678 -83
- sglang/srt/layers/attention/flashinfer_backend.py +5 -7
- sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -3
- sglang/srt/layers/attention/flashmla_backend.py +1 -1
- sglang/srt/layers/moe/ep_moe/kernels.py +142 -0
- sglang/srt/layers/moe/ep_moe/layer.py +79 -80
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +382 -199
- sglang/srt/layers/moe/fused_moe_native.py +5 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=512,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=144,N=512,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1024,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=20,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=24,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +416 -50
- sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
- sglang/srt/layers/moe/topk.py +49 -3
- sglang/srt/layers/quantization/__init__.py +5 -1
- sglang/srt/layers/quantization/blockwise_int8.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +2 -1
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +34 -10
- sglang/srt/layers/quantization/fp8.py +3 -1
- sglang/srt/layers/quantization/fp8_utils.py +1 -4
- sglang/srt/layers/quantization/moe_wna16.py +503 -0
- sglang/srt/layers/quantization/utils.py +1 -1
- sglang/srt/layers/quantization/w8a8_int8.py +2 -0
- sglang/srt/layers/radix_attention.py +2 -0
- sglang/srt/layers/rotary_embedding.py +63 -12
- sglang/srt/managers/cache_controller.py +34 -11
- sglang/srt/managers/mm_utils.py +202 -156
- sglang/srt/managers/multimodal_processor.py +0 -2
- sglang/srt/managers/multimodal_processors/base_processor.py +45 -77
- sglang/srt/managers/multimodal_processors/clip.py +7 -26
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +17 -58
- sglang/srt/managers/multimodal_processors/gemma3.py +12 -27
- sglang/srt/managers/multimodal_processors/janus_pro.py +21 -47
- sglang/srt/managers/multimodal_processors/llava.py +34 -14
- sglang/srt/managers/multimodal_processors/minicpm.py +35 -38
- sglang/srt/managers/multimodal_processors/mlama.py +10 -23
- sglang/srt/managers/multimodal_processors/mllama4.py +161 -0
- sglang/srt/managers/multimodal_processors/qwen_vl.py +22 -45
- sglang/srt/managers/schedule_batch.py +185 -128
- sglang/srt/managers/scheduler.py +4 -4
- sglang/srt/managers/tokenizer_manager.py +1 -1
- sglang/srt/managers/utils.py +1 -6
- sglang/srt/mem_cache/hiradix_cache.py +62 -52
- sglang/srt/mem_cache/memory_pool.py +72 -6
- sglang/srt/mem_cache/paged_allocator.py +39 -0
- sglang/srt/metrics/collector.py +23 -53
- sglang/srt/model_executor/cuda_graph_runner.py +8 -6
- sglang/srt/model_executor/forward_batch_info.py +10 -10
- sglang/srt/model_executor/model_runner.py +60 -57
- sglang/srt/model_loader/loader.py +8 -0
- sglang/srt/models/clip.py +12 -7
- sglang/srt/models/deepseek_janus_pro.py +10 -15
- sglang/srt/models/deepseek_v2.py +212 -121
- sglang/srt/models/deepseek_vl2.py +105 -104
- sglang/srt/models/gemma3_mm.py +14 -80
- sglang/srt/models/llama.py +16 -5
- sglang/srt/models/llama4.py +420 -0
- sglang/srt/models/llava.py +31 -19
- sglang/srt/models/llavavid.py +16 -7
- sglang/srt/models/minicpmo.py +63 -147
- sglang/srt/models/minicpmv.py +17 -27
- sglang/srt/models/mllama.py +29 -14
- sglang/srt/models/mllama4.py +154 -0
- sglang/srt/models/qwen2.py +9 -6
- sglang/srt/models/qwen2_5_vl.py +21 -31
- sglang/srt/models/qwen2_vl.py +20 -21
- sglang/srt/openai_api/adapter.py +18 -6
- sglang/srt/platforms/interface.py +371 -0
- sglang/srt/server_args.py +99 -14
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -5
- sglang/srt/speculative/eagle_utils.py +140 -28
- sglang/srt/speculative/eagle_worker.py +93 -24
- sglang/srt/utils.py +104 -51
- sglang/test/test_custom_ops.py +55 -0
- sglang/test/test_utils.py +13 -26
- sglang/utils.py +2 -2
- sglang/version.py +1 -1
- {sglang-0.4.4.post3.dist-info → sglang-0.4.5.dist-info}/METADATA +4 -3
- {sglang-0.4.4.post3.dist-info → sglang-0.4.5.dist-info}/RECORD +99 -84
- {sglang-0.4.4.post3.dist-info → sglang-0.4.5.dist-info}/WHEEL +0 -0
- {sglang-0.4.4.post3.dist-info → sglang-0.4.5.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.4.post3.dist-info → sglang-0.4.5.dist-info}/top_level.txt +0 -0
@@ -75,6 +75,7 @@ from sglang.srt.utils import (
|
|
75
75
|
get_available_gpu_memory,
|
76
76
|
init_custom_process_group,
|
77
77
|
is_cuda,
|
78
|
+
is_flashinfer_available,
|
78
79
|
is_hip,
|
79
80
|
monkey_patch_p2p_access_check,
|
80
81
|
monkey_patch_vllm_gguf_config,
|
@@ -123,6 +124,11 @@ class ModelRunner:
|
|
123
124
|
self.page_size = server_args.page_size
|
124
125
|
self.req_to_token_pool = req_to_token_pool
|
125
126
|
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
|
127
|
+
self.use_mla_backend = (
|
128
|
+
self.model_config.attention_arch == AttentionArch.MLA
|
129
|
+
and not server_args.disable_mla
|
130
|
+
)
|
131
|
+
self.attention_chunk_size = model_config.attention_chunk_size
|
126
132
|
|
127
133
|
# Model-specific adjustment
|
128
134
|
self.model_specific_adjustment()
|
@@ -147,15 +153,18 @@ class ModelRunner:
|
|
147
153
|
"enable_dp_attention": server_args.enable_dp_attention,
|
148
154
|
"enable_ep_moe": server_args.enable_ep_moe,
|
149
155
|
"enable_deepep_moe": server_args.enable_deepep_moe,
|
156
|
+
"deepep_mode": server_args.deepep_mode,
|
150
157
|
"device": server_args.device,
|
151
158
|
"speculative_accept_threshold_single": server_args.speculative_accept_threshold_single,
|
152
159
|
"speculative_accept_threshold_acc": server_args.speculative_accept_threshold_acc,
|
153
|
-
"enable_flashinfer_mla": server_args.enable_flashinfer_mla,
|
154
160
|
"enable_flashmla": server_args.enable_flashmla,
|
155
161
|
"disable_radix_cache": server_args.disable_radix_cache,
|
156
162
|
"flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged,
|
157
163
|
"debug_tensor_dump_output_folder": server_args.debug_tensor_dump_output_folder,
|
158
164
|
"debug_tensor_dump_inject": server_args.debug_tensor_dump_inject,
|
165
|
+
"n_share_experts_fusion": server_args.n_share_experts_fusion,
|
166
|
+
"disable_shared_experts_fusion": server_args.disable_shared_experts_fusion,
|
167
|
+
"use_mla_backend": self.use_mla_backend,
|
159
168
|
}
|
160
169
|
)
|
161
170
|
|
@@ -216,27 +225,38 @@ class ModelRunner:
|
|
216
225
|
def model_specific_adjustment(self):
|
217
226
|
server_args = self.server_args
|
218
227
|
|
219
|
-
if
|
220
|
-
|
221
|
-
|
222
|
-
|
228
|
+
if server_args.enable_flashinfer_mla:
|
229
|
+
# TODO: remove this branch after enable_flashinfer_mla is deprecated
|
230
|
+
logger.info("MLA optimization is turned on. Use flashinfer backend.")
|
231
|
+
server_args.attention_backend = "flashinfer"
|
232
|
+
elif server_args.enable_flashmla:
|
233
|
+
# TODO: remove this branch after enable_flashmla is deprecated
|
234
|
+
logger.info("MLA optimization is turned on. Use flashmla decode.")
|
235
|
+
server_args.attention_backend = "flashmla"
|
236
|
+
elif server_args.attention_backend is None:
|
237
|
+
# By default, use flashinfer for non-mla attention and triton for mla attention
|
238
|
+
if not self.use_mla_backend:
|
239
|
+
server_args.attention_backend = (
|
240
|
+
"flashinfer" if is_flashinfer_available() else "triton"
|
241
|
+
)
|
242
|
+
else:
|
243
|
+
server_args.attention_backend = "triton"
|
244
|
+
logger.info(
|
245
|
+
f"Attention backend not set. Use {server_args.attention_backend} backend by default."
|
246
|
+
)
|
247
|
+
elif self.use_mla_backend:
|
223
248
|
# TODO: add MLA optimization on CPU
|
224
249
|
if server_args.device != "cpu":
|
225
|
-
if server_args.
|
226
|
-
logger.info(
|
227
|
-
"MLA optimization is turned on. Use flashinfer mla backend."
|
228
|
-
)
|
229
|
-
server_args.attention_backend = "flashinfer_mla"
|
230
|
-
elif server_args.enable_flashmla:
|
231
|
-
logger.info("MLA optimization is turned on. Use flashmla decode.")
|
232
|
-
server_args.attention_backend = "flashmla"
|
233
|
-
elif server_args.attention_backend == "fa3":
|
250
|
+
if server_args.attention_backend in ["flashinfer", "fa3", "triton"]:
|
234
251
|
logger.info(
|
235
|
-
f"MLA optimization is turned on. Use
|
252
|
+
f"MLA optimization is turned on. Use {server_args.attention_backend} backend."
|
236
253
|
)
|
237
254
|
else:
|
238
|
-
|
239
|
-
|
255
|
+
raise ValueError(
|
256
|
+
f"Invalid attention backend for MLA: {server_args.attention_backend}"
|
257
|
+
)
|
258
|
+
else:
|
259
|
+
raise ValueError(f"MLA optimization not supported on CPU.")
|
240
260
|
|
241
261
|
if server_args.enable_double_sparsity:
|
242
262
|
logger.info(
|
@@ -251,17 +271,16 @@ class ModelRunner:
|
|
251
271
|
self.init_double_sparsity_channel_config(server_args.ds_heavy_channel_type)
|
252
272
|
|
253
273
|
if self.is_multimodal:
|
254
|
-
self.mem_fraction_static *= 0.
|
274
|
+
self.mem_fraction_static *= 0.90
|
255
275
|
logger.info(
|
256
276
|
f"Automatically reduce --mem-fraction-static to {self.mem_fraction_static:.3f} "
|
257
277
|
f"because this is a multimodal model."
|
258
278
|
)
|
259
279
|
|
260
|
-
|
261
|
-
"
|
262
|
-
|
263
|
-
|
264
|
-
server_args.chunked_prefill_size = -1
|
280
|
+
logger.info(
|
281
|
+
"Automatically turn off --chunked-prefill-size for multimodal model."
|
282
|
+
)
|
283
|
+
server_args.chunked_prefill_size = -1
|
265
284
|
|
266
285
|
if self.model_config.hf_config.architectures == [
|
267
286
|
"Qwen2VLForConditionalGeneration"
|
@@ -269,22 +288,11 @@ class ModelRunner:
|
|
269
288
|
"Qwen2_5_VLForConditionalGeneration"
|
270
289
|
]:
|
271
290
|
# TODO: qwen2-vl series does not support radix cache now, set disable_radix_cache=True automatically
|
272
|
-
logger.info(
|
273
|
-
"Automatically turn off --chunked-prefill-size and disable radix cache for qwen-vl series."
|
274
|
-
)
|
275
|
-
server_args.chunked_prefill_size = -1
|
276
|
-
server_args.disable_radix_cache = True
|
277
|
-
|
278
|
-
if self.model_config.hf_config.architectures == ["DeepseekVL2ForCausalLM"]:
|
279
|
-
# TODO: deepseek-vl2 does not support radix cache now, set disable_radix_cache=True automatically
|
280
|
-
logger.info(
|
281
|
-
"Automatically turn off --chunked-prefill-size and disable radix cache for deepseek-vl2."
|
282
|
-
)
|
283
|
-
server_args.chunked_prefill_size = -1
|
291
|
+
logger.info("Automatically disable radix cache for qwen-vl series.")
|
284
292
|
server_args.disable_radix_cache = True
|
285
293
|
|
286
294
|
if server_args.enable_deepep_moe:
|
287
|
-
logger.info("DeepEP is turned on.")
|
295
|
+
logger.info(f"DeepEP is turned on. DeepEP mode: {server_args.deepep_mode}")
|
288
296
|
|
289
297
|
def init_torch_distributed(self):
|
290
298
|
logger.info("Init torch distributed begin.")
|
@@ -646,10 +654,7 @@ class ModelRunner:
|
|
646
654
|
available_gpu_memory = get_available_gpu_memory(
|
647
655
|
self.device, self.gpu_id, distributed=self.tp_size > 1
|
648
656
|
)
|
649
|
-
if
|
650
|
-
self.model_config.attention_arch == AttentionArch.MLA
|
651
|
-
and not self.server_args.disable_mla
|
652
|
-
):
|
657
|
+
if self.use_mla_backend:
|
653
658
|
cell_size = (
|
654
659
|
(self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim)
|
655
660
|
* self.model_config.num_hidden_layers
|
@@ -760,10 +765,7 @@ class ModelRunner:
|
|
760
765
|
# Draft worker shares req_to_token_pool with the target worker.
|
761
766
|
assert self.is_draft_worker
|
762
767
|
|
763
|
-
if
|
764
|
-
self.model_config.attention_arch == AttentionArch.MLA
|
765
|
-
and not self.server_args.disable_mla
|
766
|
-
):
|
768
|
+
if self.use_mla_backend:
|
767
769
|
self.token_to_kv_pool = MLATokenToKVPool(
|
768
770
|
self.max_total_num_tokens,
|
769
771
|
page_size=self.page_size,
|
@@ -834,14 +836,21 @@ class ModelRunner:
|
|
834
836
|
def init_attention_backend(self):
|
835
837
|
"""Init attention kernel backend."""
|
836
838
|
if self.server_args.attention_backend == "flashinfer":
|
837
|
-
|
838
|
-
|
839
|
-
|
839
|
+
if not self.use_mla_backend:
|
840
|
+
from sglang.srt.layers.attention.flashinfer_backend import (
|
841
|
+
FlashInferAttnBackend,
|
842
|
+
)
|
840
843
|
|
841
|
-
|
842
|
-
|
843
|
-
|
844
|
-
|
844
|
+
# Init streams
|
845
|
+
if self.server_args.speculative_algorithm == "EAGLE":
|
846
|
+
self.plan_stream_for_flashinfer = torch.cuda.Stream()
|
847
|
+
self.attn_backend = FlashInferAttnBackend(self)
|
848
|
+
else:
|
849
|
+
from sglang.srt.layers.attention.flashinfer_mla_backend import (
|
850
|
+
FlashInferMLAAttnBackend,
|
851
|
+
)
|
852
|
+
|
853
|
+
self.attn_backend = FlashInferMLAAttnBackend(self)
|
845
854
|
elif self.server_args.attention_backend == "triton":
|
846
855
|
assert self.sliding_window_size is None, (
|
847
856
|
"Window attention is not supported in the triton attention backend. "
|
@@ -867,12 +876,6 @@ class ModelRunner:
|
|
867
876
|
)
|
868
877
|
|
869
878
|
self.attn_backend = TorchNativeAttnBackend(self)
|
870
|
-
elif self.server_args.attention_backend == "flashinfer_mla":
|
871
|
-
from sglang.srt.layers.attention.flashinfer_mla_backend import (
|
872
|
-
FlashInferMLAAttnBackend,
|
873
|
-
)
|
874
|
-
|
875
|
-
self.attn_backend = FlashInferMLAAttnBackend(self)
|
876
879
|
elif self.server_args.attention_backend == "flashmla":
|
877
880
|
from sglang.srt.layers.attention.flashmla_backend import FlashMLABackend
|
878
881
|
|
@@ -489,6 +489,14 @@ class DummyModelLoader(BaseModelLoader):
|
|
489
489
|
# NOTE(woosuk): For accurate performance evaluation, we assign
|
490
490
|
# random values to the weights.
|
491
491
|
initialize_dummy_weights(model)
|
492
|
+
|
493
|
+
# Model weight loading consists of two stages:
|
494
|
+
# 1. Initial weight loading.
|
495
|
+
# 2. Post-processing of weights, including assigning specific member variables.
|
496
|
+
# For `dummy_init`, only the second stage is required.
|
497
|
+
if hasattr(model, "post_load_weights"):
|
498
|
+
model.post_load_weights()
|
499
|
+
|
492
500
|
return model.eval()
|
493
501
|
|
494
502
|
|
sglang/srt/models/clip.py
CHANGED
@@ -17,7 +17,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
|
17
17
|
from sglang.srt.managers.schedule_batch import MultimodalInputs
|
18
18
|
from sglang.srt.model_executor.model_runner import ForwardBatch
|
19
19
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
20
|
-
from sglang.srt.utils import add_prefix
|
20
|
+
from sglang.srt.utils import add_prefix, flatten_nested_list
|
21
21
|
|
22
22
|
|
23
23
|
class CLIPVisionEmbeddings(nn.Module):
|
@@ -368,7 +368,6 @@ class CLIPVisionTransformer(nn.Module):
|
|
368
368
|
self,
|
369
369
|
pixel_values: torch.Tensor,
|
370
370
|
) -> torch.Tensor:
|
371
|
-
|
372
371
|
hidden_states = self.embeddings(pixel_values.to(self.device))
|
373
372
|
hidden_states = self.pre_layrnorm(hidden_states)
|
374
373
|
|
@@ -456,12 +455,18 @@ class CLIPModel(nn.Module):
|
|
456
455
|
get_embedding: bool = True,
|
457
456
|
):
|
458
457
|
assert get_embedding, "CLIPEmbeddingModel is only used for embedding"
|
459
|
-
|
458
|
+
mm_inputs = []
|
460
459
|
if forward_batch.mm_inputs is not None:
|
461
|
-
|
462
|
-
|
463
|
-
|
464
|
-
|
460
|
+
mm_inputs = forward_batch.mm_inputs
|
461
|
+
pixel_values_list = [
|
462
|
+
item.pixel_values
|
463
|
+
for item in flatten_nested_list(
|
464
|
+
[mm_input.mm_items for mm_input in mm_inputs if mm_input is not None]
|
465
|
+
)
|
466
|
+
]
|
467
|
+
if len(pixel_values_list) != 0:
|
468
|
+
pixel_values = torch.concat(pixel_values_list)
|
469
|
+
vision_outputs = self.vision_model(pixel_values)
|
465
470
|
pooled_output = vision_outputs[:, 0, :]
|
466
471
|
image_embeds = self.visual_projection(pooled_output)
|
467
472
|
image_embeds = nn.functional.normalize(image_embeds, p=2, dim=1)
|
@@ -51,7 +51,7 @@ from sglang.srt.managers.mm_utils import (
|
|
51
51
|
MultiModalityDataPaddingPatternTokenPairs,
|
52
52
|
general_mm_embed_routine,
|
53
53
|
)
|
54
|
-
from sglang.srt.managers.schedule_batch import
|
54
|
+
from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
|
55
55
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
56
56
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
57
57
|
from sglang.srt.models.llama import LlamaForCausalLM
|
@@ -1959,8 +1959,8 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
|
|
1959
1959
|
)
|
1960
1960
|
self.logits_processor = LogitsProcessor(config)
|
1961
1961
|
|
1962
|
-
def get_image_feature(self,
|
1963
|
-
pixel_values =
|
1962
|
+
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
1963
|
+
pixel_values = torch.concat([item.pixel_values for item in items], dim=0)
|
1964
1964
|
bs, n = pixel_values.shape[0:2]
|
1965
1965
|
pixel_values = pixel_values.to(
|
1966
1966
|
device=self.vision_model.device, dtype=self.vision_model.dtype
|
@@ -1976,7 +1976,7 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
|
|
1976
1976
|
return images_embeds
|
1977
1977
|
|
1978
1978
|
def get_input_embeddings(self) -> nn.Embedding:
|
1979
|
-
return self.language_model.
|
1979
|
+
return self.language_model.get_input_embeddings()
|
1980
1980
|
|
1981
1981
|
@torch.no_grad()
|
1982
1982
|
def forward(
|
@@ -1984,23 +1984,18 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
|
|
1984
1984
|
input_ids: torch.LongTensor,
|
1985
1985
|
positions: torch.Tensor,
|
1986
1986
|
forward_batch: ForwardBatch,
|
1987
|
+
get_embedding: bool = False,
|
1987
1988
|
) -> torch.Tensor:
|
1988
|
-
|
1989
|
-
inputs_embeds = general_mm_embed_routine(
|
1989
|
+
hidden_states = general_mm_embed_routine(
|
1990
1990
|
input_ids=input_ids,
|
1991
1991
|
forward_batch=forward_batch,
|
1992
|
-
|
1993
|
-
|
1994
|
-
)
|
1995
|
-
|
1996
|
-
return self.language_model(
|
1997
|
-
input_ids=None,
|
1992
|
+
image_data_embedding_func=self.get_image_feature,
|
1993
|
+
language_model=self.language_model,
|
1998
1994
|
positions=positions,
|
1999
|
-
forward_batch=forward_batch,
|
2000
|
-
input_embeds=inputs_embeds,
|
2001
|
-
get_embedding=False,
|
2002
1995
|
)
|
2003
1996
|
|
1997
|
+
return hidden_states
|
1998
|
+
|
2004
1999
|
def prepare_gen_img_embeds(self, image_ids: torch.LongTensor):
|
2005
2000
|
return self.gen_aligner(self.gen_embed(image_ids))
|
2006
2001
|
|