sglang 0.4.9.post2__py3-none-any.whl → 0.4.9.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (200) hide show
  1. sglang/bench_one_batch.py +2 -1
  2. sglang/eval/loogle_eval.py +7 -0
  3. sglang/srt/_custom_ops.py +29 -1
  4. sglang/srt/configs/deepseekvl2.py +11 -2
  5. sglang/srt/configs/internvl.py +3 -0
  6. sglang/srt/configs/janus_pro.py +3 -0
  7. sglang/srt/configs/model_config.py +10 -8
  8. sglang/srt/configs/update_config.py +3 -1
  9. sglang/srt/conversation.py +2 -1
  10. sglang/srt/custom_op.py +5 -2
  11. sglang/srt/disaggregation/common/conn.py +34 -6
  12. sglang/srt/disaggregation/decode.py +9 -1
  13. sglang/srt/disaggregation/mini_lb.py +3 -2
  14. sglang/srt/disaggregation/mooncake/conn.py +93 -76
  15. sglang/srt/disaggregation/mooncake/transfer_engine.py +4 -2
  16. sglang/srt/disaggregation/nixl/conn.py +17 -13
  17. sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -91
  18. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +96 -1
  19. sglang/srt/distributed/device_communicators/quick_all_reduce.py +273 -0
  20. sglang/srt/distributed/device_communicators/shm_broadcast.py +12 -5
  21. sglang/srt/distributed/parallel_state.py +103 -15
  22. sglang/srt/entrypoints/engine.py +31 -33
  23. sglang/srt/entrypoints/http_server.py +20 -32
  24. sglang/srt/entrypoints/openai/protocol.py +3 -3
  25. sglang/srt/entrypoints/openai/serving_chat.py +48 -6
  26. sglang/srt/eplb/expert_location_dispatch.py +1 -1
  27. sglang/srt/function_call/base_format_detector.py +74 -12
  28. sglang/srt/function_call/deepseekv3_detector.py +26 -11
  29. sglang/srt/function_call/ebnf_composer.py +95 -63
  30. sglang/srt/function_call/function_call_parser.py +4 -2
  31. sglang/srt/function_call/kimik2_detector.py +41 -16
  32. sglang/srt/function_call/llama32_detector.py +6 -3
  33. sglang/srt/function_call/mistral_detector.py +11 -3
  34. sglang/srt/function_call/pythonic_detector.py +16 -14
  35. sglang/srt/function_call/qwen25_detector.py +12 -3
  36. sglang/srt/function_call/qwen3_coder_detector.py +151 -0
  37. sglang/srt/hf_transformers_utils.py +0 -1
  38. sglang/srt/layers/activation.py +24 -3
  39. sglang/srt/layers/attention/base_attn_backend.py +3 -1
  40. sglang/srt/layers/attention/flashattention_backend.py +3 -3
  41. sglang/srt/layers/attention/flashinfer_backend.py +40 -1
  42. sglang/srt/layers/communicator.py +12 -12
  43. sglang/srt/layers/dp_attention.py +72 -24
  44. sglang/srt/layers/linear.py +13 -102
  45. sglang/srt/layers/logits_processor.py +34 -24
  46. sglang/srt/layers/moe/ep_moe/kernels.py +4 -2
  47. sglang/srt/layers/moe/ep_moe/layer.py +23 -402
  48. sglang/srt/layers/moe/fused_moe_native.py +7 -47
  49. sglang/srt/layers/moe/fused_moe_triton/__init__.py +4 -4
  50. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=320,device_name=NVIDIA_H20-3e.json +146 -0
  51. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  52. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  53. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  54. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  56. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +54 -263
  57. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -396
  58. sglang/srt/layers/moe/topk.py +190 -23
  59. sglang/srt/layers/quantization/__init__.py +20 -134
  60. sglang/srt/layers/quantization/awq.py +578 -11
  61. sglang/srt/layers/quantization/awq_triton.py +339 -0
  62. sglang/srt/layers/quantization/base_config.py +85 -10
  63. sglang/srt/layers/quantization/blockwise_int8.py +17 -55
  64. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +13 -11
  65. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +23 -79
  66. sglang/srt/layers/quantization/fp8.py +273 -62
  67. sglang/srt/layers/quantization/fp8_kernel.py +210 -46
  68. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  69. sglang/srt/layers/quantization/gptq.py +501 -143
  70. sglang/srt/layers/quantization/marlin_utils.py +790 -0
  71. sglang/srt/layers/quantization/modelopt_quant.py +34 -112
  72. sglang/srt/layers/quantization/moe_wna16.py +45 -49
  73. sglang/srt/layers/quantization/petit.py +252 -0
  74. sglang/srt/layers/quantization/petit_utils.py +104 -0
  75. sglang/srt/layers/quantization/qoq.py +7 -6
  76. sglang/srt/layers/quantization/scalar_type.py +352 -0
  77. sglang/srt/layers/quantization/unquant.py +422 -0
  78. sglang/srt/layers/quantization/utils.py +340 -9
  79. sglang/srt/layers/quantization/w4afp8.py +8 -4
  80. sglang/srt/layers/quantization/w8a8_fp8.py +17 -51
  81. sglang/srt/layers/quantization/w8a8_int8.py +51 -115
  82. sglang/srt/layers/radix_attention.py +5 -3
  83. sglang/srt/layers/vocab_parallel_embedding.py +1 -41
  84. sglang/srt/lora/lora.py +0 -4
  85. sglang/srt/lora/lora_manager.py +162 -164
  86. sglang/srt/lora/lora_registry.py +124 -0
  87. sglang/srt/lora/mem_pool.py +83 -35
  88. sglang/srt/lora/utils.py +12 -5
  89. sglang/srt/managers/cache_controller.py +288 -0
  90. sglang/srt/managers/io_struct.py +60 -30
  91. sglang/srt/managers/mm_utils.py +7 -8
  92. sglang/srt/managers/schedule_batch.py +163 -113
  93. sglang/srt/managers/schedule_policy.py +68 -27
  94. sglang/srt/managers/scheduler.py +256 -86
  95. sglang/srt/managers/scheduler_output_processor_mixin.py +22 -4
  96. sglang/srt/managers/tokenizer_manager.py +38 -27
  97. sglang/srt/managers/tp_worker.py +16 -4
  98. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  99. sglang/srt/mem_cache/allocator.py +74 -23
  100. sglang/srt/mem_cache/base_prefix_cache.py +14 -2
  101. sglang/srt/mem_cache/chunk_cache.py +5 -2
  102. sglang/srt/mem_cache/hicache_storage.py +168 -0
  103. sglang/srt/mem_cache/hiradix_cache.py +194 -5
  104. sglang/srt/mem_cache/memory_pool.py +16 -1
  105. sglang/srt/mem_cache/memory_pool_host.py +44 -2
  106. sglang/srt/mem_cache/radix_cache.py +26 -0
  107. sglang/srt/mem_cache/swa_radix_cache.py +1025 -0
  108. sglang/srt/metrics/collector.py +9 -0
  109. sglang/srt/model_executor/cuda_graph_runner.py +66 -31
  110. sglang/srt/model_executor/forward_batch_info.py +210 -25
  111. sglang/srt/model_executor/model_runner.py +147 -42
  112. sglang/srt/model_loader/loader.py +7 -1
  113. sglang/srt/model_loader/utils.py +4 -4
  114. sglang/srt/models/clip.py +1 -1
  115. sglang/srt/models/deepseek.py +9 -6
  116. sglang/srt/models/deepseek_janus_pro.py +1 -1
  117. sglang/srt/models/deepseek_v2.py +192 -173
  118. sglang/srt/models/deepseek_vl2.py +5 -5
  119. sglang/srt/models/gemma.py +48 -0
  120. sglang/srt/models/gemma2.py +52 -0
  121. sglang/srt/models/gemma3_causal.py +63 -0
  122. sglang/srt/models/gemma3_mm.py +1 -1
  123. sglang/srt/models/gemma3n_mm.py +2 -4
  124. sglang/srt/models/granitemoe.py +385 -0
  125. sglang/srt/models/grok.py +9 -3
  126. sglang/srt/models/hunyuan.py +63 -16
  127. sglang/srt/models/internvl.py +1 -1
  128. sglang/srt/models/kimi_vl.py +1 -1
  129. sglang/srt/models/llama.py +41 -0
  130. sglang/srt/models/llama4.py +11 -11
  131. sglang/srt/models/llava.py +2 -2
  132. sglang/srt/models/llavavid.py +1 -1
  133. sglang/srt/models/minicpm.py +0 -2
  134. sglang/srt/models/minicpmo.py +3 -7
  135. sglang/srt/models/minicpmv.py +1 -1
  136. sglang/srt/models/mistral.py +1 -1
  137. sglang/srt/models/mixtral.py +9 -2
  138. sglang/srt/models/mllama.py +3 -5
  139. sglang/srt/models/mllama4.py +13 -6
  140. sglang/srt/models/olmoe.py +8 -5
  141. sglang/srt/models/persimmon.py +330 -0
  142. sglang/srt/models/phi.py +321 -0
  143. sglang/srt/models/phi4mm.py +44 -4
  144. sglang/srt/models/phi4mm_audio.py +1260 -0
  145. sglang/srt/models/phi4mm_utils.py +1917 -0
  146. sglang/srt/models/phimoe.py +9 -3
  147. sglang/srt/models/qwen.py +37 -0
  148. sglang/srt/models/qwen2.py +41 -0
  149. sglang/srt/models/qwen2_5_vl.py +4 -4
  150. sglang/srt/models/qwen2_audio.py +1 -1
  151. sglang/srt/models/qwen2_moe.py +53 -9
  152. sglang/srt/models/qwen2_vl.py +4 -4
  153. sglang/srt/models/qwen3.py +65 -1
  154. sglang/srt/models/qwen3_moe.py +57 -24
  155. sglang/srt/models/vila.py +1 -1
  156. sglang/srt/multimodal/processors/base_processor.py +91 -97
  157. sglang/srt/multimodal/processors/clip.py +21 -19
  158. sglang/srt/multimodal/processors/deepseek_vl_v2.py +8 -26
  159. sglang/srt/multimodal/processors/gemma3.py +13 -17
  160. sglang/srt/multimodal/processors/gemma3n.py +19 -23
  161. sglang/srt/multimodal/processors/internvl.py +9 -10
  162. sglang/srt/multimodal/processors/janus_pro.py +12 -27
  163. sglang/srt/multimodal/processors/kimi_vl.py +12 -14
  164. sglang/srt/multimodal/processors/llava.py +4 -2
  165. sglang/srt/multimodal/processors/minicpm.py +35 -44
  166. sglang/srt/multimodal/processors/mlama.py +21 -18
  167. sglang/srt/multimodal/processors/mllama4.py +4 -5
  168. sglang/srt/multimodal/processors/phi4mm.py +63 -39
  169. sglang/srt/multimodal/processors/pixtral.py +14 -35
  170. sglang/srt/multimodal/processors/qwen_audio.py +65 -0
  171. sglang/srt/multimodal/processors/qwen_vl.py +16 -21
  172. sglang/srt/multimodal/processors/vila.py +14 -14
  173. sglang/srt/reasoning_parser.py +46 -4
  174. sglang/srt/sampling/sampling_batch_info.py +6 -5
  175. sglang/srt/sampling/sampling_params.py +8 -1
  176. sglang/srt/server_args.py +454 -270
  177. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +33 -28
  178. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +46 -37
  179. sglang/srt/speculative/eagle_utils.py +51 -23
  180. sglang/srt/speculative/eagle_worker.py +59 -44
  181. sglang/srt/two_batch_overlap.py +10 -5
  182. sglang/srt/utils.py +44 -69
  183. sglang/test/runners.py +14 -3
  184. sglang/test/test_activation.py +50 -1
  185. sglang/test/test_block_fp8.py +8 -3
  186. sglang/test/test_block_fp8_ep.py +1 -1
  187. sglang/test/test_custom_ops.py +12 -7
  188. sglang/test/test_cutlass_w4a8_moe.py +1 -3
  189. sglang/test/test_fp4_moe.py +1 -3
  190. sglang/test/test_marlin_moe.py +286 -0
  191. sglang/test/test_marlin_utils.py +171 -0
  192. sglang/test/test_utils.py +35 -0
  193. sglang/version.py +1 -1
  194. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/METADATA +10 -10
  195. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/RECORD +198 -175
  196. sglang/srt/layers/quantization/quant_utils.py +0 -166
  197. sglang/srt/managers/multimodal_processors/qwen_audio.py +0 -94
  198. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/WHEEL +0 -0
  199. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/licenses/LICENSE +0 -0
  200. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/top_level.txt +0 -0
@@ -13,15 +13,16 @@
13
13
  # ==============================================================================
14
14
  """
15
15
  The definition of objects transferred between different
16
- processes (TokenizerManager, DetokenizerManager, Controller).
16
+ processes (TokenizerManager, DetokenizerManager, Scheduler).
17
17
  """
18
18
 
19
19
  import copy
20
20
  import uuid
21
21
  from dataclasses import dataclass, field
22
22
  from enum import Enum
23
- from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Union
23
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
24
24
 
25
+ from sglang.srt.lora.lora_registry import LoRARef
25
26
  from sglang.srt.managers.schedule_batch import BaseFinishReason
26
27
  from sglang.srt.multimodal.mm_utils import has_valid_data
27
28
  from sglang.srt.sampling.sampling_params import SamplingParams
@@ -42,8 +43,21 @@ class SessionParams:
42
43
  drop_previous_output: Optional[bool] = None
43
44
 
44
45
 
45
- AudioDataItem = Union[str, Dict]
46
- ImageDataItem = Union[Image, str, Dict]
46
+ # Type definitions for multimodal input data
47
+ # Individual data item types for each modality
48
+ ImageDataInputItem = Union[Image, str, Dict]
49
+ AudioDataInputItem = Union[str, Dict]
50
+ VideoDataInputItem = Union[str, Dict]
51
+ # Union type for any multimodal data item
52
+ MultimodalDataInputItem = Union[
53
+ ImageDataInputItem, VideoDataInputItem, AudioDataInputItem
54
+ ]
55
+ # Format types supporting single items, lists, or nested lists for batch processing
56
+ MultimodalDataInputFormat = Union[
57
+ List[List[MultimodalDataInputItem]],
58
+ List[MultimodalDataInputItem],
59
+ MultimodalDataInputItem,
60
+ ]
47
61
 
48
62
 
49
63
  @dataclass
@@ -60,13 +74,11 @@ class GenerateReqInput:
60
74
  # - List of images (one per request in a batch)
61
75
  # - List of lists of images (multiple images per request)
62
76
  # See also python/sglang/srt/utils.py:load_image for more details.
63
- image_data: Optional[
64
- Union[List[List[ImageDataItem]], List[ImageDataItem], ImageDataItem]
65
- ] = None
66
- # The audio input. Like image data, it can be a file name, a url, or base64 encoded string.
67
- audio_data: Optional[Union[List[AudioDataItem], AudioDataItem]] = None
77
+ image_data: Optional[MultimodalDataInputFormat] = None
68
78
  # The video input. Like image data, it can be a file name, a url, or base64 encoded string.
69
- video_data: Optional[Union[List[List[str]], List[str], str]] = None
79
+ video_data: Optional[MultimodalDataInputFormat] = None
80
+ # The audio input. Like image data, it can be a file name, a url, or base64 encoded string.
81
+ audio_data: Optional[MultimodalDataInputFormat] = None
70
82
  # The sampling_params. See descriptions below.
71
83
  sampling_params: Optional[Union[List[Dict], Dict]] = None
72
84
  # The request id.
@@ -297,6 +309,9 @@ class GenerateReqInput:
297
309
  self.modalities.append("image")
298
310
  elif len(self.image_data[i]) > 1:
299
311
  self.modalities.append("multi-images")
312
+ else:
313
+ # Ensure len(self.modalities) == len(self.image_data)
314
+ self.modalities.append(None)
300
315
  # Expand parallel_sample_num
301
316
  self.image_data = self.image_data * self.parallel_sample_num
302
317
  self.modalities = self.modalities * self.parallel_sample_num
@@ -521,19 +536,17 @@ class EmbeddingReqInput:
521
536
  # - List of images (one per request in a batch)
522
537
  # - List of lists of images (multiple images per request)
523
538
  # See also python/sglang/srt/utils.py:load_image for more details.
524
- image_data: Optional[
525
- Union[List[List[Union[Image, str]]], List[Union[Image, str]], Union[Image, str]]
526
- ] = None
539
+ image_data: Optional[MultimodalDataInputFormat] = None
527
540
  # The video input. Like image data, it can be a file name, a url, or base64 encoded string.
528
- video_data: Optional[Union[List[str], str]] = None
541
+ video_data: Optional[MultimodalDataInputFormat] = None
529
542
  # The audio input. Like image data, it can be a file name, a url, or base64 encoded string.
530
- audio_data: Optional[Union[List[str], str]] = None
543
+ audio_data: Optional[MultimodalDataInputFormat] = None
531
544
  # The token ids for text; one can either specify text or input_ids.
532
545
  input_ids: Optional[Union[List[List[int]], List[int]]] = None
533
546
  # The request id.
534
547
  rid: Optional[Union[List[str], str]] = None
535
548
  # Dummy sampling params for compatibility
536
- sampling_params: Union[List[Dict], Dict] = None
549
+ sampling_params: Optional[Union[List[Dict], Dict]] = None
537
550
  # Dummy input embeds for compatibility
538
551
  input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
539
552
  # Whether to log metrics for this request (e.g. health_generate calls do not log metrics)
@@ -607,8 +620,6 @@ class EmbeddingReqInput:
607
620
  if self.is_cross_encoder_request:
608
621
  return EmbeddingReqInput(
609
622
  text=[self.text[i]] if self.text is not None else None,
610
- input_ids=None,
611
- image_data=None,
612
623
  sampling_params=self.sampling_params[i],
613
624
  rid=self.rid[i],
614
625
  is_cross_encoder_request=True,
@@ -618,6 +629,8 @@ class EmbeddingReqInput:
618
629
  text=self.text[i] if self.text is not None else None,
619
630
  input_ids=self.input_ids[i] if self.input_ids is not None else None,
620
631
  image_data=self.image_data[i] if self.image_data is not None else None,
632
+ audio_data=self.audio_data[i] if self.audio_data is not None else None,
633
+ video_data=self.video_data[i] if self.video_data is not None else None,
621
634
  sampling_params=self.sampling_params[i],
622
635
  rid=self.rid[i],
623
636
  )
@@ -941,17 +954,6 @@ class ProfileReqType(Enum):
941
954
  STOP_PROFILE = 2
942
955
 
943
956
 
944
- class ExpertDistributionReq(Enum):
945
- START_RECORD = 1
946
- STOP_RECORD = 2
947
- DUMP_RECORD = 3
948
-
949
-
950
- @dataclass
951
- class ExpertDistributionReqOutput:
952
- pass
953
-
954
-
955
957
  @dataclass
956
958
  class ProfileReq:
957
959
  type: ProfileReqType
@@ -1001,6 +1003,17 @@ class HealthCheckOutput:
1001
1003
  pass
1002
1004
 
1003
1005
 
1006
+ class ExpertDistributionReq(Enum):
1007
+ START_RECORD = 1
1008
+ STOP_RECORD = 2
1009
+ DUMP_RECORD = 3
1010
+
1011
+
1012
+ @dataclass
1013
+ class ExpertDistributionReqOutput:
1014
+ pass
1015
+
1016
+
1004
1017
  @dataclass
1005
1018
  class Function:
1006
1019
  description: Optional[str] = None
@@ -1055,19 +1068,36 @@ class LoadLoRAAdapterReqInput:
1055
1068
  lora_name: str
1056
1069
  # The path of loading.
1057
1070
  lora_path: str
1071
+ # The unique identifier for the LoRA adapter, which automatically generated in the `TokenizerManager`.
1072
+ lora_id: Optional[str] = None
1073
+
1074
+ def to_ref(self) -> LoRARef:
1075
+ return LoRARef(
1076
+ lora_id=self.lora_id,
1077
+ lora_name=self.lora_name,
1078
+ lora_path=self.lora_path,
1079
+ )
1058
1080
 
1059
1081
 
1060
1082
  @dataclass
1061
1083
  class UnloadLoRAAdapterReqInput:
1062
1084
  # The name of lora module to unload.
1063
1085
  lora_name: str
1086
+ # The unique identifier for the LoRA adapter, which automatically generated in the `TokenizerManager`.
1087
+ lora_id: Optional[str] = None
1088
+
1089
+ def to_ref(self) -> LoRARef:
1090
+ return LoRARef(
1091
+ lora_id=self.lora_id,
1092
+ lora_name=self.lora_name,
1093
+ )
1064
1094
 
1065
1095
 
1066
1096
  @dataclass
1067
1097
  class LoRAUpdateResult:
1068
1098
  success: bool
1069
1099
  error_message: Optional[str] = None
1070
- loaded_adapters: Dict[str, str] = field(default_factory=dict)
1100
+ loaded_adapters: Dict[str, LoRARef] = field(default_factory=dict)
1071
1101
 
1072
1102
 
1073
1103
  LoadLoRAAdapterReqOutput = UnloadLoRAAdapterReqOutput = LoRAUpdateResult
@@ -76,7 +76,6 @@ class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern)
76
76
  This function will replace the data-tokens in between with pad_values accordingly
77
77
  """
78
78
  pad_values = [item.pad_value for item in mm_inputs.mm_items]
79
- print(f"{mm_inputs.mm_items=}")
80
79
  data_token_pairs = self.data_token_id_pairs
81
80
  mm_inputs.data_offsets = []
82
81
  if data_token_pairs is None:
@@ -222,17 +221,17 @@ def _get_precomputed_embedding(
222
221
  items: List[MultimodalDataItem],
223
222
  ) -> Optional[torch.Tensor]:
224
223
  """
225
- If all items have precomputed_features, return their concatenation.
226
- If some but not all have precomputed_features, raise NotImplementedError.
227
- If none have precomputed_features, return None.
224
+ If all items have precomputed_embeddings, return their concatenation.
225
+ If some but not all have precomputed_embeddings, raise NotImplementedError.
226
+ If none have precomputed_embeddings, return None.
228
227
  """
229
- precomputed_features = [item.precomputed_features for item in items]
230
- if any(feature is not None for feature in precomputed_features):
231
- if not all(feature is not None for feature in precomputed_features):
228
+ precomputed_embeddings = [item.precomputed_embeddings for item in items]
229
+ if any(feature is not None for feature in precomputed_embeddings):
230
+ if not all(feature is not None for feature in precomputed_embeddings):
232
231
  raise NotImplementedError(
233
232
  "MM inputs where only some items are precomputed."
234
233
  )
235
- result = torch.concat(precomputed_features)
234
+ result = torch.concat(precomputed_embeddings)
236
235
  # some models embedding is 3-dim, reshape it to 2-dim (similar to get_embedding_chunk)
237
236
  result = result.reshape(-1, result.shape[-1])
238
237
  return result