sglang 0.4.4__py3-none-any.whl → 0.4.4.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (176) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +6 -0
  3. sglang/bench_one_batch.py +1 -1
  4. sglang/bench_one_batch_server.py +1 -1
  5. sglang/bench_serving.py +3 -1
  6. sglang/check_env.py +3 -4
  7. sglang/lang/backend/openai.py +18 -5
  8. sglang/lang/chat_template.py +28 -7
  9. sglang/lang/interpreter.py +7 -3
  10. sglang/lang/ir.py +10 -0
  11. sglang/srt/_custom_ops.py +1 -1
  12. sglang/srt/code_completion_parser.py +174 -0
  13. sglang/srt/configs/__init__.py +2 -6
  14. sglang/srt/configs/deepseekvl2.py +667 -0
  15. sglang/srt/configs/janus_pro.py +3 -4
  16. sglang/srt/configs/load_config.py +1 -0
  17. sglang/srt/configs/model_config.py +63 -11
  18. sglang/srt/configs/utils.py +25 -0
  19. sglang/srt/connector/__init__.py +51 -0
  20. sglang/srt/connector/base_connector.py +112 -0
  21. sglang/srt/connector/redis.py +85 -0
  22. sglang/srt/connector/s3.py +122 -0
  23. sglang/srt/connector/serde/__init__.py +31 -0
  24. sglang/srt/connector/serde/safe_serde.py +29 -0
  25. sglang/srt/connector/serde/serde.py +43 -0
  26. sglang/srt/connector/utils.py +35 -0
  27. sglang/srt/conversation.py +88 -0
  28. sglang/srt/disaggregation/conn.py +81 -0
  29. sglang/srt/disaggregation/decode.py +495 -0
  30. sglang/srt/disaggregation/mini_lb.py +285 -0
  31. sglang/srt/disaggregation/prefill.py +249 -0
  32. sglang/srt/disaggregation/utils.py +44 -0
  33. sglang/srt/distributed/parallel_state.py +10 -3
  34. sglang/srt/entrypoints/engine.py +55 -5
  35. sglang/srt/entrypoints/http_server.py +71 -12
  36. sglang/srt/function_call_parser.py +164 -54
  37. sglang/srt/hf_transformers_utils.py +28 -3
  38. sglang/srt/layers/activation.py +4 -2
  39. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  40. sglang/srt/layers/attention/flashattention_backend.py +295 -0
  41. sglang/srt/layers/attention/flashinfer_backend.py +1 -1
  42. sglang/srt/layers/attention/flashmla_backend.py +284 -0
  43. sglang/srt/layers/attention/triton_backend.py +171 -38
  44. sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
  45. sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
  46. sglang/srt/layers/attention/utils.py +53 -0
  47. sglang/srt/layers/attention/vision.py +9 -28
  48. sglang/srt/layers/dp_attention.py +62 -23
  49. sglang/srt/layers/elementwise.py +411 -0
  50. sglang/srt/layers/layernorm.py +24 -2
  51. sglang/srt/layers/linear.py +17 -5
  52. sglang/srt/layers/logits_processor.py +26 -7
  53. sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
  54. sglang/srt/layers/moe/ep_moe/layer.py +273 -1
  55. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
  56. sglang/srt/layers/moe/fused_moe_native.py +2 -1
  57. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
  62. sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
  63. sglang/srt/layers/moe/router.py +342 -0
  64. sglang/srt/layers/moe/topk.py +31 -18
  65. sglang/srt/layers/parameter.py +1 -1
  66. sglang/srt/layers/quantization/__init__.py +184 -126
  67. sglang/srt/layers/quantization/base_config.py +5 -0
  68. sglang/srt/layers/quantization/blockwise_int8.py +1 -1
  69. sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
  70. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
  71. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
  72. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
  73. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
  74. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
  75. sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
  76. sglang/srt/layers/quantization/fp8.py +76 -34
  77. sglang/srt/layers/quantization/fp8_kernel.py +24 -8
  78. sglang/srt/layers/quantization/fp8_utils.py +284 -28
  79. sglang/srt/layers/quantization/gptq.py +36 -9
  80. sglang/srt/layers/quantization/kv_cache.py +98 -0
  81. sglang/srt/layers/quantization/modelopt_quant.py +9 -7
  82. sglang/srt/layers/quantization/utils.py +153 -0
  83. sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
  84. sglang/srt/layers/rotary_embedding.py +66 -87
  85. sglang/srt/layers/sampler.py +1 -1
  86. sglang/srt/lora/layers.py +68 -0
  87. sglang/srt/lora/lora.py +2 -22
  88. sglang/srt/lora/lora_manager.py +47 -23
  89. sglang/srt/lora/mem_pool.py +110 -51
  90. sglang/srt/lora/utils.py +12 -1
  91. sglang/srt/managers/cache_controller.py +4 -5
  92. sglang/srt/managers/data_parallel_controller.py +31 -9
  93. sglang/srt/managers/expert_distribution.py +81 -0
  94. sglang/srt/managers/io_struct.py +39 -3
  95. sglang/srt/managers/mm_utils.py +373 -0
  96. sglang/srt/managers/multimodal_processor.py +68 -0
  97. sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
  98. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
  99. sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
  100. sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
  101. sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
  102. sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
  103. sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
  104. sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
  105. sglang/srt/managers/schedule_batch.py +134 -31
  106. sglang/srt/managers/scheduler.py +325 -38
  107. sglang/srt/managers/scheduler_output_processor_mixin.py +4 -1
  108. sglang/srt/managers/session_controller.py +1 -1
  109. sglang/srt/managers/tokenizer_manager.py +59 -23
  110. sglang/srt/managers/tp_worker.py +1 -1
  111. sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
  112. sglang/srt/managers/utils.py +6 -1
  113. sglang/srt/mem_cache/hiradix_cache.py +27 -8
  114. sglang/srt/mem_cache/memory_pool.py +258 -98
  115. sglang/srt/mem_cache/paged_allocator.py +2 -2
  116. sglang/srt/mem_cache/radix_cache.py +4 -4
  117. sglang/srt/model_executor/cuda_graph_runner.py +85 -28
  118. sglang/srt/model_executor/forward_batch_info.py +81 -15
  119. sglang/srt/model_executor/model_runner.py +70 -6
  120. sglang/srt/model_loader/loader.py +160 -2
  121. sglang/srt/model_loader/weight_utils.py +45 -0
  122. sglang/srt/models/deepseek_janus_pro.py +29 -86
  123. sglang/srt/models/deepseek_nextn.py +22 -10
  124. sglang/srt/models/deepseek_v2.py +326 -192
  125. sglang/srt/models/deepseek_vl2.py +358 -0
  126. sglang/srt/models/gemma3_causal.py +684 -0
  127. sglang/srt/models/gemma3_mm.py +462 -0
  128. sglang/srt/models/grok.py +374 -119
  129. sglang/srt/models/llama.py +47 -7
  130. sglang/srt/models/llama_eagle.py +1 -0
  131. sglang/srt/models/llama_eagle3.py +196 -0
  132. sglang/srt/models/llava.py +3 -3
  133. sglang/srt/models/llavavid.py +3 -3
  134. sglang/srt/models/minicpmo.py +1995 -0
  135. sglang/srt/models/minicpmv.py +62 -137
  136. sglang/srt/models/mllama.py +4 -4
  137. sglang/srt/models/phi3_small.py +1 -1
  138. sglang/srt/models/qwen2.py +3 -0
  139. sglang/srt/models/qwen2_5_vl.py +68 -146
  140. sglang/srt/models/qwen2_classification.py +75 -0
  141. sglang/srt/models/qwen2_moe.py +9 -1
  142. sglang/srt/models/qwen2_vl.py +25 -63
  143. sglang/srt/openai_api/adapter.py +145 -47
  144. sglang/srt/openai_api/protocol.py +23 -2
  145. sglang/srt/sampling/sampling_batch_info.py +1 -1
  146. sglang/srt/sampling/sampling_params.py +6 -6
  147. sglang/srt/server_args.py +104 -14
  148. sglang/srt/speculative/build_eagle_tree.py +7 -347
  149. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
  150. sglang/srt/speculative/eagle_utils.py +208 -252
  151. sglang/srt/speculative/eagle_worker.py +139 -53
  152. sglang/srt/speculative/spec_info.py +6 -1
  153. sglang/srt/torch_memory_saver_adapter.py +22 -0
  154. sglang/srt/utils.py +182 -21
  155. sglang/test/__init__.py +0 -0
  156. sglang/test/attention/__init__.py +0 -0
  157. sglang/test/attention/test_flashattn_backend.py +312 -0
  158. sglang/test/runners.py +2 -0
  159. sglang/test/test_activation.py +2 -1
  160. sglang/test/test_block_fp8.py +5 -4
  161. sglang/test/test_block_fp8_ep.py +2 -1
  162. sglang/test/test_dynamic_grad_mode.py +58 -0
  163. sglang/test/test_layernorm.py +3 -2
  164. sglang/test/test_utils.py +55 -4
  165. sglang/utils.py +31 -0
  166. sglang/version.py +1 -1
  167. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/METADATA +12 -8
  168. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/RECORD +171 -125
  169. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/WHEEL +1 -1
  170. sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
  171. sglang/srt/managers/image_processor.py +0 -55
  172. sglang/srt/managers/image_processors/base_image_processor.py +0 -219
  173. sglang/srt/managers/image_processors/minicpmv.py +0 -86
  174. sglang/srt/managers/multi_modality_padding.py +0 -134
  175. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info/licenses}/LICENSE +0 -0
  176. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,373 @@
1
+ """
2
+ Multimodality utils
3
+ """
4
+
5
+ from abc import abstractmethod
6
+ from typing import Callable, List, Optional, Tuple
7
+
8
+ import torch
9
+ from torch import nn
10
+
11
+ from sglang.srt.managers.schedule_batch import (
12
+ MultimodalInputs,
13
+ global_server_args_dict,
14
+ logger,
15
+ )
16
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
17
+ from sglang.utils import logger
18
+
19
+
20
+ class MultiModalityDataPaddingPattern:
21
+ """
22
+ Data tokens (like image tokens) often need special handling during padding
23
+ to maintain model compatibility. This class provides the interface for
24
+ implementing different padding strategies for data tokens
25
+ """
26
+
27
+ @abstractmethod
28
+ def pad_input_tokens(
29
+ self, input_ids: List[int], image_inputs: MultimodalInputs
30
+ ) -> List[int]:
31
+ """
32
+ Pad the input ids sequence containing data tokens, and replace them with pad_values
33
+ """
34
+ pass
35
+
36
+
37
+ class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern):
38
+ """In this pattern, data tokens should be enclosed by special token pairs (e.g. <image>...</image>, data_token_pairs)
39
+
40
+ This strategy should be applied when data content is marked by start/end token pairs in the input sequence.
41
+ """
42
+
43
+ def __init__(self, data_token_pairs: Optional[List[Tuple[int, int]]]) -> None:
44
+ self.data_token_id_pairs = data_token_pairs
45
+
46
+ def pad_input_tokens(
47
+ self, input_ids: List[int], mm_inputs: MultimodalInputs
48
+ ) -> List[int]:
49
+ """
50
+ This function will replace the data-tokens inbetween with pad_values accordingly
51
+ """
52
+ pad_values = mm_inputs.pad_values
53
+ data_token_pairs = self.data_token_id_pairs
54
+ mm_inputs.image_offsets = []
55
+ if data_token_pairs is None:
56
+ data_token_pairs = [mm_inputs.im_start_id, mm_inputs.im_end_id]
57
+ if data_token_pairs is None:
58
+ logger.warning(
59
+ "No data_token_pairs provided, RadixAttention might be influenced."
60
+ )
61
+ return input_ids
62
+ start_token_ids = [s for s, _e in data_token_pairs]
63
+ end_tokens_ids = [e for _s, e in data_token_pairs]
64
+
65
+ padded_ids = []
66
+ last_idx = 0
67
+ data_idx = -1
68
+
69
+ start_indices = [i for i, x in enumerate(input_ids) if x in start_token_ids]
70
+ end_indices = [i for i, x in enumerate(input_ids) if x in end_tokens_ids]
71
+
72
+ if len(start_indices) != len(end_indices):
73
+ return input_ids
74
+
75
+ for start_idx, end_idx in zip(start_indices, end_indices):
76
+ padded_ids.extend(input_ids[last_idx : start_idx + 1])
77
+
78
+ if input_ids[start_idx] in start_token_ids:
79
+ data_idx += 1
80
+ mm_inputs.image_offsets += [start_idx]
81
+
82
+ if data_idx >= len(mm_inputs.pad_values):
83
+ data_idx = len(mm_inputs.pad_values) - 1
84
+
85
+ num_tokens = end_idx - start_idx - 1
86
+ pad_value = pad_values[data_idx]
87
+ padded_ids.extend([pad_value] * num_tokens)
88
+
89
+ last_idx = end_idx
90
+
91
+ padded_ids.extend(input_ids[last_idx:])
92
+
93
+ assert len(input_ids) == len(padded_ids), "Length validation fails"
94
+ return padded_ids
95
+
96
+
97
+ class MultModalityDataPaddingPatternSingleToken(MultiModalityDataPaddingPattern):
98
+ """In this pattern, data is represented with a special token_id ( image_inputs.im_token_id ),
99
+ which needs first to be expanded to multiple tokens, then replaced with their padding values
100
+
101
+ This strategy should be used when a single data token represents content that should
102
+ be expanded to multiple tokens during processing.
103
+ """
104
+
105
+ def __init__(
106
+ self, num_data_token_calc_func: Callable[[Tuple[int, int, int]], int]
107
+ ) -> None:
108
+ self.num_data_token_calc_func = num_data_token_calc_func
109
+
110
+ def pad_input_tokens(
111
+ self, input_ids: List[int], mm_inputs: MultimodalInputs
112
+ ) -> List[int]:
113
+ """
114
+ This function will follow the procedure of:
115
+ 1. the data token will be expanded, of which the final number will be calculated by `num_data_token_calc_func`
116
+ 2. the padded data tokens will be replaced with their pad_values
117
+ """
118
+ image_grid_thws = mm_inputs.image_grid_thws
119
+ pad_values = mm_inputs.pad_values
120
+
121
+ image_indices = [
122
+ idx for idx, token in enumerate(input_ids) if token == mm_inputs.im_token_id
123
+ ]
124
+
125
+ mm_inputs.image_offsets = []
126
+
127
+ input_ids_with_image = []
128
+ for image_cnt, _ in enumerate(image_grid_thws):
129
+ # print(f"image_cnt {image_cnt}")
130
+ num_image_tokens = self.num_data_token_calc_func(image_grid_thws[image_cnt])
131
+ if image_cnt == 0:
132
+ non_image_tokens = input_ids[: image_indices[image_cnt]]
133
+ else:
134
+ non_image_tokens = input_ids[
135
+ image_indices[image_cnt - 1] + 1 : image_indices[image_cnt]
136
+ ]
137
+ input_ids_with_image.extend(non_image_tokens)
138
+ mm_inputs.image_offsets.append(len(input_ids_with_image))
139
+ pad_ids = pad_values * (
140
+ (num_image_tokens + len(pad_values)) // len(pad_values)
141
+ )
142
+ input_ids_with_image.extend(pad_ids[:num_image_tokens])
143
+ input_ids_with_image.extend(input_ids[image_indices[-1] + 1 :])
144
+
145
+ return input_ids_with_image
146
+
147
+
148
+ class MultiModalityDataPaddingPatternImageTokens(MultiModalityDataPaddingPattern):
149
+ """In this pattern, data tokens should be represented as image tokens (e.g. <image><image>....<image>)"""
150
+
151
+ def __init__(self, image_token_id: torch.Tensor) -> None:
152
+ self.image_token_id = image_token_id
153
+
154
+ def pad_input_tokens(self, input_ids: List[int], image_inputs) -> List[int]:
155
+ """
156
+ This function will replace the data-tokens in between with pad_values accordingly
157
+ """
158
+ pad_values = image_inputs.pad_values
159
+ assert len(pad_values) != 0
160
+
161
+ input_ids_tensor = torch.tensor(input_ids)
162
+ mask = torch.isin(input_ids_tensor, self.image_token_id)
163
+
164
+ num_image_tokens = mask.sum().item()
165
+ repeated_pad_values = torch.tensor(pad_values).repeat(
166
+ num_image_tokens // len(pad_values) + 1
167
+ )[:num_image_tokens]
168
+
169
+ input_ids_tensor[mask] = repeated_pad_values
170
+ return input_ids_tensor.tolist()
171
+
172
+
173
+ def embed_mm_inputs(
174
+ mm_input: MultimodalInputs,
175
+ input_ids: torch.Tensor,
176
+ input_embedding: nn.Embedding,
177
+ mm_data_embedding_func: Callable[[MultimodalInputs], torch.Tensor],
178
+ placeholder_token_ids: List[int] = None,
179
+ ) -> Optional[torch.Tensor]:
180
+ """
181
+ Calculate the image embeddings if necessary, then scatter the result with
182
+ the help of a boolean mask denoting the embed locations
183
+
184
+ Returns:
185
+ final embedding: Optional[torch.Tensor]
186
+ """
187
+ if mm_input is None:
188
+ return None
189
+
190
+ placeholder_token_ids = placeholder_token_ids or mm_input.pad_values
191
+
192
+ # boolean masking the special tokens
193
+ special_image_mask = torch.isin(
194
+ input_ids,
195
+ torch.tensor(placeholder_token_ids, device=input_ids.device),
196
+ ).unsqueeze(-1)
197
+
198
+ num_image_tokens_in_input_ids = special_image_mask.sum()
199
+ # print(f"{num_image_tokens_in_input_ids}")
200
+ # print(f"{input_ids}")
201
+
202
+ # return
203
+ if num_image_tokens_in_input_ids == 0:
204
+ # unexpected
205
+ inputs_embeds = input_embedding(input_ids)
206
+ else:
207
+ # print(f"Getting image feature")
208
+ image_embedding = mm_data_embedding_func(mm_input)
209
+
210
+ # print(f"image_embedding: {image_embedding.shape}")
211
+
212
+ if image_embedding.dim() == 2:
213
+ num_image_tokens_in_embedding = image_embedding.shape[0]
214
+ else:
215
+ num_image_tokens_in_embedding = (
216
+ image_embedding.shape[0] * image_embedding.shape[1]
217
+ )
218
+ if num_image_tokens_in_input_ids != num_image_tokens_in_embedding:
219
+ num_image = num_image_tokens_in_input_ids // image_embedding.shape[1]
220
+ image_embedding = image_embedding[:num_image, :]
221
+ logger.warning(
222
+ f"Number of images does not match number of special image tokens in the input text. "
223
+ f"Got {num_image_tokens_in_input_ids} image tokens in the text but {num_image_tokens_in_embedding} "
224
+ "tokens from image embeddings."
225
+ )
226
+
227
+ # TODO: chunked prefill will split special tokens from input_ids into several passes, failing the embedding
228
+ # a fix may be cache the unfinished image embedding for future reuse, determine the tokens to embed with
229
+ # extend_start_loc and extend_seq_lens
230
+ if num_image_tokens_in_input_ids > num_image_tokens_in_embedding:
231
+ chunked_prefill_size = global_server_args_dict["chunked_prefill_size"]
232
+ if chunked_prefill_size != -1:
233
+ logger.warning(
234
+ "You may want to avoid this issue by raising `chunked_prefill_size`, or disabling chunked_prefill"
235
+ )
236
+
237
+ vocab_size = input_embedding.num_embeddings
238
+ # Important: clamp after getting original image regions
239
+ # Clamp input ids. This is because the input_ids for the image tokens are
240
+ # filled with the hash values of the image for the prefix matching in the radix attention.
241
+ # There values are useless because their embeddings will be replaced by vision embeddings anyway.
242
+ input_ids.clamp_(min=0, max=vocab_size - 1)
243
+ inputs_embeds = input_embedding(input_ids)
244
+
245
+ special_image_mask = special_image_mask.expand_as(inputs_embeds).to(
246
+ inputs_embeds.device
247
+ )
248
+
249
+ inputs_embeds = inputs_embeds.masked_scatter(
250
+ special_image_mask,
251
+ image_embedding.to(inputs_embeds.device, inputs_embeds.dtype),
252
+ )
253
+ return inputs_embeds
254
+
255
+
256
+ def embed_image_embedding(
257
+ inputs_embeds: torch.Tensor,
258
+ image_embedding: torch.Tensor,
259
+ image_bounds: torch.Tensor,
260
+ ) -> torch.Tensor:
261
+ """
262
+ scatter image_embedding into inputs_embeds according to image_bounds
263
+ """
264
+ if len(image_bounds) > 0:
265
+ image_indices = torch.stack(
266
+ [
267
+ torch.arange(start, end, dtype=torch.long)
268
+ for start, end in image_bounds.tolist()
269
+ ]
270
+ ).to(inputs_embeds.device)
271
+
272
+ inputs_embeds.scatter_(
273
+ 0,
274
+ image_indices.view(-1, 1).repeat(1, inputs_embeds.shape[-1]),
275
+ image_embedding.view(-1, image_embedding.shape[-1]),
276
+ )
277
+ return inputs_embeds
278
+
279
+
280
+ def general_mm_embed_routine(
281
+ input_ids: torch.Tensor,
282
+ forward_batch: ForwardBatch,
283
+ embed_tokens: nn.Embedding,
284
+ mm_data_embedding_func: Callable[[MultimodalInputs], torch.Tensor],
285
+ placeholder_token_ids: List[int] = None,
286
+ ):
287
+ """
288
+ a general wrapper function to get final input embeds from multimodal models
289
+ with a language model as causal model
290
+
291
+ Args:
292
+ placeholder_token_ids (List[int]): the ids of mm data placeholder tokens
293
+
294
+ """
295
+ if (
296
+ not forward_batch.forward_mode.is_decode()
297
+ and forward_batch.contains_mm_inputs()
298
+ ):
299
+ image = forward_batch.merge_mm_inputs()
300
+ inputs_embeds = embed_mm_inputs(
301
+ mm_input=image,
302
+ input_ids=input_ids,
303
+ input_embedding=embed_tokens,
304
+ mm_data_embedding_func=mm_data_embedding_func,
305
+ placeholder_token_ids=placeholder_token_ids,
306
+ )
307
+ # once used, mm_inputs is useless
308
+ # just being defensive here
309
+ forward_batch.mm_inputs = None
310
+ else:
311
+ inputs_embeds = embed_tokens(input_ids)
312
+
313
+ return inputs_embeds
314
+
315
+
316
+ def get_multimodal_data_bounds(
317
+ input_ids: torch.Tensor, pad_values: List[int], token_pairs: List[Tuple[int, int]]
318
+ ) -> torch.Tensor:
319
+ """
320
+ Returns a tensor indicating the bounds of multimodal data (images, video, audio, etc.)
321
+
322
+ Returns:
323
+ [bounds_count, 2]
324
+ """
325
+ # All the images in the batch should share the same special image
326
+ # bound token ids.
327
+ start_tokens = [s for s, _e in token_pairs]
328
+ end_tokens = [e for _s, e in token_pairs]
329
+
330
+ assert all(isinstance(t, int) for t in start_tokens)
331
+ assert all(isinstance(t, int) for t in end_tokens)
332
+
333
+ # print(input_ids)
334
+ start_cond = torch.isin(
335
+ input_ids, torch.tensor(start_tokens, device=input_ids.device)
336
+ )
337
+ end_cond = torch.isin(input_ids, torch.tensor(end_tokens, device=input_ids.device))
338
+
339
+ (data_start_tokens,) = torch.where(start_cond)
340
+ (data_end_tokens,) = torch.where(end_cond)
341
+
342
+ # the im_start_id sometimes can be cached as prefix, but it is needed for the embedding of the images
343
+ if len(data_start_tokens) != len(data_end_tokens):
344
+ if (
345
+ len(data_start_tokens) + 1 == len(data_end_tokens)
346
+ and input_ids[0] in pad_values
347
+ and data_end_tokens[0] < data_start_tokens[0]
348
+ ):
349
+ data_start_tokens = torch.cat(
350
+ [
351
+ torch.tensor([0], device=data_start_tokens.device),
352
+ data_start_tokens,
353
+ ]
354
+ )
355
+ valid_image_nums = min(len(data_start_tokens), len(data_end_tokens))
356
+
357
+ if valid_image_nums == 0:
358
+ return torch.zeros((0, 2), device=input_ids.device)
359
+
360
+ # Filter out pairs where start_token >= end_token
361
+ valid_pairs = []
362
+ for i in range(valid_image_nums):
363
+ start_token = data_start_tokens[i]
364
+ end_token = data_end_tokens[i]
365
+ if start_token < end_token:
366
+ valid_pairs.append((start_token + 1, end_token - 1))
367
+
368
+ if not valid_pairs:
369
+ return torch.zeros((0, 2), device=input_ids.device)
370
+
371
+ # Convert valid pairs to tensor
372
+ valid_pairs_tensor = torch.tensor(valid_pairs, device=input_ids.device)
373
+ return valid_pairs_tensor
@@ -0,0 +1,68 @@
1
+ # TODO: also move pad_input_ids into this module
2
+ import importlib
3
+ import inspect
4
+ import logging
5
+ import pkgutil
6
+ from functools import lru_cache
7
+
8
+ from transformers import PROCESSOR_MAPPING
9
+
10
+ from sglang.srt.managers.multimodal_processors.base_processor import (
11
+ BaseMultimodalProcessor,
12
+ )
13
+ from sglang.srt.server_args import ServerArgs
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+ PROCESSOR_MAPPING = {}
18
+
19
+
20
+ class DummyMultimodalProcessor(BaseMultimodalProcessor):
21
+ def __init__(self):
22
+ pass
23
+
24
+ async def process_mm_data_async(self, *args, **kwargs):
25
+ return None
26
+
27
+
28
+ def get_dummy_processor():
29
+ return DummyMultimodalProcessor()
30
+
31
+
32
+ @lru_cache()
33
+ def import_processors():
34
+ package_name = "sglang.srt.managers.multimodal_processors"
35
+ package = importlib.import_module(package_name)
36
+ for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."):
37
+ if not ispkg:
38
+ try:
39
+ module = importlib.import_module(name)
40
+ except Exception as e:
41
+ logger.warning(f"Ignore import error when loading {name}: " f"{e}")
42
+ continue
43
+ all_members = inspect.getmembers(module, inspect.isclass)
44
+ classes = [
45
+ member
46
+ for name, member in all_members
47
+ if member.__module__ == module.__name__
48
+ ]
49
+ for cls in (
50
+ cls for cls in classes if issubclass(cls, BaseMultimodalProcessor)
51
+ ):
52
+ assert hasattr(cls, "models")
53
+ for arch in getattr(cls, "models"):
54
+ PROCESSOR_MAPPING[arch] = cls
55
+
56
+
57
+ def get_mm_processor(
58
+ hf_config, server_args: ServerArgs, processor
59
+ ) -> BaseMultimodalProcessor:
60
+ for model_cls, processor_cls in PROCESSOR_MAPPING.items():
61
+ if model_cls.__name__ in hf_config.architectures:
62
+ return processor_cls(hf_config, server_args, processor)
63
+ raise ValueError(
64
+ f"No processor registered for architecture: {hf_config.architectures}.\n"
65
+ f"Registered architectures: {[model_cls.__name__ for model_cls in PROCESSOR_MAPPING.keys()]}"
66
+ )
67
+
68
+ self.image_proce