sglang 0.4.6.post3__py3-none-any.whl → 0.4.6.post5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. sglang/bench_offline_throughput.py +10 -8
  2. sglang/bench_one_batch.py +7 -6
  3. sglang/bench_one_batch_server.py +157 -21
  4. sglang/bench_serving.py +137 -59
  5. sglang/compile_deep_gemm.py +5 -5
  6. sglang/eval/loogle_eval.py +157 -0
  7. sglang/lang/chat_template.py +78 -78
  8. sglang/lang/tracer.py +1 -1
  9. sglang/srt/code_completion_parser.py +1 -1
  10. sglang/srt/configs/deepseekvl2.py +2 -2
  11. sglang/srt/configs/model_config.py +40 -28
  12. sglang/srt/constrained/base_grammar_backend.py +55 -72
  13. sglang/srt/constrained/llguidance_backend.py +25 -21
  14. sglang/srt/constrained/outlines_backend.py +27 -26
  15. sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
  16. sglang/srt/constrained/xgrammar_backend.py +69 -43
  17. sglang/srt/conversation.py +49 -44
  18. sglang/srt/disaggregation/base/conn.py +1 -0
  19. sglang/srt/disaggregation/decode.py +129 -135
  20. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +142 -0
  21. sglang/srt/disaggregation/fake/conn.py +3 -13
  22. sglang/srt/disaggregation/kv_events.py +357 -0
  23. sglang/srt/disaggregation/mini_lb.py +57 -24
  24. sglang/srt/disaggregation/mooncake/conn.py +238 -122
  25. sglang/srt/disaggregation/mooncake/transfer_engine.py +2 -1
  26. sglang/srt/disaggregation/nixl/conn.py +10 -19
  27. sglang/srt/disaggregation/prefill.py +132 -47
  28. sglang/srt/disaggregation/utils.py +123 -6
  29. sglang/srt/distributed/utils.py +3 -3
  30. sglang/srt/entrypoints/EngineBase.py +5 -0
  31. sglang/srt/entrypoints/engine.py +44 -9
  32. sglang/srt/entrypoints/http_server.py +23 -6
  33. sglang/srt/entrypoints/http_server_engine.py +5 -2
  34. sglang/srt/function_call/base_format_detector.py +250 -0
  35. sglang/srt/function_call/core_types.py +34 -0
  36. sglang/srt/function_call/deepseekv3_detector.py +157 -0
  37. sglang/srt/function_call/ebnf_composer.py +234 -0
  38. sglang/srt/function_call/function_call_parser.py +175 -0
  39. sglang/srt/function_call/llama32_detector.py +74 -0
  40. sglang/srt/function_call/mistral_detector.py +84 -0
  41. sglang/srt/function_call/pythonic_detector.py +163 -0
  42. sglang/srt/function_call/qwen25_detector.py +67 -0
  43. sglang/srt/function_call/utils.py +35 -0
  44. sglang/srt/hf_transformers_utils.py +46 -7
  45. sglang/srt/layers/attention/aiter_backend.py +513 -0
  46. sglang/srt/layers/attention/flashattention_backend.py +64 -18
  47. sglang/srt/layers/attention/flashinfer_mla_backend.py +8 -4
  48. sglang/srt/layers/attention/flashmla_backend.py +340 -78
  49. sglang/srt/layers/attention/triton_backend.py +3 -0
  50. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
  51. sglang/srt/layers/attention/utils.py +6 -4
  52. sglang/srt/layers/attention/vision.py +1 -1
  53. sglang/srt/layers/communicator.py +451 -0
  54. sglang/srt/layers/dp_attention.py +61 -21
  55. sglang/srt/layers/layernorm.py +1 -1
  56. sglang/srt/layers/logits_processor.py +46 -11
  57. sglang/srt/layers/moe/cutlass_moe.py +207 -0
  58. sglang/srt/layers/moe/ep_moe/kernels.py +34 -12
  59. sglang/srt/layers/moe/ep_moe/layer.py +105 -51
  60. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +82 -7
  61. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -1
  62. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -0
  63. sglang/srt/layers/moe/topk.py +67 -10
  64. sglang/srt/layers/multimodal.py +70 -0
  65. sglang/srt/layers/quantization/__init__.py +8 -3
  66. sglang/srt/layers/quantization/blockwise_int8.py +2 -2
  67. sglang/srt/layers/quantization/deep_gemm.py +77 -74
  68. sglang/srt/layers/quantization/fp8.py +92 -2
  69. sglang/srt/layers/quantization/fp8_kernel.py +3 -3
  70. sglang/srt/layers/quantization/fp8_utils.py +6 -0
  71. sglang/srt/layers/quantization/gptq.py +298 -6
  72. sglang/srt/layers/quantization/int8_kernel.py +20 -7
  73. sglang/srt/layers/quantization/qoq.py +244 -0
  74. sglang/srt/layers/sampler.py +0 -4
  75. sglang/srt/layers/vocab_parallel_embedding.py +18 -7
  76. sglang/srt/lora/lora_manager.py +2 -4
  77. sglang/srt/lora/mem_pool.py +4 -4
  78. sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
  79. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  80. sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
  81. sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
  82. sglang/srt/lora/utils.py +1 -1
  83. sglang/srt/managers/data_parallel_controller.py +3 -3
  84. sglang/srt/managers/deepseek_eplb.py +278 -0
  85. sglang/srt/managers/detokenizer_manager.py +21 -8
  86. sglang/srt/managers/eplb_manager.py +55 -0
  87. sglang/srt/managers/expert_distribution.py +704 -56
  88. sglang/srt/managers/expert_location.py +394 -0
  89. sglang/srt/managers/expert_location_dispatch.py +91 -0
  90. sglang/srt/managers/io_struct.py +19 -4
  91. sglang/srt/managers/mm_utils.py +294 -140
  92. sglang/srt/managers/multimodal_processors/base_processor.py +127 -42
  93. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +6 -1
  94. sglang/srt/managers/multimodal_processors/gemma3.py +31 -6
  95. sglang/srt/managers/multimodal_processors/internvl.py +14 -5
  96. sglang/srt/managers/multimodal_processors/janus_pro.py +7 -1
  97. sglang/srt/managers/multimodal_processors/kimi_vl.py +7 -6
  98. sglang/srt/managers/multimodal_processors/llava.py +46 -0
  99. sglang/srt/managers/multimodal_processors/minicpm.py +25 -31
  100. sglang/srt/managers/multimodal_processors/mllama4.py +6 -0
  101. sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
  102. sglang/srt/managers/multimodal_processors/qwen_vl.py +58 -16
  103. sglang/srt/managers/schedule_batch.py +122 -42
  104. sglang/srt/managers/schedule_policy.py +1 -5
  105. sglang/srt/managers/scheduler.py +205 -138
  106. sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
  107. sglang/srt/managers/session_controller.py +1 -1
  108. sglang/srt/managers/tokenizer_manager.py +232 -58
  109. sglang/srt/managers/tp_worker.py +12 -9
  110. sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
  111. sglang/srt/mem_cache/base_prefix_cache.py +3 -0
  112. sglang/srt/mem_cache/chunk_cache.py +3 -1
  113. sglang/srt/mem_cache/hiradix_cache.py +4 -4
  114. sglang/srt/mem_cache/memory_pool.py +76 -52
  115. sglang/srt/mem_cache/multimodal_cache.py +45 -0
  116. sglang/srt/mem_cache/radix_cache.py +58 -5
  117. sglang/srt/metrics/collector.py +314 -39
  118. sglang/srt/mm_utils.py +10 -0
  119. sglang/srt/model_executor/cuda_graph_runner.py +29 -19
  120. sglang/srt/model_executor/expert_location_updater.py +422 -0
  121. sglang/srt/model_executor/forward_batch_info.py +5 -1
  122. sglang/srt/model_executor/model_runner.py +163 -68
  123. sglang/srt/model_loader/loader.py +10 -6
  124. sglang/srt/models/clip.py +5 -1
  125. sglang/srt/models/deepseek_janus_pro.py +2 -2
  126. sglang/srt/models/deepseek_v2.py +308 -351
  127. sglang/srt/models/exaone.py +8 -3
  128. sglang/srt/models/gemma3_mm.py +70 -33
  129. sglang/srt/models/llama.py +2 -0
  130. sglang/srt/models/llama4.py +15 -8
  131. sglang/srt/models/llava.py +258 -7
  132. sglang/srt/models/mimo_mtp.py +220 -0
  133. sglang/srt/models/minicpmo.py +5 -12
  134. sglang/srt/models/mistral.py +71 -1
  135. sglang/srt/models/mixtral.py +98 -34
  136. sglang/srt/models/mllama.py +3 -3
  137. sglang/srt/models/pixtral.py +467 -0
  138. sglang/srt/models/qwen2.py +95 -26
  139. sglang/srt/models/qwen2_5_vl.py +8 -0
  140. sglang/srt/models/qwen2_moe.py +330 -60
  141. sglang/srt/models/qwen2_vl.py +6 -0
  142. sglang/srt/models/qwen3.py +52 -10
  143. sglang/srt/models/qwen3_moe.py +411 -48
  144. sglang/srt/models/roberta.py +1 -1
  145. sglang/srt/models/siglip.py +294 -0
  146. sglang/srt/models/torch_native_llama.py +1 -1
  147. sglang/srt/openai_api/adapter.py +58 -20
  148. sglang/srt/openai_api/protocol.py +6 -8
  149. sglang/srt/operations.py +154 -0
  150. sglang/srt/operations_strategy.py +31 -0
  151. sglang/srt/reasoning_parser.py +3 -3
  152. sglang/srt/sampling/custom_logit_processor.py +18 -3
  153. sglang/srt/sampling/sampling_batch_info.py +4 -56
  154. sglang/srt/sampling/sampling_params.py +2 -2
  155. sglang/srt/server_args.py +162 -22
  156. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
  157. sglang/srt/speculative/eagle_utils.py +138 -7
  158. sglang/srt/speculative/eagle_worker.py +69 -21
  159. sglang/srt/utils.py +74 -17
  160. sglang/test/few_shot_gsm8k.py +2 -2
  161. sglang/test/few_shot_gsm8k_engine.py +2 -2
  162. sglang/test/run_eval.py +2 -2
  163. sglang/test/runners.py +8 -1
  164. sglang/test/send_one.py +13 -3
  165. sglang/test/simple_eval_common.py +1 -1
  166. sglang/test/simple_eval_humaneval.py +1 -1
  167. sglang/test/test_cutlass_moe.py +278 -0
  168. sglang/test/test_programs.py +5 -5
  169. sglang/test/test_utils.py +55 -14
  170. sglang/utils.py +3 -3
  171. sglang/version.py +1 -1
  172. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/METADATA +23 -13
  173. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/RECORD +178 -149
  174. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/WHEEL +1 -1
  175. sglang/srt/function_call_parser.py +0 -858
  176. sglang/srt/platforms/interface.py +0 -371
  177. /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
  178. /sglang/srt/models/{xiaomi_mimo.py → mimo.py} +0 -0
  179. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/licenses/LICENSE +0 -0
  180. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/top_level.txt +0 -0
@@ -2,6 +2,7 @@
2
2
  Multi-modality utils
3
3
  """
4
4
 
5
+ import dataclasses
5
6
  import logging
6
7
  from abc import abstractmethod
7
8
  from typing import Callable, List, Optional, Tuple
@@ -15,10 +16,15 @@ from sglang.srt.managers.schedule_batch import (
15
16
  MultimodalInputs,
16
17
  global_server_args_dict,
17
18
  )
19
+ from sglang.srt.mem_cache.multimodal_cache import MultiModalCache
18
20
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
19
21
  from sglang.srt.utils import flatten_nested_list, print_warning_once
22
+ from sglang.utils import logger
20
23
 
21
- logger = logging.getLogger(__name__)
24
+ # NOTE: Using the shared logger from sglang.utils instead of creating a module-specific logger
25
+ # to ensure consistent logging behavior across the codebase. This prevents issues with log
26
+ # propagation that can cause some log messages (like 'server is fired up') to not appear
27
+ # in the console when multimodal support is enabled.
22
28
 
23
29
 
24
30
  class MultiModalityDataPaddingPattern:
@@ -41,17 +47,32 @@ class MultiModalityDataPaddingPattern:
41
47
  class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern):
42
48
  """In this pattern, data tokens should be enclosed by special token pairs (e.g. <image>...</image>, data_token_pairs)
43
49
 
50
+ The padded value in a region enclosed by a token pair with be the same one, as the MultimodalDataItem's pad value
51
+
44
52
  This strategy should be applied when data content is marked by start/end token pairs in the input sequence.
45
53
  """
46
54
 
47
- def __init__(self, data_token_pairs: Optional[List[Tuple[int, int]]]) -> None:
55
+ def __init__(
56
+ self,
57
+ data_token_pairs: Optional[List[Tuple[int, int]]],
58
+ data_start_token_ids: Optional[List[int]] = None,
59
+ ) -> None:
60
+ """
61
+
62
+ Args:
63
+ data_start_token_ids marks the start of a single multimodal data
64
+ See Minicpmo's slice_start_id for example
65
+ """
48
66
  self.data_token_id_pairs = data_token_pairs
67
+ self.data_start_token_ids = data_start_token_ids or [
68
+ s for s, _e in data_token_pairs
69
+ ]
49
70
 
50
71
  def pad_input_tokens(
51
72
  self, input_ids: List[int], mm_inputs: MultimodalInputs
52
73
  ) -> List[int]:
53
74
  """
54
- This function will replace the data-tokens inbetween with pad_values accordingly
75
+ This function will replace the data-tokens in between with pad_values accordingly
55
76
  """
56
77
  pad_values = [item.pad_value for item in mm_inputs.mm_items]
57
78
  data_token_pairs = self.data_token_id_pairs
@@ -79,7 +100,7 @@ class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern)
79
100
  for start_idx, end_idx in zip(start_indices, end_indices):
80
101
  padded_ids.extend(input_ids[last_idx : start_idx + 1])
81
102
 
82
- if input_ids[start_idx] in start_token_ids:
103
+ if input_ids[start_idx] in self.data_start_token_ids:
83
104
  data_idx += 1
84
105
  mm_inputs.data_offsets += [start_idx]
85
106
 
@@ -170,30 +191,140 @@ class MultiModalityDataPaddingPatternMultimodalTokens(MultiModalityDataPaddingPa
170
191
  output_ids_tensor[start_idx:end_idx] = pad_value
171
192
  else:
172
193
  logger.warning(f"Skipping region {i} due to None pad_value.")
173
-
174
194
  return output_ids_tensor.tolist()
175
195
 
176
196
 
197
+ embedding_cache = None
198
+
199
+
200
+ def init_embedding_cache(max_size: int):
201
+ global embedding_cache
202
+ embedding_cache = MultiModalCache(max_size)
203
+
204
+
205
+ def get_embedding_hash(embedding_items: List[MultimodalDataItem]) -> int:
206
+ hash_list = [item.hash for item in embedding_items]
207
+ return hash(tuple(hash_list))
208
+
209
+
210
+ def get_embedding_chunk(
211
+ embedding: torch.Tensor,
212
+ extend_prefix_len: int,
213
+ extend_seq_len: int,
214
+ items_offset: List[Tuple[int, int]],
215
+ ) -> Tuple[torch.Tensor, int, int]:
216
+ """
217
+ Extract a chunk of embeddings based on the specified prefix length, sequence length, and offset ranges.
218
+
219
+ Args:
220
+ embedding: The full embedding tensor to extract a chunk from
221
+ extend_prefix_len: The starting position (prefix length) for extraction
222
+ extend_seq_len: The number of tokens to extract
223
+ items_offset: List of [start, end] offset ranges for multimodal items in the input sequence
224
+
225
+ Returns:
226
+ A tuple containing:
227
+ - The extracted embedding chunk as a tensor
228
+ - The start index used for extraction
229
+ - The end index used for extraction
230
+
231
+ Note:
232
+ If there's no overlap between the requested range and the offset ranges,
233
+ an empty tensor is returned with zeros for start and end indices.
234
+ """
235
+ start_index, end_index = 0, 0
236
+ extend_start_index = extend_prefix_len
237
+ extend_end_index = extend_prefix_len + extend_seq_len - 1
238
+
239
+ for start, end in items_offset:
240
+ if extend_start_index >= start and extend_start_index <= end:
241
+ start_index += extend_start_index - start
242
+ elif extend_start_index > end:
243
+ start_index += end - start + 1
244
+
245
+ if extend_end_index >= start and extend_end_index <= end:
246
+ end_index += extend_end_index - start + 1
247
+ elif extend_end_index > end:
248
+ end_index += end - start + 1
249
+ # some models embedding is 3-dim, reshape it to 2-dim
250
+ embedding = embedding.reshape(-1, embedding.shape[-1])
251
+ embedding_chunk = embedding[start_index:end_index]
252
+ return embedding_chunk, start_index, end_index
253
+
254
+
177
255
  def get_embedding_and_mask(
178
256
  data_embedding_func: Callable[[List[MultimodalDataItem]], torch.Tensor],
179
257
  embedding_items: List[MultimodalDataItem],
180
258
  placeholder_tensor: torch.Tensor,
181
259
  input_ids: torch.Tensor,
182
- ):
260
+ items_size: List[int],
261
+ prefix_length: List[int],
262
+ extend_length: List[int],
263
+ items_offset_list: List[List[Tuple[int, int]]],
264
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
183
265
  """
184
- Get the multimodal embedding and its mask from input_ids
266
+ Generate multimodal embeddings and create a mask for identifying their positions in the input sequence.
267
+
268
+ Args:
269
+ data_embedding_func: Function that generates embeddings for multimodal items
270
+ embedding_items: List of multimodal items to embed
271
+ placeholder_tensor: Tensor containing token IDs that serve as placeholders for multimodal content
272
+ input_ids: The input token IDs tensor
273
+ items_size: Cumulative sizes of multimodal items per request
274
+ prefix_length: Prefix lengths for each request
275
+ extend_length: Sequence lengths for each request
276
+ items_offset_list: List of offset ranges for multimodal items in each request
185
277
 
278
+ Returns:
279
+ A tuple containing:
280
+ - The generated embeddings tensor
281
+ - A boolean mask tensor indicating where these embeddings should be placed
282
+
283
+ Raises:
284
+ AssertionError: If the number of multimodal tokens in input_ids doesn't match
285
+ the number of tokens in the generated embeddings
186
286
  """
187
287
  # 1. Get the embedding
188
- embedding = data_embedding_func(embedding_items)
288
+ # Calculate embedding for each request, try to get it from cache to avoid repeated calculation
289
+ embedding_list = []
290
+ for i in range(len(items_size) - 1):
291
+ if items_size[i] == items_size[i + 1]:
292
+ continue
293
+ embedding_items_per_req = embedding_items[items_size[i] : items_size[i + 1]]
294
+ items_offset = items_offset_list[i]
295
+ embedding_items_hash = get_embedding_hash(embedding_items_per_req)
296
+ # if all items has been prefixed, we do not need to calculate embedding
297
+ if all([offset_end < prefix_length[i] for _, offset_end in items_offset]):
298
+ continue
299
+ embedding_per_req = embedding_cache.get(embedding_items_hash)
300
+ if embedding_per_req is None:
301
+ embedding_per_req = data_embedding_func(embedding_items_per_req)
302
+ if not embedding_cache.put(embedding_items_hash, embedding_per_req):
303
+ print_warning_once(
304
+ "Multimodal embedding cache is full. Consider increasing the "
305
+ "`SGLANG_VLM_CACHE_SIZE_MB` environment variable."
306
+ )
189
307
 
308
+ embedding_per_req_chunk, _, end_index = get_embedding_chunk(
309
+ embedding=embedding_per_req,
310
+ extend_prefix_len=prefix_length[i],
311
+ extend_seq_len=extend_length[i],
312
+ items_offset=items_offset,
313
+ )
314
+ # remove this item from cache if chunk reaches to the end
315
+ embedding_per_req_length = (
316
+ embedding_per_req.shape[0]
317
+ if embedding_per_req.dim() == 2
318
+ else embedding_per_req.shape[0] * embedding_per_req.shape[1]
319
+ )
320
+ if end_index == embedding_per_req_length:
321
+ embedding_cache.free(embedding_items_hash)
322
+ embedding_list.append(embedding_per_req_chunk)
323
+ if len(embedding_list) == 0:
324
+ return None, None
325
+ embedding = torch.concat(embedding_list, dim=0)
190
326
  # 2. Check the embedding
191
- if embedding.dim() == 2:
192
- num_mm_tokens_in_embedding = embedding.shape[0]
193
- else:
194
- num_mm_tokens_in_embedding = embedding.shape[0] * embedding.shape[1]
195
-
196
- # the mask of multimodal tokens from input_ids
327
+ num_mm_tokens_in_embedding = embedding.shape[0]
197
328
  special_multimodal_mask = torch.isin(
198
329
  input_ids,
199
330
  placeholder_tensor,
@@ -202,14 +333,11 @@ def get_embedding_and_mask(
202
333
  num_mm_tokens_in_input_ids = special_multimodal_mask.sum().item()
203
334
  if num_mm_tokens_in_input_ids != num_mm_tokens_in_embedding:
204
335
  logger.warning(
205
- f"Number of tokens in multimodal embedding does not match those in the input text."
336
+ f"Number of tokens in multimodal embedding does not match those in the input text. "
206
337
  f"Got {num_mm_tokens_in_input_ids} tokens in the text but {num_mm_tokens_in_embedding} "
207
338
  "tokens from multimodal embeddings."
208
339
  )
209
340
  if num_mm_tokens_in_input_ids < num_mm_tokens_in_embedding:
210
- # TODO: chunked prefill will split special tokens from input_ids into several passes, failing the embedding
211
- # a fix may be cache the unfinished multimodal embedding for future reuse, determine the tokens to embed with
212
- # extend_start_loc and extend_seq_lens
213
341
  chunked_prefill_size = global_server_args_dict["chunked_prefill_size"]
214
342
  if chunked_prefill_size != -1:
215
343
  logger.warning(
@@ -230,7 +358,9 @@ def get_embedding_and_mask(
230
358
 
231
359
 
232
360
  def embed_mm_inputs(
233
- mm_inputs: MultimodalInputs,
361
+ mm_inputs_list: List[MultimodalInputs],
362
+ extend_prefix_lens: List[int],
363
+ extend_seq_lens: List[int],
234
364
  input_ids: torch.Tensor,
235
365
  input_embedding: nn.Embedding,
236
366
  image_data_embedding_func: Callable[
@@ -242,125 +372,133 @@ def embed_mm_inputs(
242
372
  placeholder_tokens: dict[Modality, List[int]] = None,
243
373
  ) -> Optional[torch.Tensor]:
244
374
  """
245
- Calculate the multimodal embeddings if necessary, then scatter the result with the help of a boolean mask denoting the embed locations
375
+ Embed multimodal inputs and integrate them with text token embeddings.
376
+
377
+ Args:
378
+ mm_inputs_list: List of multimodal inputs to process
379
+ extend_prefix_lens: Prefix lengths for each request
380
+ extend_seq_lens: Sequence lengths for each request
381
+ input_ids: Input token IDs tensor
382
+ input_embedding: Embedding layer for text tokens
383
+ image_data_embedding_func: Function to embed image data
384
+ audio_data_embedding_func: Function to embed audio data
385
+ placeholder_tokens: Token IDs for multimodal placeholders (uses pad_values if None)
246
386
 
247
- Args:
248
- placeholder_tokens: denoting the token of multimodal data in input_ids.
249
- If none, the pad_values of multimodal items are used
250
-
251
- Returns:
252
- final embedding: Optional[torch.Tensor]
387
+ Returns:
388
+ Combined embedding tensor with multimodal content integrated
253
389
  """
254
390
 
255
- if mm_inputs is None:
391
+ if mm_inputs_list is None:
256
392
  return None
257
393
 
258
394
  # 1. Calculate the multimodal data which exists in input_ids, with the help of pad_values
259
395
  # we assume that multimodal data are represented with its pad_values in input_ids
260
- # See `pad_input_ids` for more detail
261
-
262
- # if placeholder_tokens is specified
263
- if placeholder_tokens is not None:
264
- placeholder_token_ids = flatten_nested_list(
265
- [placeholder_token for placeholder_token in placeholder_tokens.values()]
266
- )
267
- else:
268
- placeholder_token_ids = [item.pad_value for item in mm_inputs.mm_items]
269
-
270
- assert isinstance(placeholder_token_ids[0], int)
271
-
272
- placeholder_tensor = torch.tensor(placeholder_token_ids, device=input_ids.device)
273
-
274
- placeholder_masks = torch.isin(input_ids, placeholder_tensor)
396
+ item_flatten_list = []
397
+ for mm_inputs in mm_inputs_list:
398
+ item_flatten_list += [item for item in mm_inputs.mm_items if item is not None]
275
399
 
276
- appearing_pad_values = torch.unique(
277
- input_ids[placeholder_masks], return_counts=False
278
- )
400
+ embeddings, masks = [], []
279
401
 
280
- if appearing_pad_values.numel() == 0:
281
- # all been prefixed
282
- inputs_embeds = input_embedding(input_ids)
283
- else:
284
- appearing_items = [
285
- item
286
- for item in mm_inputs.mm_items
287
- if item.pad_value is not None and item.pad_value in appearing_pad_values
288
- ]
289
-
290
- using_all_items = False
291
- if len(appearing_items) == 0:
292
- # This happens mostly when arg placeholder_token_ids is passed
293
- logger.warning(
294
- "No multimodal data item's pad value exist in placeholder ids. Using all items"
402
+ # 2. Get multimodal embedding separately
403
+ # TODO: make this more generic
404
+ # Try get image embedding if any
405
+ if (
406
+ any(True for item in item_flatten_list if item.is_image())
407
+ and image_data_embedding_func
408
+ ):
409
+ items = [item for item in item_flatten_list if item.is_image()]
410
+ placeholder_tensor = torch.tensor(
411
+ [item.pad_value for item in items],
412
+ device=input_ids.device,
413
+ )
414
+ # calculate per request items length offset
415
+ items_size = torch.zeros(len(mm_inputs_list) + 1, dtype=int)
416
+ items_offsets = []
417
+ for i, mm_inputs in enumerate(mm_inputs_list):
418
+ image_items = [item for item in mm_inputs.mm_items if item.is_image()]
419
+ items_size[i + 1] = len(image_items)
420
+ items_offsets.append(
421
+ flatten_nested_list(
422
+ [
423
+ item.image_offsets
424
+ for item in mm_inputs.mm_items
425
+ if item.is_image()
426
+ ]
427
+ )
295
428
  )
296
- using_all_items = True
297
- appearing_items = mm_inputs.mm_items
429
+ items_size = torch.cumsum(items_size, dim=0).tolist()
298
430
 
299
- embeddings, masks = [], []
431
+ embedding, mask = get_embedding_and_mask(
432
+ data_embedding_func=image_data_embedding_func,
433
+ embedding_items=items,
434
+ placeholder_tensor=placeholder_tensor,
435
+ input_ids=input_ids,
436
+ items_size=items_size,
437
+ prefix_length=extend_prefix_lens,
438
+ extend_length=extend_seq_lens,
439
+ items_offset_list=items_offsets,
440
+ )
441
+ embeddings += [embedding]
442
+ masks += [mask]
300
443
 
301
- # 2. Get multimodal embedding separately
302
- # TODO: make this more generic
303
- # Try get image embedding if any
304
- if (
305
- any(True for item in appearing_items if item.is_image())
306
- and image_data_embedding_func
307
- ):
308
- items = [item for item in appearing_items if item.is_image()]
309
- embedding, mask = get_embedding_and_mask(
310
- data_embedding_func=image_data_embedding_func,
311
- embedding_items=items,
312
- placeholder_tensor=(
313
- # use the specified modality token to identify the location to embed
314
- placeholder_tokens[Modality.IMAGE]
315
- if using_all_items
316
- else torch.tensor(
317
- [item.pad_value for item in items],
318
- device=input_ids.device,
319
- )
320
- ),
321
- input_ids=input_ids,
444
+ # Try get audio embedding if any
445
+ if (
446
+ any(True for item in item_flatten_list if item.is_audio())
447
+ and audio_data_embedding_func
448
+ ):
449
+ items = [item for item in item_flatten_list if item.is_audio()]
450
+ placeholder_tensor = torch.tensor(
451
+ [item.pad_value for item in items],
452
+ device=input_ids.device,
453
+ )
454
+ items_offsets = []
455
+ # calculate per request items length offset
456
+ items_size = torch.zeros(len(mm_inputs_list) + 1, dtype=int)
457
+ for i, mm_inputs in enumerate(mm_inputs_list):
458
+ audio_items = [item for item in mm_inputs.mm_items if item.is_audio()]
459
+ items_size[i + 1] = len(audio_items)
460
+ items_offsets.append(
461
+ flatten_nested_list(
462
+ [
463
+ item.audio_offsets
464
+ for item in mm_inputs.mm_items
465
+ if item.is_audio()
466
+ ]
467
+ )
322
468
  )
323
- embeddings += [embedding]
324
- masks += [mask]
469
+ items_size = torch.cumsum(items_size, dim=0)
325
470
 
326
- # Try get audio embedding if any
327
- if (
328
- any(True for item in appearing_items if item.is_audio())
329
- and audio_data_embedding_func
330
- ):
331
- items = [item for item in appearing_items if item.is_audio()]
332
- embedding, mask = get_embedding_and_mask(
333
- data_embedding_func=audio_data_embedding_func,
334
- embedding_items=items,
335
- placeholder_tensor=(
336
- placeholder_tokens[Modality.AUDIO]
337
- if using_all_items
338
- else torch.tensor(
339
- [item.pad_value for item in items],
340
- device=input_ids.device,
341
- )
342
- ),
343
- input_ids=input_ids,
344
- )
345
- embeddings += [embedding]
346
- masks += [mask]
347
-
348
- # 3. Get input embeddings
349
- vocab_size = input_embedding.num_embeddings
350
- # Important: clamp after getting original multimodal regions
351
- # Clamp input ids. This is because the input_ids for the multimodal tokens are
352
- # filled with the hash values of the multimodal for the prefix matching in the radix attention.
353
- # There values are useless because their embeddings will be replaced by vision embeddings anyway.
354
- input_ids.clamp_(min=0, max=vocab_size - 1)
355
- inputs_embeds = input_embedding(input_ids)
356
-
357
- # 4. Scatter embeddings into input embedding
358
- for embedding, mask in zip(embeddings, masks):
359
- mask = mask.expand_as(inputs_embeds).to(inputs_embeds.device)
360
- inputs_embeds = inputs_embeds.masked_scatter(
361
- mask,
362
- embedding.to(inputs_embeds.device, inputs_embeds.dtype),
363
- )
471
+ embedding, mask = get_embedding_and_mask(
472
+ data_embedding_func=audio_data_embedding_func,
473
+ embedding_items=items,
474
+ placeholder_tensor=placeholder_tensor,
475
+ input_ids=input_ids,
476
+ items_size=items_size,
477
+ prefix_length=extend_prefix_lens,
478
+ extend_length=extend_seq_lens,
479
+ items_offset_list=items_offsets,
480
+ )
481
+ embeddings += [embedding]
482
+ masks += [mask]
483
+
484
+ # 3. Get input embeddings
485
+ vocab_size = input_embedding.num_embeddings
486
+ # Important: clamp after getting original multimodal regions
487
+ # Clamp input ids. This is because the input_ids for the multimodal tokens are
488
+ # filled with the hash values of the multimodal for the prefix matching in the radix attention.
489
+ # There values are useless because their embeddings will be replaced by vision embeddings anyway.
490
+ input_ids.clamp_(min=0, max=vocab_size - 1)
491
+ inputs_embeds = input_embedding(input_ids)
492
+
493
+ # 4. scatter embeddings into input embedding
494
+ for embedding, mask in zip(embeddings, masks):
495
+ if embedding is None or mask is None:
496
+ continue
497
+ mask = mask.expand_as(inputs_embeds).to(inputs_embeds.device)
498
+ inputs_embeds = inputs_embeds.masked_scatter(
499
+ mask,
500
+ embedding.to(inputs_embeds.device, inputs_embeds.dtype),
501
+ )
364
502
  return inputs_embeds
365
503
 
366
504
 
@@ -368,37 +506,53 @@ def general_mm_embed_routine(
368
506
  input_ids: torch.Tensor,
369
507
  forward_batch: ForwardBatch,
370
508
  language_model: nn.Module,
371
- image_data_embedding_func: Callable[
372
- [List[MultimodalDataItem]], torch.Tensor
509
+ image_data_embedding_func: Optional[
510
+ Callable[[List[MultimodalDataItem]], torch.Tensor]
373
511
  ] = None,
374
- audio_data_embedding_func: Callable[
375
- [List[MultimodalDataItem]], torch.Tensor
512
+ audio_data_embedding_func: Optional[
513
+ Callable[[List[MultimodalDataItem]], torch.Tensor]
376
514
  ] = None,
377
- placeholder_tokens: dict[Modality, List[int]] = None,
515
+ placeholder_tokens: Optional[dict[Modality, List[int]]] = None,
378
516
  **kwargs,
379
517
  ) -> torch.Tensor:
380
518
  """
381
- A general wrapper function to get final input embeds from multimodal models with a language model as causal model
519
+ Process multimodal inputs and forward through language model.
382
520
 
383
- Args:
384
- placeholder_token_ids (List[int]): the ids of mm data placeholder tokens
385
- image_data_embedding_func : the function returning the image embedding
386
- audio_data_embedding_func : the function returning the image embedding
387
-
388
- Returns:
389
- forwarded hidden states
521
+ Args:
522
+ input_ids: Input token IDs tensor
523
+ forward_batch: Batch information for model forward pass
524
+ language_model: Base language model to use
525
+ image_data_embedding_func: Function to embed image data
526
+ audio_data_embedding_func: Function to embed audio data
527
+ placeholder_tokens: Token IDs for multimodal placeholders
528
+ **kwargs: Additional arguments passed to language model
390
529
 
530
+ Returns:
531
+ Hidden states from language model forward pass
391
532
  """
392
-
393
533
  assert hasattr(language_model, "get_input_embeddings")
394
534
  embed_tokens = language_model.get_input_embeddings()
395
535
  if (
396
536
  not forward_batch.forward_mode.is_decode()
397
537
  and forward_batch.contains_mm_inputs()
398
538
  ):
399
- mm_input = forward_batch.merge_mm_inputs()
539
+ mm_inputs_list = [
540
+ mm_input for mm_input in forward_batch.mm_inputs if mm_input is not None
541
+ ]
542
+ extend_prefix_lens = [
543
+ prefix_len
544
+ for i, prefix_len in enumerate(forward_batch.extend_prefix_lens_cpu)
545
+ if forward_batch.mm_inputs[i] is not None
546
+ ]
547
+ extend_seq_lens = [
548
+ seq_len
549
+ for i, seq_len in enumerate(forward_batch.extend_seq_lens_cpu)
550
+ if forward_batch.mm_inputs[i] is not None
551
+ ]
400
552
  inputs_embeds = embed_mm_inputs(
401
- mm_inputs=mm_input,
553
+ mm_inputs_list=mm_inputs_list,
554
+ extend_prefix_lens=extend_prefix_lens,
555
+ extend_seq_lens=extend_seq_lens,
402
556
  input_ids=input_ids,
403
557
  input_embedding=embed_tokens,
404
558
  image_data_embedding_func=image_data_embedding_func,