sglang 0.4.9.post2__py3-none-any.whl → 0.4.9.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +2 -1
- sglang/eval/loogle_eval.py +7 -0
- sglang/srt/_custom_ops.py +29 -1
- sglang/srt/configs/deepseekvl2.py +11 -2
- sglang/srt/configs/internvl.py +3 -0
- sglang/srt/configs/janus_pro.py +3 -0
- sglang/srt/configs/model_config.py +10 -8
- sglang/srt/configs/update_config.py +3 -1
- sglang/srt/conversation.py +2 -1
- sglang/srt/custom_op.py +5 -2
- sglang/srt/disaggregation/common/conn.py +34 -6
- sglang/srt/disaggregation/decode.py +9 -1
- sglang/srt/disaggregation/mini_lb.py +3 -2
- sglang/srt/disaggregation/mooncake/conn.py +93 -76
- sglang/srt/disaggregation/mooncake/transfer_engine.py +4 -2
- sglang/srt/disaggregation/nixl/conn.py +17 -13
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -91
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +96 -1
- sglang/srt/distributed/device_communicators/quick_all_reduce.py +273 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +12 -5
- sglang/srt/distributed/parallel_state.py +103 -15
- sglang/srt/entrypoints/engine.py +31 -33
- sglang/srt/entrypoints/http_server.py +20 -32
- sglang/srt/entrypoints/openai/protocol.py +3 -3
- sglang/srt/entrypoints/openai/serving_chat.py +48 -6
- sglang/srt/eplb/expert_location_dispatch.py +1 -1
- sglang/srt/function_call/base_format_detector.py +74 -12
- sglang/srt/function_call/deepseekv3_detector.py +26 -11
- sglang/srt/function_call/ebnf_composer.py +95 -63
- sglang/srt/function_call/function_call_parser.py +4 -2
- sglang/srt/function_call/kimik2_detector.py +41 -16
- sglang/srt/function_call/llama32_detector.py +6 -3
- sglang/srt/function_call/mistral_detector.py +11 -3
- sglang/srt/function_call/pythonic_detector.py +16 -14
- sglang/srt/function_call/qwen25_detector.py +12 -3
- sglang/srt/function_call/qwen3_coder_detector.py +151 -0
- sglang/srt/hf_transformers_utils.py +0 -1
- sglang/srt/layers/activation.py +24 -3
- sglang/srt/layers/attention/base_attn_backend.py +3 -1
- sglang/srt/layers/attention/flashattention_backend.py +3 -3
- sglang/srt/layers/attention/flashinfer_backend.py +40 -1
- sglang/srt/layers/communicator.py +12 -12
- sglang/srt/layers/dp_attention.py +72 -24
- sglang/srt/layers/linear.py +13 -102
- sglang/srt/layers/logits_processor.py +34 -24
- sglang/srt/layers/moe/ep_moe/kernels.py +4 -2
- sglang/srt/layers/moe/ep_moe/layer.py +23 -402
- sglang/srt/layers/moe/fused_moe_native.py +7 -47
- sglang/srt/layers/moe/fused_moe_triton/__init__.py +4 -4
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=320,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +54 -263
- sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -396
- sglang/srt/layers/moe/topk.py +190 -23
- sglang/srt/layers/quantization/__init__.py +20 -134
- sglang/srt/layers/quantization/awq.py +578 -11
- sglang/srt/layers/quantization/awq_triton.py +339 -0
- sglang/srt/layers/quantization/base_config.py +85 -10
- sglang/srt/layers/quantization/blockwise_int8.py +17 -55
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +13 -11
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +23 -79
- sglang/srt/layers/quantization/fp8.py +273 -62
- sglang/srt/layers/quantization/fp8_kernel.py +210 -46
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/quantization/gptq.py +501 -143
- sglang/srt/layers/quantization/marlin_utils.py +790 -0
- sglang/srt/layers/quantization/modelopt_quant.py +34 -112
- sglang/srt/layers/quantization/moe_wna16.py +45 -49
- sglang/srt/layers/quantization/petit.py +252 -0
- sglang/srt/layers/quantization/petit_utils.py +104 -0
- sglang/srt/layers/quantization/qoq.py +7 -6
- sglang/srt/layers/quantization/scalar_type.py +352 -0
- sglang/srt/layers/quantization/unquant.py +422 -0
- sglang/srt/layers/quantization/utils.py +340 -9
- sglang/srt/layers/quantization/w4afp8.py +8 -4
- sglang/srt/layers/quantization/w8a8_fp8.py +17 -51
- sglang/srt/layers/quantization/w8a8_int8.py +51 -115
- sglang/srt/layers/radix_attention.py +5 -3
- sglang/srt/layers/vocab_parallel_embedding.py +1 -41
- sglang/srt/lora/lora.py +0 -4
- sglang/srt/lora/lora_manager.py +162 -164
- sglang/srt/lora/lora_registry.py +124 -0
- sglang/srt/lora/mem_pool.py +83 -35
- sglang/srt/lora/utils.py +12 -5
- sglang/srt/managers/cache_controller.py +288 -0
- sglang/srt/managers/io_struct.py +60 -30
- sglang/srt/managers/mm_utils.py +7 -8
- sglang/srt/managers/schedule_batch.py +163 -113
- sglang/srt/managers/schedule_policy.py +68 -27
- sglang/srt/managers/scheduler.py +256 -86
- sglang/srt/managers/scheduler_output_processor_mixin.py +22 -4
- sglang/srt/managers/tokenizer_manager.py +38 -27
- sglang/srt/managers/tp_worker.py +16 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
- sglang/srt/mem_cache/allocator.py +74 -23
- sglang/srt/mem_cache/base_prefix_cache.py +14 -2
- sglang/srt/mem_cache/chunk_cache.py +5 -2
- sglang/srt/mem_cache/hicache_storage.py +168 -0
- sglang/srt/mem_cache/hiradix_cache.py +194 -5
- sglang/srt/mem_cache/memory_pool.py +16 -1
- sglang/srt/mem_cache/memory_pool_host.py +44 -2
- sglang/srt/mem_cache/radix_cache.py +26 -0
- sglang/srt/mem_cache/swa_radix_cache.py +1025 -0
- sglang/srt/metrics/collector.py +9 -0
- sglang/srt/model_executor/cuda_graph_runner.py +66 -31
- sglang/srt/model_executor/forward_batch_info.py +210 -25
- sglang/srt/model_executor/model_runner.py +147 -42
- sglang/srt/model_loader/loader.py +7 -1
- sglang/srt/model_loader/utils.py +4 -4
- sglang/srt/models/clip.py +1 -1
- sglang/srt/models/deepseek.py +9 -6
- sglang/srt/models/deepseek_janus_pro.py +1 -1
- sglang/srt/models/deepseek_v2.py +192 -173
- sglang/srt/models/deepseek_vl2.py +5 -5
- sglang/srt/models/gemma.py +48 -0
- sglang/srt/models/gemma2.py +52 -0
- sglang/srt/models/gemma3_causal.py +63 -0
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +2 -4
- sglang/srt/models/granitemoe.py +385 -0
- sglang/srt/models/grok.py +9 -3
- sglang/srt/models/hunyuan.py +63 -16
- sglang/srt/models/internvl.py +1 -1
- sglang/srt/models/kimi_vl.py +1 -1
- sglang/srt/models/llama.py +41 -0
- sglang/srt/models/llama4.py +11 -11
- sglang/srt/models/llava.py +2 -2
- sglang/srt/models/llavavid.py +1 -1
- sglang/srt/models/minicpm.py +0 -2
- sglang/srt/models/minicpmo.py +3 -7
- sglang/srt/models/minicpmv.py +1 -1
- sglang/srt/models/mistral.py +1 -1
- sglang/srt/models/mixtral.py +9 -2
- sglang/srt/models/mllama.py +3 -5
- sglang/srt/models/mllama4.py +13 -6
- sglang/srt/models/olmoe.py +8 -5
- sglang/srt/models/persimmon.py +330 -0
- sglang/srt/models/phi.py +321 -0
- sglang/srt/models/phi4mm.py +44 -4
- sglang/srt/models/phi4mm_audio.py +1260 -0
- sglang/srt/models/phi4mm_utils.py +1917 -0
- sglang/srt/models/phimoe.py +9 -3
- sglang/srt/models/qwen.py +37 -0
- sglang/srt/models/qwen2.py +41 -0
- sglang/srt/models/qwen2_5_vl.py +4 -4
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +53 -9
- sglang/srt/models/qwen2_vl.py +4 -4
- sglang/srt/models/qwen3.py +65 -1
- sglang/srt/models/qwen3_moe.py +57 -24
- sglang/srt/models/vila.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +91 -97
- sglang/srt/multimodal/processors/clip.py +21 -19
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +8 -26
- sglang/srt/multimodal/processors/gemma3.py +13 -17
- sglang/srt/multimodal/processors/gemma3n.py +19 -23
- sglang/srt/multimodal/processors/internvl.py +9 -10
- sglang/srt/multimodal/processors/janus_pro.py +12 -27
- sglang/srt/multimodal/processors/kimi_vl.py +12 -14
- sglang/srt/multimodal/processors/llava.py +4 -2
- sglang/srt/multimodal/processors/minicpm.py +35 -44
- sglang/srt/multimodal/processors/mlama.py +21 -18
- sglang/srt/multimodal/processors/mllama4.py +4 -5
- sglang/srt/multimodal/processors/phi4mm.py +63 -39
- sglang/srt/multimodal/processors/pixtral.py +14 -35
- sglang/srt/multimodal/processors/qwen_audio.py +65 -0
- sglang/srt/multimodal/processors/qwen_vl.py +16 -21
- sglang/srt/multimodal/processors/vila.py +14 -14
- sglang/srt/reasoning_parser.py +46 -4
- sglang/srt/sampling/sampling_batch_info.py +6 -5
- sglang/srt/sampling/sampling_params.py +8 -1
- sglang/srt/server_args.py +454 -270
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +33 -28
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +46 -37
- sglang/srt/speculative/eagle_utils.py +51 -23
- sglang/srt/speculative/eagle_worker.py +59 -44
- sglang/srt/two_batch_overlap.py +10 -5
- sglang/srt/utils.py +44 -69
- sglang/test/runners.py +14 -3
- sglang/test/test_activation.py +50 -1
- sglang/test/test_block_fp8.py +8 -3
- sglang/test/test_block_fp8_ep.py +1 -1
- sglang/test/test_custom_ops.py +12 -7
- sglang/test/test_cutlass_w4a8_moe.py +1 -3
- sglang/test/test_fp4_moe.py +1 -3
- sglang/test/test_marlin_moe.py +286 -0
- sglang/test/test_marlin_utils.py +171 -0
- sglang/test/test_utils.py +35 -0
- sglang/version.py +1 -1
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/METADATA +10 -10
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/RECORD +198 -175
- sglang/srt/layers/quantization/quant_utils.py +0 -166
- sglang/srt/managers/multimodal_processors/qwen_audio.py +0 -94
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/top_level.txt +0 -0
sglang/srt/managers/io_struct.py
CHANGED
@@ -13,15 +13,16 @@
|
|
13
13
|
# ==============================================================================
|
14
14
|
"""
|
15
15
|
The definition of objects transferred between different
|
16
|
-
processes (TokenizerManager, DetokenizerManager,
|
16
|
+
processes (TokenizerManager, DetokenizerManager, Scheduler).
|
17
17
|
"""
|
18
18
|
|
19
19
|
import copy
|
20
20
|
import uuid
|
21
21
|
from dataclasses import dataclass, field
|
22
22
|
from enum import Enum
|
23
|
-
from typing import TYPE_CHECKING, Any, Dict, List, Optional,
|
23
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
24
24
|
|
25
|
+
from sglang.srt.lora.lora_registry import LoRARef
|
25
26
|
from sglang.srt.managers.schedule_batch import BaseFinishReason
|
26
27
|
from sglang.srt.multimodal.mm_utils import has_valid_data
|
27
28
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
@@ -42,8 +43,21 @@ class SessionParams:
|
|
42
43
|
drop_previous_output: Optional[bool] = None
|
43
44
|
|
44
45
|
|
45
|
-
|
46
|
-
|
46
|
+
# Type definitions for multimodal input data
|
47
|
+
# Individual data item types for each modality
|
48
|
+
ImageDataInputItem = Union[Image, str, Dict]
|
49
|
+
AudioDataInputItem = Union[str, Dict]
|
50
|
+
VideoDataInputItem = Union[str, Dict]
|
51
|
+
# Union type for any multimodal data item
|
52
|
+
MultimodalDataInputItem = Union[
|
53
|
+
ImageDataInputItem, VideoDataInputItem, AudioDataInputItem
|
54
|
+
]
|
55
|
+
# Format types supporting single items, lists, or nested lists for batch processing
|
56
|
+
MultimodalDataInputFormat = Union[
|
57
|
+
List[List[MultimodalDataInputItem]],
|
58
|
+
List[MultimodalDataInputItem],
|
59
|
+
MultimodalDataInputItem,
|
60
|
+
]
|
47
61
|
|
48
62
|
|
49
63
|
@dataclass
|
@@ -60,13 +74,11 @@ class GenerateReqInput:
|
|
60
74
|
# - List of images (one per request in a batch)
|
61
75
|
# - List of lists of images (multiple images per request)
|
62
76
|
# See also python/sglang/srt/utils.py:load_image for more details.
|
63
|
-
image_data: Optional[
|
64
|
-
Union[List[List[ImageDataItem]], List[ImageDataItem], ImageDataItem]
|
65
|
-
] = None
|
66
|
-
# The audio input. Like image data, it can be a file name, a url, or base64 encoded string.
|
67
|
-
audio_data: Optional[Union[List[AudioDataItem], AudioDataItem]] = None
|
77
|
+
image_data: Optional[MultimodalDataInputFormat] = None
|
68
78
|
# The video input. Like image data, it can be a file name, a url, or base64 encoded string.
|
69
|
-
video_data: Optional[
|
79
|
+
video_data: Optional[MultimodalDataInputFormat] = None
|
80
|
+
# The audio input. Like image data, it can be a file name, a url, or base64 encoded string.
|
81
|
+
audio_data: Optional[MultimodalDataInputFormat] = None
|
70
82
|
# The sampling_params. See descriptions below.
|
71
83
|
sampling_params: Optional[Union[List[Dict], Dict]] = None
|
72
84
|
# The request id.
|
@@ -297,6 +309,9 @@ class GenerateReqInput:
|
|
297
309
|
self.modalities.append("image")
|
298
310
|
elif len(self.image_data[i]) > 1:
|
299
311
|
self.modalities.append("multi-images")
|
312
|
+
else:
|
313
|
+
# Ensure len(self.modalities) == len(self.image_data)
|
314
|
+
self.modalities.append(None)
|
300
315
|
# Expand parallel_sample_num
|
301
316
|
self.image_data = self.image_data * self.parallel_sample_num
|
302
317
|
self.modalities = self.modalities * self.parallel_sample_num
|
@@ -521,19 +536,17 @@ class EmbeddingReqInput:
|
|
521
536
|
# - List of images (one per request in a batch)
|
522
537
|
# - List of lists of images (multiple images per request)
|
523
538
|
# See also python/sglang/srt/utils.py:load_image for more details.
|
524
|
-
image_data: Optional[
|
525
|
-
Union[List[List[Union[Image, str]]], List[Union[Image, str]], Union[Image, str]]
|
526
|
-
] = None
|
539
|
+
image_data: Optional[MultimodalDataInputFormat] = None
|
527
540
|
# The video input. Like image data, it can be a file name, a url, or base64 encoded string.
|
528
|
-
video_data: Optional[
|
541
|
+
video_data: Optional[MultimodalDataInputFormat] = None
|
529
542
|
# The audio input. Like image data, it can be a file name, a url, or base64 encoded string.
|
530
|
-
audio_data: Optional[
|
543
|
+
audio_data: Optional[MultimodalDataInputFormat] = None
|
531
544
|
# The token ids for text; one can either specify text or input_ids.
|
532
545
|
input_ids: Optional[Union[List[List[int]], List[int]]] = None
|
533
546
|
# The request id.
|
534
547
|
rid: Optional[Union[List[str], str]] = None
|
535
548
|
# Dummy sampling params for compatibility
|
536
|
-
sampling_params: Union[List[Dict], Dict] = None
|
549
|
+
sampling_params: Optional[Union[List[Dict], Dict]] = None
|
537
550
|
# Dummy input embeds for compatibility
|
538
551
|
input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
|
539
552
|
# Whether to log metrics for this request (e.g. health_generate calls do not log metrics)
|
@@ -607,8 +620,6 @@ class EmbeddingReqInput:
|
|
607
620
|
if self.is_cross_encoder_request:
|
608
621
|
return EmbeddingReqInput(
|
609
622
|
text=[self.text[i]] if self.text is not None else None,
|
610
|
-
input_ids=None,
|
611
|
-
image_data=None,
|
612
623
|
sampling_params=self.sampling_params[i],
|
613
624
|
rid=self.rid[i],
|
614
625
|
is_cross_encoder_request=True,
|
@@ -618,6 +629,8 @@ class EmbeddingReqInput:
|
|
618
629
|
text=self.text[i] if self.text is not None else None,
|
619
630
|
input_ids=self.input_ids[i] if self.input_ids is not None else None,
|
620
631
|
image_data=self.image_data[i] if self.image_data is not None else None,
|
632
|
+
audio_data=self.audio_data[i] if self.audio_data is not None else None,
|
633
|
+
video_data=self.video_data[i] if self.video_data is not None else None,
|
621
634
|
sampling_params=self.sampling_params[i],
|
622
635
|
rid=self.rid[i],
|
623
636
|
)
|
@@ -941,17 +954,6 @@ class ProfileReqType(Enum):
|
|
941
954
|
STOP_PROFILE = 2
|
942
955
|
|
943
956
|
|
944
|
-
class ExpertDistributionReq(Enum):
|
945
|
-
START_RECORD = 1
|
946
|
-
STOP_RECORD = 2
|
947
|
-
DUMP_RECORD = 3
|
948
|
-
|
949
|
-
|
950
|
-
@dataclass
|
951
|
-
class ExpertDistributionReqOutput:
|
952
|
-
pass
|
953
|
-
|
954
|
-
|
955
957
|
@dataclass
|
956
958
|
class ProfileReq:
|
957
959
|
type: ProfileReqType
|
@@ -1001,6 +1003,17 @@ class HealthCheckOutput:
|
|
1001
1003
|
pass
|
1002
1004
|
|
1003
1005
|
|
1006
|
+
class ExpertDistributionReq(Enum):
|
1007
|
+
START_RECORD = 1
|
1008
|
+
STOP_RECORD = 2
|
1009
|
+
DUMP_RECORD = 3
|
1010
|
+
|
1011
|
+
|
1012
|
+
@dataclass
|
1013
|
+
class ExpertDistributionReqOutput:
|
1014
|
+
pass
|
1015
|
+
|
1016
|
+
|
1004
1017
|
@dataclass
|
1005
1018
|
class Function:
|
1006
1019
|
description: Optional[str] = None
|
@@ -1055,19 +1068,36 @@ class LoadLoRAAdapterReqInput:
|
|
1055
1068
|
lora_name: str
|
1056
1069
|
# The path of loading.
|
1057
1070
|
lora_path: str
|
1071
|
+
# The unique identifier for the LoRA adapter, which automatically generated in the `TokenizerManager`.
|
1072
|
+
lora_id: Optional[str] = None
|
1073
|
+
|
1074
|
+
def to_ref(self) -> LoRARef:
|
1075
|
+
return LoRARef(
|
1076
|
+
lora_id=self.lora_id,
|
1077
|
+
lora_name=self.lora_name,
|
1078
|
+
lora_path=self.lora_path,
|
1079
|
+
)
|
1058
1080
|
|
1059
1081
|
|
1060
1082
|
@dataclass
|
1061
1083
|
class UnloadLoRAAdapterReqInput:
|
1062
1084
|
# The name of lora module to unload.
|
1063
1085
|
lora_name: str
|
1086
|
+
# The unique identifier for the LoRA adapter, which automatically generated in the `TokenizerManager`.
|
1087
|
+
lora_id: Optional[str] = None
|
1088
|
+
|
1089
|
+
def to_ref(self) -> LoRARef:
|
1090
|
+
return LoRARef(
|
1091
|
+
lora_id=self.lora_id,
|
1092
|
+
lora_name=self.lora_name,
|
1093
|
+
)
|
1064
1094
|
|
1065
1095
|
|
1066
1096
|
@dataclass
|
1067
1097
|
class LoRAUpdateResult:
|
1068
1098
|
success: bool
|
1069
1099
|
error_message: Optional[str] = None
|
1070
|
-
loaded_adapters: Dict[str,
|
1100
|
+
loaded_adapters: Dict[str, LoRARef] = field(default_factory=dict)
|
1071
1101
|
|
1072
1102
|
|
1073
1103
|
LoadLoRAAdapterReqOutput = UnloadLoRAAdapterReqOutput = LoRAUpdateResult
|
sglang/srt/managers/mm_utils.py
CHANGED
@@ -76,7 +76,6 @@ class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern)
|
|
76
76
|
This function will replace the data-tokens in between with pad_values accordingly
|
77
77
|
"""
|
78
78
|
pad_values = [item.pad_value for item in mm_inputs.mm_items]
|
79
|
-
print(f"{mm_inputs.mm_items=}")
|
80
79
|
data_token_pairs = self.data_token_id_pairs
|
81
80
|
mm_inputs.data_offsets = []
|
82
81
|
if data_token_pairs is None:
|
@@ -222,17 +221,17 @@ def _get_precomputed_embedding(
|
|
222
221
|
items: List[MultimodalDataItem],
|
223
222
|
) -> Optional[torch.Tensor]:
|
224
223
|
"""
|
225
|
-
If all items have
|
226
|
-
If some but not all have
|
227
|
-
If none have
|
224
|
+
If all items have precomputed_embeddings, return their concatenation.
|
225
|
+
If some but not all have precomputed_embeddings, raise NotImplementedError.
|
226
|
+
If none have precomputed_embeddings, return None.
|
228
227
|
"""
|
229
|
-
|
230
|
-
if any(feature is not None for feature in
|
231
|
-
if not all(feature is not None for feature in
|
228
|
+
precomputed_embeddings = [item.precomputed_embeddings for item in items]
|
229
|
+
if any(feature is not None for feature in precomputed_embeddings):
|
230
|
+
if not all(feature is not None for feature in precomputed_embeddings):
|
232
231
|
raise NotImplementedError(
|
233
232
|
"MM inputs where only some items are precomputed."
|
234
233
|
)
|
235
|
-
result = torch.concat(
|
234
|
+
result = torch.concat(precomputed_embeddings)
|
236
235
|
# some models embedding is 3-dim, reshape it to 2-dim (similar to get_embedding_chunk)
|
237
236
|
result = result.reshape(-1, result.shape[-1])
|
238
237
|
return result
|