sglang 0.4.4.post3__py3-none-any.whl → 0.4.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (99) hide show
  1. sglang/bench_serving.py +49 -7
  2. sglang/lang/chat_template.py +24 -0
  3. sglang/srt/_custom_ops.py +59 -92
  4. sglang/srt/configs/model_config.py +5 -0
  5. sglang/srt/constrained/base_grammar_backend.py +5 -1
  6. sglang/srt/conversation.py +29 -4
  7. sglang/srt/custom_op.py +5 -0
  8. sglang/srt/distributed/device_communicators/custom_all_reduce.py +27 -79
  9. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
  10. sglang/srt/entrypoints/engine.py +0 -5
  11. sglang/srt/layers/attention/flashattention_backend.py +678 -83
  12. sglang/srt/layers/attention/flashinfer_backend.py +5 -7
  13. sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -3
  14. sglang/srt/layers/attention/flashmla_backend.py +1 -1
  15. sglang/srt/layers/moe/ep_moe/kernels.py +142 -0
  16. sglang/srt/layers/moe/ep_moe/layer.py +79 -80
  17. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +382 -199
  18. sglang/srt/layers/moe/fused_moe_native.py +5 -0
  19. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=512,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  20. sglang/srt/layers/moe/fused_moe_triton/configs/E=144,N=512,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  21. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  22. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1024,device_name=NVIDIA_H200.json +146 -0
  23. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  24. sglang/srt/layers/moe/fused_moe_triton/configs/E=20,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  25. sglang/srt/layers/moe/fused_moe_triton/configs/E=24,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  26. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,block_shape=[128, 128].json +146 -0
  27. sglang/srt/layers/moe/fused_moe_triton/configs/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  28. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  29. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +416 -50
  30. sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
  31. sglang/srt/layers/moe/topk.py +49 -3
  32. sglang/srt/layers/quantization/__init__.py +5 -1
  33. sglang/srt/layers/quantization/blockwise_int8.py +2 -0
  34. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +2 -1
  35. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +34 -10
  36. sglang/srt/layers/quantization/fp8.py +3 -1
  37. sglang/srt/layers/quantization/fp8_utils.py +1 -4
  38. sglang/srt/layers/quantization/moe_wna16.py +503 -0
  39. sglang/srt/layers/quantization/utils.py +1 -1
  40. sglang/srt/layers/quantization/w8a8_int8.py +2 -0
  41. sglang/srt/layers/radix_attention.py +2 -0
  42. sglang/srt/layers/rotary_embedding.py +63 -12
  43. sglang/srt/managers/cache_controller.py +34 -11
  44. sglang/srt/managers/mm_utils.py +202 -156
  45. sglang/srt/managers/multimodal_processor.py +0 -2
  46. sglang/srt/managers/multimodal_processors/base_processor.py +45 -77
  47. sglang/srt/managers/multimodal_processors/clip.py +7 -26
  48. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +17 -58
  49. sglang/srt/managers/multimodal_processors/gemma3.py +12 -27
  50. sglang/srt/managers/multimodal_processors/janus_pro.py +21 -47
  51. sglang/srt/managers/multimodal_processors/llava.py +34 -14
  52. sglang/srt/managers/multimodal_processors/minicpm.py +35 -38
  53. sglang/srt/managers/multimodal_processors/mlama.py +10 -23
  54. sglang/srt/managers/multimodal_processors/mllama4.py +161 -0
  55. sglang/srt/managers/multimodal_processors/qwen_vl.py +22 -45
  56. sglang/srt/managers/schedule_batch.py +185 -128
  57. sglang/srt/managers/scheduler.py +4 -4
  58. sglang/srt/managers/tokenizer_manager.py +1 -1
  59. sglang/srt/managers/utils.py +1 -6
  60. sglang/srt/mem_cache/hiradix_cache.py +62 -52
  61. sglang/srt/mem_cache/memory_pool.py +72 -6
  62. sglang/srt/mem_cache/paged_allocator.py +39 -0
  63. sglang/srt/metrics/collector.py +23 -53
  64. sglang/srt/model_executor/cuda_graph_runner.py +8 -6
  65. sglang/srt/model_executor/forward_batch_info.py +10 -10
  66. sglang/srt/model_executor/model_runner.py +60 -57
  67. sglang/srt/model_loader/loader.py +8 -0
  68. sglang/srt/models/clip.py +12 -7
  69. sglang/srt/models/deepseek_janus_pro.py +10 -15
  70. sglang/srt/models/deepseek_v2.py +212 -121
  71. sglang/srt/models/deepseek_vl2.py +105 -104
  72. sglang/srt/models/gemma3_mm.py +14 -80
  73. sglang/srt/models/llama.py +16 -5
  74. sglang/srt/models/llama4.py +420 -0
  75. sglang/srt/models/llava.py +31 -19
  76. sglang/srt/models/llavavid.py +16 -7
  77. sglang/srt/models/minicpmo.py +63 -147
  78. sglang/srt/models/minicpmv.py +17 -27
  79. sglang/srt/models/mllama.py +29 -14
  80. sglang/srt/models/mllama4.py +154 -0
  81. sglang/srt/models/qwen2.py +9 -6
  82. sglang/srt/models/qwen2_5_vl.py +21 -31
  83. sglang/srt/models/qwen2_vl.py +20 -21
  84. sglang/srt/openai_api/adapter.py +18 -6
  85. sglang/srt/platforms/interface.py +371 -0
  86. sglang/srt/server_args.py +99 -14
  87. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -5
  88. sglang/srt/speculative/eagle_utils.py +140 -28
  89. sglang/srt/speculative/eagle_worker.py +93 -24
  90. sglang/srt/utils.py +104 -51
  91. sglang/test/test_custom_ops.py +55 -0
  92. sglang/test/test_utils.py +13 -26
  93. sglang/utils.py +2 -2
  94. sglang/version.py +1 -1
  95. {sglang-0.4.4.post3.dist-info → sglang-0.4.5.dist-info}/METADATA +4 -3
  96. {sglang-0.4.4.post3.dist-info → sglang-0.4.5.dist-info}/RECORD +99 -84
  97. {sglang-0.4.4.post3.dist-info → sglang-0.4.5.dist-info}/WHEEL +0 -0
  98. {sglang-0.4.4.post3.dist-info → sglang-0.4.5.dist-info}/licenses/LICENSE +0 -0
  99. {sglang-0.4.4.post3.dist-info → sglang-0.4.5.dist-info}/top_level.txt +0 -0
@@ -149,6 +149,7 @@ class HiCacheController:
149
149
  self,
150
150
  token_to_kv_pool_allocator: TokenToKVPoolAllocator,
151
151
  mem_pool_host: HostKVCache,
152
+ page_size: int,
152
153
  load_cache_event: threading.Event = None,
153
154
  write_policy: str = "write_through_selective",
154
155
  ):
@@ -156,6 +157,7 @@ class HiCacheController:
156
157
  self.mem_pool_device = token_to_kv_pool_allocator.get_kvcache()
157
158
  self.mem_pool_host = mem_pool_host
158
159
  self.write_policy = write_policy
160
+ self.page_size = page_size
159
161
 
160
162
  self.load_cache_event = load_cache_event
161
163
  self.layer_done_counter = LayerDoneCounter(self.mem_pool_device.layer_num)
@@ -184,7 +186,12 @@ class HiCacheController:
184
186
  self.load_stream = torch.cuda.Stream()
185
187
 
186
188
  self.write_thread = threading.Thread(
187
- target=self.write_thread_func_buffer, daemon=True
189
+ target=(
190
+ self.write_thread_func_buffer
191
+ if self.page_size == 1
192
+ else self.write_thread_func_direct
193
+ ),
194
+ daemon=True,
188
195
  )
189
196
  self.load_thread = threading.Thread(
190
197
  target=self.load_thread_func_layer_by_layer, daemon=True
@@ -205,7 +212,12 @@ class HiCacheController:
205
212
  self.ack_load_queue.queue.clear()
206
213
 
207
214
  self.write_thread = threading.Thread(
208
- target=self.write_thread_func_buffer, daemon=True
215
+ target=(
216
+ self.write_thread_func_buffer
217
+ if self.page_size == 1
218
+ else self.write_thread_func_direct
219
+ ),
220
+ daemon=True,
209
221
  )
210
222
  self.load_thread = threading.Thread(
211
223
  target=self.load_thread_func_layer_by_layer, daemon=True
@@ -260,10 +272,12 @@ class HiCacheController:
260
272
  while not self.stop_event.is_set():
261
273
  try:
262
274
  operation = self.write_queue.get(block=True, timeout=1)
263
- operation.data = self.mem_pool_device.get_flat_data(
264
- operation.device_indices
275
+ self.mem_pool_host.write_page_all_layers(
276
+ operation.host_indices,
277
+ operation.device_indices,
278
+ self.mem_pool_device,
265
279
  )
266
- self.mem_pool_host.transfer(operation.host_indices, operation.data)
280
+ self.write_stream.synchronize()
267
281
  self.mem_pool_host.complete_io(operation.host_indices)
268
282
  for node_id in operation.node_ids:
269
283
  if node_id != 0:
@@ -320,12 +334,21 @@ class HiCacheController:
320
334
 
321
335
  self.layer_done_counter.reset()
322
336
  for i in range(self.mem_pool_host.layer_num):
323
- flat_data = self.mem_pool_host.get_flat_data_by_layer(
324
- batch_operation.host_indices, i
325
- )
326
- self.mem_pool_device.transfer_per_layer(
327
- batch_operation.device_indices, flat_data, i
328
- )
337
+ if self.page_size == 1:
338
+ flat_data = self.mem_pool_host.get_flat_data_by_layer(
339
+ batch_operation.host_indices, i
340
+ )
341
+ self.mem_pool_device.transfer_per_layer(
342
+ batch_operation.device_indices, flat_data, i
343
+ )
344
+ else:
345
+ self.mem_pool_host.load_page_per_layer(
346
+ batch_operation.host_indices,
347
+ batch_operation.device_indices,
348
+ self.mem_pool_device,
349
+ i,
350
+ )
351
+ self.load_stream.synchronize()
329
352
  self.layer_done_counter.increment()
330
353
 
331
354
  self.mem_pool_host.complete_io(batch_operation.host_indices)
@@ -1,5 +1,5 @@
1
1
  """
2
- Multimodality utils
2
+ Multi-modality utils
3
3
  """
4
4
 
5
5
  from abc import abstractmethod
@@ -9,11 +9,13 @@ import torch
9
9
  from torch import nn
10
10
 
11
11
  from sglang.srt.managers.schedule_batch import (
12
+ MultimodalDataItem,
12
13
  MultimodalInputs,
13
14
  global_server_args_dict,
14
15
  logger,
15
16
  )
16
17
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
18
+ from sglang.srt.utils import print_warning_once
17
19
  from sglang.utils import logger
18
20
 
19
21
 
@@ -26,7 +28,7 @@ class MultiModalityDataPaddingPattern:
26
28
 
27
29
  @abstractmethod
28
30
  def pad_input_tokens(
29
- self, input_ids: List[int], image_inputs: MultimodalInputs
31
+ self, input_ids: List[int], mm_inputs: MultimodalInputs
30
32
  ) -> List[int]:
31
33
  """
32
34
  Pad the input ids sequence containing data tokens, and replace them with pad_values
@@ -49,13 +51,13 @@ class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern)
49
51
  """
50
52
  This function will replace the data-tokens inbetween with pad_values accordingly
51
53
  """
52
- pad_values = mm_inputs.pad_values
54
+ pad_values = [item.pad_value for item in mm_inputs.mm_items]
53
55
  data_token_pairs = self.data_token_id_pairs
54
- mm_inputs.image_offsets = []
56
+ mm_inputs.data_offsets = []
55
57
  if data_token_pairs is None:
56
58
  data_token_pairs = [mm_inputs.im_start_id, mm_inputs.im_end_id]
57
59
  if data_token_pairs is None:
58
- logger.warning(
60
+ print_warning_once(
59
61
  "No data_token_pairs provided, RadixAttention might be influenced."
60
62
  )
61
63
  return input_ids
@@ -77,10 +79,10 @@ class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern)
77
79
 
78
80
  if input_ids[start_idx] in start_token_ids:
79
81
  data_idx += 1
80
- mm_inputs.image_offsets += [start_idx]
82
+ mm_inputs.data_offsets += [start_idx]
81
83
 
82
- if data_idx >= len(mm_inputs.pad_values):
83
- data_idx = len(mm_inputs.pad_values) - 1
84
+ if data_idx >= len(pad_values):
85
+ data_idx = len(pad_values) - 1
84
86
 
85
87
  num_tokens = end_idx - start_idx - 1
86
88
  pad_value = pad_values[data_idx]
@@ -94,68 +96,19 @@ class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern)
94
96
  return padded_ids
95
97
 
96
98
 
97
- class MultModalityDataPaddingPatternSingleToken(MultiModalityDataPaddingPattern):
98
- """In this pattern, data is represented with a special token_id ( image_inputs.im_token_id ),
99
- which needs first to be expanded to multiple tokens, then replaced with their padding values
100
-
101
- This strategy should be used when a single data token represents content that should
102
- be expanded to multiple tokens during processing.
103
- """
104
-
105
- def __init__(
106
- self, num_data_token_calc_func: Callable[[Tuple[int, int, int]], int]
107
- ) -> None:
108
- self.num_data_token_calc_func = num_data_token_calc_func
109
-
110
- def pad_input_tokens(
111
- self, input_ids: List[int], mm_inputs: MultimodalInputs
112
- ) -> List[int]:
113
- """
114
- This function will follow the procedure of:
115
- 1. the data token will be expanded, of which the final number will be calculated by `num_data_token_calc_func`
116
- 2. the padded data tokens will be replaced with their pad_values
117
- """
118
- image_grid_thws = mm_inputs.image_grid_thws
119
- pad_values = mm_inputs.pad_values
120
-
121
- image_indices = [
122
- idx for idx, token in enumerate(input_ids) if token == mm_inputs.im_token_id
123
- ]
124
-
125
- mm_inputs.image_offsets = []
126
-
127
- input_ids_with_image = []
128
- for image_cnt, _ in enumerate(image_grid_thws):
129
- # print(f"image_cnt {image_cnt}")
130
- num_image_tokens = self.num_data_token_calc_func(image_grid_thws[image_cnt])
131
- if image_cnt == 0:
132
- non_image_tokens = input_ids[: image_indices[image_cnt]]
133
- else:
134
- non_image_tokens = input_ids[
135
- image_indices[image_cnt - 1] + 1 : image_indices[image_cnt]
136
- ]
137
- input_ids_with_image.extend(non_image_tokens)
138
- mm_inputs.image_offsets.append(len(input_ids_with_image))
139
- pad_ids = pad_values * (
140
- (num_image_tokens + len(pad_values)) // len(pad_values)
141
- )
142
- input_ids_with_image.extend(pad_ids[:num_image_tokens])
143
- input_ids_with_image.extend(input_ids[image_indices[-1] + 1 :])
144
-
145
- return input_ids_with_image
146
-
147
-
148
99
  class MultiModalityDataPaddingPatternImageTokens(MultiModalityDataPaddingPattern):
149
- """In this pattern, data tokens should be represented as image tokens (e.g. <image><image>....<image>)"""
100
+ """In this pattern, data tokens should be represented as repetitions of a single token
101
+ e.g. <image><image>....<image>, or <audio><audio>...<audio>
102
+ """
150
103
 
151
104
  def __init__(self, image_token_id: torch.Tensor) -> None:
152
105
  self.image_token_id = image_token_id
153
106
 
154
- def pad_input_tokens(self, input_ids: List[int], image_inputs) -> List[int]:
107
+ def pad_input_tokens(self, input_ids: List[int], mm_inputs) -> List[int]:
155
108
  """
156
109
  This function will replace the data-tokens in between with pad_values accordingly
157
110
  """
158
- pad_values = image_inputs.pad_values
111
+ pad_values = [item.pad_value for item in mm_inputs.mm_items]
159
112
  assert len(pad_values) != 0
160
113
 
161
114
  input_ids_tensor = torch.tensor(input_ids)
@@ -170,138 +123,227 @@ class MultiModalityDataPaddingPatternImageTokens(MultiModalityDataPaddingPattern
170
123
  return input_ids_tensor.tolist()
171
124
 
172
125
 
126
+ def get_embedding_and_mask(
127
+ data_embedding_func: Callable[[List[MultimodalDataItem]], torch.Tensor],
128
+ embedding_items: List[MultimodalDataItem],
129
+ placeholder_tensor: torch.Tensor,
130
+ input_ids: torch.Tensor,
131
+ ):
132
+ """
133
+ Get the multimodal embedding and its mask from input_ids
134
+
135
+ """
136
+ # 1. Get the embedding
137
+ embedding = data_embedding_func(embedding_items)
138
+
139
+ # 2. Check the embedding
140
+ if embedding.dim() == 2:
141
+ num_mm_tokens_in_embedding = embedding.shape[0]
142
+ else:
143
+ num_mm_tokens_in_embedding = embedding.shape[0] * embedding.shape[1]
144
+
145
+ # the mask of multimodal tokens from input_ids
146
+ special_multimodal_mask = torch.isin(
147
+ input_ids,
148
+ placeholder_tensor,
149
+ ).unsqueeze(-1)
150
+
151
+ num_mm_tokens_in_input_ids = special_multimodal_mask.sum()
152
+ if num_mm_tokens_in_input_ids != num_mm_tokens_in_embedding:
153
+ logger.warning(
154
+ f"Number of tokens in multimodal embedding does not match those in the input text."
155
+ f"Got {num_mm_tokens_in_input_ids} tokens in the text but {num_mm_tokens_in_embedding} "
156
+ "tokens from multimodal embeddings."
157
+ )
158
+ if num_mm_tokens_in_input_ids < num_mm_tokens_in_embedding:
159
+ # TODO: chunked prefill will split special tokens from input_ids into several passes, failing the embedding
160
+ # a fix may be cache the unfinished multimodal embedding for future reuse, determine the tokens to embed with
161
+ # extend_start_loc and extend_seq_lens
162
+ chunked_prefill_size = global_server_args_dict["chunked_prefill_size"]
163
+ if chunked_prefill_size != -1:
164
+ logger.warning(
165
+ "You may want to avoid this issue by raising `chunked_prefill_size`, or disabling chunked prefill"
166
+ )
167
+ # extract from the end: this is a compromise
168
+ if embedding.dim() == 2:
169
+ embedding = embedding[-num_mm_tokens_in_input_ids:, :]
170
+ else:
171
+ num_multimodal = num_mm_tokens_in_input_ids // embedding.shape[0]
172
+ embedding = embedding[-num_multimodal:, :]
173
+ else:
174
+ raise RuntimeError(
175
+ "Insufficient multimodal embedding length. This is an internal error"
176
+ )
177
+
178
+ return embedding, special_multimodal_mask
179
+
180
+
173
181
  def embed_mm_inputs(
174
- mm_input: MultimodalInputs,
182
+ mm_inputs: MultimodalInputs,
175
183
  input_ids: torch.Tensor,
176
184
  input_embedding: nn.Embedding,
177
- mm_data_embedding_func: Callable[[MultimodalInputs], torch.Tensor],
185
+ image_data_embedding_func: Callable[
186
+ [List[MultimodalDataItem]], torch.Tensor
187
+ ] = None,
188
+ audio_data_embedding_func: Callable[
189
+ [List[MultimodalDataItem]], torch.Tensor
190
+ ] = None,
178
191
  placeholder_token_ids: List[int] = None,
179
192
  ) -> Optional[torch.Tensor]:
180
193
  """
181
- Calculate the image embeddings if necessary, then scatter the result with
182
- the help of a boolean mask denoting the embed locations
194
+ Calculate the multimodal embeddings if necessary, then scatter the result with the help of a boolean mask denoting the embed locations
183
195
 
184
- Returns:
185
- final embedding: Optional[torch.Tensor]
196
+ Args:
197
+ placeholder_token_ids: denoting the token of multimodal data in input_ids.
198
+ If none, the pad_values of multimodal items are used
199
+
200
+ Returns:
201
+ final embedding: Optional[torch.Tensor]
186
202
  """
187
- if mm_input is None:
203
+
204
+ if mm_inputs is None:
188
205
  return None
189
206
 
190
- placeholder_token_ids = placeholder_token_ids or mm_input.pad_values
207
+ # 1. Calculate the multimodal data which exists in input_ids, with the help of pad_values
208
+ # we assume that multimodal data are represented with its pad_values in input_ids
209
+ placeholder_token_ids = placeholder_token_ids or [
210
+ item.pad_value for item in mm_inputs.mm_items
211
+ ]
191
212
 
192
- # boolean masking the special tokens
193
- special_image_mask = torch.isin(
194
- input_ids,
195
- torch.tensor(placeholder_token_ids, device=input_ids.device),
196
- ).unsqueeze(-1)
213
+ placeholder_tensor = torch.tensor(placeholder_token_ids, device=input_ids.device)
197
214
 
198
- num_image_tokens_in_input_ids = special_image_mask.sum()
199
- # print(f"{num_image_tokens_in_input_ids}")
200
- # print(f"{input_ids}")
215
+ placeholder_masks = torch.isin(input_ids, placeholder_tensor)
201
216
 
202
- # return
203
- if num_image_tokens_in_input_ids == 0:
204
- # unexpected
217
+ appearing_pad_values = torch.unique(
218
+ input_ids[placeholder_masks], return_counts=False
219
+ )
220
+
221
+ if appearing_pad_values.numel() == 0:
222
+ # all been prefixed
205
223
  inputs_embeds = input_embedding(input_ids)
206
224
  else:
207
- # print(f"Getting image feature")
208
- image_embedding = mm_data_embedding_func(mm_input)
209
-
210
- # print(f"image_embedding: {image_embedding.shape}")
225
+ appearing_items = [
226
+ item
227
+ for item in mm_inputs.mm_items
228
+ if item.pad_value is not None and item.pad_value in appearing_pad_values
229
+ ]
211
230
 
212
- if image_embedding.dim() == 2:
213
- num_image_tokens_in_embedding = image_embedding.shape[0]
214
- else:
215
- num_image_tokens_in_embedding = (
216
- image_embedding.shape[0] * image_embedding.shape[1]
231
+ using_all_items = False
232
+ if len(appearing_items) == 0:
233
+ # This happens mostly when arg placeholder_token_ids is passed
234
+ logger.warning_once(
235
+ "No multimodal data item's pad value exist in placeholder ids. Using all items"
217
236
  )
218
- if num_image_tokens_in_input_ids != num_image_tokens_in_embedding:
219
- num_image = num_image_tokens_in_input_ids // image_embedding.shape[1]
220
- image_embedding = image_embedding[:num_image, :]
221
- logger.warning(
222
- f"Number of images does not match number of special image tokens in the input text. "
223
- f"Got {num_image_tokens_in_input_ids} image tokens in the text but {num_image_tokens_in_embedding} "
224
- "tokens from image embeddings."
237
+ using_all_items = True
238
+ appearing_items = mm_inputs.mm_items
239
+
240
+ embeddings, masks = [], []
241
+
242
+ # 2. Get multimodal embedding separately
243
+ # TODO: make this more generic
244
+ # Try get image embedding if any
245
+ if (
246
+ any(True for item in appearing_items if item.is_image())
247
+ and image_data_embedding_func
248
+ ):
249
+ items = [item for item in appearing_items if item.is_image()]
250
+ embedding, mask = get_embedding_and_mask(
251
+ data_embedding_func=image_data_embedding_func,
252
+ embedding_items=items,
253
+ placeholder_tensor=(
254
+ placeholder_tensor
255
+ if using_all_items
256
+ else torch.tensor(
257
+ [item.pad_value for item in items],
258
+ device=input_ids.device,
259
+ )
260
+ ),
261
+ input_ids=input_ids,
225
262
  )
263
+ embeddings += [embedding]
264
+ masks += [mask]
226
265
 
227
- # TODO: chunked prefill will split special tokens from input_ids into several passes, failing the embedding
228
- # a fix may be cache the unfinished image embedding for future reuse, determine the tokens to embed with
229
- # extend_start_loc and extend_seq_lens
230
- if num_image_tokens_in_input_ids > num_image_tokens_in_embedding:
231
- chunked_prefill_size = global_server_args_dict["chunked_prefill_size"]
232
- if chunked_prefill_size != -1:
233
- logger.warning(
234
- "You may want to avoid this issue by raising `chunked_prefill_size`, or disabling chunked_prefill"
266
+ # Try get audio embedding if any
267
+ if (
268
+ any(True for item in appearing_items if item.is_audio())
269
+ and audio_data_embedding_func
270
+ ):
271
+ items = [item for item in appearing_items if item.is_audio()]
272
+ embedding, mask = get_embedding_and_mask(
273
+ data_embedding_func=audio_data_embedding_func,
274
+ embedding_items=items,
275
+ placeholder_tensor=(
276
+ placeholder_tensor
277
+ if using_all_items
278
+ else torch.tensor(
279
+ [item.pad_value for item in items],
280
+ device=input_ids.device,
235
281
  )
282
+ ),
283
+ input_ids=input_ids,
284
+ )
285
+ embeddings += [embedding]
286
+ masks += [mask]
236
287
 
288
+ # 3. Get input embeddings
237
289
  vocab_size = input_embedding.num_embeddings
238
- # Important: clamp after getting original image regions
239
- # Clamp input ids. This is because the input_ids for the image tokens are
240
- # filled with the hash values of the image for the prefix matching in the radix attention.
290
+ # Important: clamp after getting original multimodal regions
291
+ # Clamp input ids. This is because the input_ids for the multimodal tokens are
292
+ # filled with the hash values of the multimodal for the prefix matching in the radix attention.
241
293
  # There values are useless because their embeddings will be replaced by vision embeddings anyway.
242
294
  input_ids.clamp_(min=0, max=vocab_size - 1)
243
295
  inputs_embeds = input_embedding(input_ids)
244
296
 
245
- special_image_mask = special_image_mask.expand_as(inputs_embeds).to(
246
- inputs_embeds.device
247
- )
248
-
249
- inputs_embeds = inputs_embeds.masked_scatter(
250
- special_image_mask,
251
- image_embedding.to(inputs_embeds.device, inputs_embeds.dtype),
252
- )
253
- return inputs_embeds
254
-
255
-
256
- def embed_image_embedding(
257
- inputs_embeds: torch.Tensor,
258
- image_embedding: torch.Tensor,
259
- image_bounds: torch.Tensor,
260
- ) -> torch.Tensor:
261
- """
262
- scatter image_embedding into inputs_embeds according to image_bounds
263
- """
264
- if len(image_bounds) > 0:
265
- image_indices = torch.stack(
266
- [
267
- torch.arange(start, end, dtype=torch.long)
268
- for start, end in image_bounds.tolist()
269
- ]
270
- ).to(inputs_embeds.device)
271
-
272
- inputs_embeds.scatter_(
273
- 0,
274
- image_indices.view(-1, 1).repeat(1, inputs_embeds.shape[-1]),
275
- image_embedding.view(-1, image_embedding.shape[-1]),
276
- )
297
+ # 4. scatter embeddings into input embedding
298
+ for embedding, mask in zip(embeddings, masks):
299
+ mask = mask.expand_as(inputs_embeds).to(inputs_embeds.device)
300
+ inputs_embeds = inputs_embeds.masked_scatter(
301
+ mask,
302
+ embedding.to(inputs_embeds.device, inputs_embeds.dtype),
303
+ )
277
304
  return inputs_embeds
278
305
 
279
306
 
280
307
  def general_mm_embed_routine(
281
308
  input_ids: torch.Tensor,
282
309
  forward_batch: ForwardBatch,
283
- embed_tokens: nn.Embedding,
284
- mm_data_embedding_func: Callable[[MultimodalInputs], torch.Tensor],
310
+ language_model: nn.Module,
311
+ image_data_embedding_func: Callable[
312
+ [List[MultimodalDataItem]], torch.Tensor
313
+ ] = None,
314
+ audio_data_embedding_func: Callable[
315
+ [List[MultimodalDataItem]], torch.Tensor
316
+ ] = None,
285
317
  placeholder_token_ids: List[int] = None,
286
- ):
318
+ **kwargs,
319
+ ) -> torch.Tensor:
287
320
  """
288
- a general wrapper function to get final input embeds from multimodal models
289
- with a language model as causal model
321
+ A general wrapper function to get final input embeds from multimodal models with a language model as causal model
290
322
 
291
323
  Args:
292
324
  placeholder_token_ids (List[int]): the ids of mm data placeholder tokens
325
+ image_data_embedding_func : the function returning the image embedding
326
+ audio_data_embedding_func : the function returning the image embedding
327
+
328
+ Returns:
329
+ inputs_embedding
330
+ forwarded hidden states
293
331
 
294
332
  """
333
+
334
+ assert hasattr(language_model, "get_input_embeddings")
335
+ embed_tokens = language_model.get_input_embeddings()
295
336
  if (
296
337
  not forward_batch.forward_mode.is_decode()
297
338
  and forward_batch.contains_mm_inputs()
298
339
  ):
299
- image = forward_batch.merge_mm_inputs()
340
+ mm_input = forward_batch.merge_mm_inputs()
300
341
  inputs_embeds = embed_mm_inputs(
301
- mm_input=image,
342
+ mm_inputs=mm_input,
302
343
  input_ids=input_ids,
303
344
  input_embedding=embed_tokens,
304
- mm_data_embedding_func=mm_data_embedding_func,
345
+ image_data_embedding_func=image_data_embedding_func,
346
+ audio_data_embedding_func=audio_data_embedding_func,
305
347
  placeholder_token_ids=placeholder_token_ids,
306
348
  )
307
349
  # once used, mm_inputs is useless
@@ -310,7 +352,13 @@ def general_mm_embed_routine(
310
352
  else:
311
353
  inputs_embeds = embed_tokens(input_ids)
312
354
 
313
- return inputs_embeds
355
+ hidden_states = language_model(
356
+ input_ids=None,
357
+ forward_batch=forward_batch,
358
+ input_embeds=inputs_embeds,
359
+ **kwargs,
360
+ )
361
+ return hidden_states
314
362
 
315
363
 
316
364
  def get_multimodal_data_bounds(
@@ -322,15 +370,13 @@ def get_multimodal_data_bounds(
322
370
  Returns:
323
371
  [bounds_count, 2]
324
372
  """
325
- # All the images in the batch should share the same special image
326
- # bound token ids.
373
+ # All the multimodal data in the batch should share the same special bound token ids.
327
374
  start_tokens = [s for s, _e in token_pairs]
328
375
  end_tokens = [e for _s, e in token_pairs]
329
376
 
330
377
  assert all(isinstance(t, int) for t in start_tokens)
331
378
  assert all(isinstance(t, int) for t in end_tokens)
332
379
 
333
- # print(input_ids)
334
380
  start_cond = torch.isin(
335
381
  input_ids, torch.tensor(start_tokens, device=input_ids.device)
336
382
  )
@@ -339,7 +385,7 @@ def get_multimodal_data_bounds(
339
385
  (data_start_tokens,) = torch.where(start_cond)
340
386
  (data_end_tokens,) = torch.where(end_cond)
341
387
 
342
- # the im_start_id sometimes can be cached as prefix, but it is needed for the embedding of the images
388
+ # the im_start_id sometimes can be cached as prefix, but it is needed for the embedding of the multimodal data
343
389
  if len(data_start_tokens) != len(data_end_tokens):
344
390
  if (
345
391
  len(data_start_tokens) + 1 == len(data_end_tokens)
@@ -352,14 +398,14 @@ def get_multimodal_data_bounds(
352
398
  data_start_tokens,
353
399
  ]
354
400
  )
355
- valid_image_nums = min(len(data_start_tokens), len(data_end_tokens))
401
+ valid_mm_data_nums = min(len(data_start_tokens), len(data_end_tokens))
356
402
 
357
- if valid_image_nums == 0:
403
+ if valid_mm_data_nums == 0:
358
404
  return torch.zeros((0, 2), device=input_ids.device)
359
405
 
360
406
  # Filter out pairs where start_token >= end_token
361
407
  valid_pairs = []
362
- for i in range(valid_image_nums):
408
+ for i in range(valid_mm_data_nums):
363
409
  start_token = data_start_tokens[i]
364
410
  end_token = data_end_tokens[i]
365
411
  if start_token < end_token:
@@ -64,5 +64,3 @@ def get_mm_processor(
64
64
  f"No processor registered for architecture: {hf_config.architectures}.\n"
65
65
  f"Registered architectures: {[model_cls.__name__ for model_cls in PROCESSOR_MAPPING.keys()]}"
66
66
  )
67
-
68
- self.image_proce