sglang 0.4.4.post1__py3-none-any.whl → 0.4.4.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (185) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +6 -0
  3. sglang/bench_one_batch.py +1 -1
  4. sglang/bench_one_batch_server.py +1 -1
  5. sglang/bench_serving.py +26 -4
  6. sglang/check_env.py +3 -4
  7. sglang/lang/backend/openai.py +18 -5
  8. sglang/lang/chat_template.py +28 -7
  9. sglang/lang/interpreter.py +7 -3
  10. sglang/lang/ir.py +10 -0
  11. sglang/srt/_custom_ops.py +1 -1
  12. sglang/srt/code_completion_parser.py +174 -0
  13. sglang/srt/configs/__init__.py +2 -6
  14. sglang/srt/configs/deepseekvl2.py +676 -0
  15. sglang/srt/configs/janus_pro.py +3 -4
  16. sglang/srt/configs/load_config.py +1 -0
  17. sglang/srt/configs/model_config.py +49 -8
  18. sglang/srt/configs/utils.py +25 -0
  19. sglang/srt/connector/__init__.py +51 -0
  20. sglang/srt/connector/base_connector.py +112 -0
  21. sglang/srt/connector/redis.py +85 -0
  22. sglang/srt/connector/s3.py +122 -0
  23. sglang/srt/connector/serde/__init__.py +31 -0
  24. sglang/srt/connector/serde/safe_serde.py +29 -0
  25. sglang/srt/connector/serde/serde.py +43 -0
  26. sglang/srt/connector/utils.py +35 -0
  27. sglang/srt/conversation.py +88 -0
  28. sglang/srt/disaggregation/conn.py +81 -0
  29. sglang/srt/disaggregation/decode.py +495 -0
  30. sglang/srt/disaggregation/mini_lb.py +285 -0
  31. sglang/srt/disaggregation/prefill.py +249 -0
  32. sglang/srt/disaggregation/utils.py +44 -0
  33. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
  34. sglang/srt/distributed/parallel_state.py +42 -8
  35. sglang/srt/entrypoints/engine.py +55 -5
  36. sglang/srt/entrypoints/http_server.py +78 -13
  37. sglang/srt/entrypoints/verl_engine.py +2 -0
  38. sglang/srt/function_call_parser.py +133 -55
  39. sglang/srt/hf_transformers_utils.py +28 -3
  40. sglang/srt/layers/activation.py +4 -2
  41. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  42. sglang/srt/layers/attention/flashattention_backend.py +434 -0
  43. sglang/srt/layers/attention/flashinfer_backend.py +1 -1
  44. sglang/srt/layers/attention/flashmla_backend.py +284 -0
  45. sglang/srt/layers/attention/triton_backend.py +171 -38
  46. sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
  47. sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
  48. sglang/srt/layers/attention/utils.py +53 -0
  49. sglang/srt/layers/attention/vision.py +9 -28
  50. sglang/srt/layers/dp_attention.py +41 -19
  51. sglang/srt/layers/layernorm.py +24 -2
  52. sglang/srt/layers/linear.py +17 -5
  53. sglang/srt/layers/logits_processor.py +25 -7
  54. sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
  55. sglang/srt/layers/moe/ep_moe/layer.py +273 -1
  56. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
  57. sglang/srt/layers/moe/fused_moe_native.py +2 -1
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  62. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
  63. sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
  64. sglang/srt/layers/moe/topk.py +60 -20
  65. sglang/srt/layers/parameter.py +1 -1
  66. sglang/srt/layers/quantization/__init__.py +80 -53
  67. sglang/srt/layers/quantization/awq.py +200 -0
  68. sglang/srt/layers/quantization/base_config.py +5 -0
  69. sglang/srt/layers/quantization/blockwise_int8.py +1 -1
  70. sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
  71. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
  72. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
  73. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
  74. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
  75. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
  76. sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
  77. sglang/srt/layers/quantization/fp8.py +76 -34
  78. sglang/srt/layers/quantization/fp8_kernel.py +25 -8
  79. sglang/srt/layers/quantization/fp8_utils.py +284 -28
  80. sglang/srt/layers/quantization/gptq.py +36 -19
  81. sglang/srt/layers/quantization/kv_cache.py +98 -0
  82. sglang/srt/layers/quantization/modelopt_quant.py +9 -7
  83. sglang/srt/layers/quantization/utils.py +153 -0
  84. sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
  85. sglang/srt/layers/rotary_embedding.py +78 -87
  86. sglang/srt/layers/sampler.py +1 -1
  87. sglang/srt/lora/backend/base_backend.py +4 -4
  88. sglang/srt/lora/backend/flashinfer_backend.py +12 -9
  89. sglang/srt/lora/backend/triton_backend.py +5 -8
  90. sglang/srt/lora/layers.py +87 -33
  91. sglang/srt/lora/lora.py +2 -22
  92. sglang/srt/lora/lora_manager.py +67 -30
  93. sglang/srt/lora/mem_pool.py +117 -52
  94. sglang/srt/lora/triton_ops/gate_up_lora_b.py +10 -4
  95. sglang/srt/lora/triton_ops/qkv_lora_b.py +8 -3
  96. sglang/srt/lora/triton_ops/sgemm_lora_a.py +16 -5
  97. sglang/srt/lora/triton_ops/sgemm_lora_b.py +11 -6
  98. sglang/srt/lora/utils.py +18 -1
  99. sglang/srt/managers/cache_controller.py +2 -5
  100. sglang/srt/managers/data_parallel_controller.py +30 -8
  101. sglang/srt/managers/expert_distribution.py +81 -0
  102. sglang/srt/managers/io_struct.py +43 -5
  103. sglang/srt/managers/mm_utils.py +373 -0
  104. sglang/srt/managers/multimodal_processor.py +68 -0
  105. sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
  106. sglang/srt/managers/multimodal_processors/clip.py +63 -0
  107. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
  108. sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
  109. sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
  110. sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
  111. sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
  112. sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
  113. sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
  114. sglang/srt/managers/schedule_batch.py +134 -30
  115. sglang/srt/managers/scheduler.py +290 -31
  116. sglang/srt/managers/session_controller.py +1 -1
  117. sglang/srt/managers/tokenizer_manager.py +59 -24
  118. sglang/srt/managers/tp_worker.py +4 -1
  119. sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
  120. sglang/srt/managers/utils.py +6 -1
  121. sglang/srt/mem_cache/hiradix_cache.py +18 -7
  122. sglang/srt/mem_cache/memory_pool.py +255 -98
  123. sglang/srt/mem_cache/paged_allocator.py +2 -2
  124. sglang/srt/mem_cache/radix_cache.py +4 -4
  125. sglang/srt/model_executor/cuda_graph_runner.py +36 -21
  126. sglang/srt/model_executor/forward_batch_info.py +68 -11
  127. sglang/srt/model_executor/model_runner.py +75 -8
  128. sglang/srt/model_loader/loader.py +171 -3
  129. sglang/srt/model_loader/weight_utils.py +51 -3
  130. sglang/srt/models/clip.py +563 -0
  131. sglang/srt/models/deepseek_janus_pro.py +31 -88
  132. sglang/srt/models/deepseek_nextn.py +22 -10
  133. sglang/srt/models/deepseek_v2.py +329 -73
  134. sglang/srt/models/deepseek_vl2.py +358 -0
  135. sglang/srt/models/gemma3_causal.py +694 -0
  136. sglang/srt/models/gemma3_mm.py +468 -0
  137. sglang/srt/models/llama.py +47 -7
  138. sglang/srt/models/llama_eagle.py +1 -0
  139. sglang/srt/models/llama_eagle3.py +196 -0
  140. sglang/srt/models/llava.py +3 -3
  141. sglang/srt/models/llavavid.py +3 -3
  142. sglang/srt/models/minicpmo.py +1995 -0
  143. sglang/srt/models/minicpmv.py +62 -137
  144. sglang/srt/models/mllama.py +4 -4
  145. sglang/srt/models/phi3_small.py +1 -1
  146. sglang/srt/models/qwen2.py +3 -0
  147. sglang/srt/models/qwen2_5_vl.py +68 -146
  148. sglang/srt/models/qwen2_classification.py +75 -0
  149. sglang/srt/models/qwen2_moe.py +9 -1
  150. sglang/srt/models/qwen2_vl.py +25 -63
  151. sglang/srt/openai_api/adapter.py +201 -104
  152. sglang/srt/openai_api/protocol.py +33 -7
  153. sglang/srt/patch_torch.py +71 -0
  154. sglang/srt/sampling/sampling_batch_info.py +1 -1
  155. sglang/srt/sampling/sampling_params.py +6 -6
  156. sglang/srt/server_args.py +114 -14
  157. sglang/srt/speculative/build_eagle_tree.py +7 -347
  158. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
  159. sglang/srt/speculative/eagle_utils.py +208 -252
  160. sglang/srt/speculative/eagle_worker.py +140 -54
  161. sglang/srt/speculative/spec_info.py +6 -1
  162. sglang/srt/torch_memory_saver_adapter.py +22 -0
  163. sglang/srt/utils.py +215 -21
  164. sglang/test/__init__.py +0 -0
  165. sglang/test/attention/__init__.py +0 -0
  166. sglang/test/attention/test_flashattn_backend.py +312 -0
  167. sglang/test/runners.py +29 -2
  168. sglang/test/test_activation.py +2 -1
  169. sglang/test/test_block_fp8.py +5 -4
  170. sglang/test/test_block_fp8_ep.py +2 -1
  171. sglang/test/test_dynamic_grad_mode.py +58 -0
  172. sglang/test/test_layernorm.py +3 -2
  173. sglang/test/test_utils.py +56 -5
  174. sglang/utils.py +31 -0
  175. sglang/version.py +1 -1
  176. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/METADATA +16 -8
  177. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/RECORD +180 -132
  178. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/WHEEL +1 -1
  179. sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
  180. sglang/srt/managers/image_processor.py +0 -55
  181. sglang/srt/managers/image_processors/base_image_processor.py +0 -219
  182. sglang/srt/managers/image_processors/minicpmv.py +0 -86
  183. sglang/srt/managers/multi_modality_padding.py +0 -134
  184. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info/licenses}/LICENSE +0 -0
  185. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/top_level.txt +0 -0
@@ -20,7 +20,7 @@ import copy
20
20
  import uuid
21
21
  from dataclasses import dataclass, field
22
22
  from enum import Enum
23
- from typing import Any, Dict, List, Optional, Union
23
+ from typing import Any, Dict, List, Literal, Optional, Union
24
24
 
25
25
  from sglang.srt.managers.schedule_batch import BaseFinishReason
26
26
  from sglang.srt.sampling.sampling_params import SamplingParams
@@ -45,6 +45,8 @@ class GenerateReqInput:
45
45
  # The image input. It can be a file name, a url, or base64 encoded string.
46
46
  # See also python/sglang/srt/utils.py:load_image.
47
47
  image_data: Optional[Union[List[str], str]] = None
48
+ # The audio input. Like image data, tt can be a file name, a url, or base64 encoded string.
49
+ audio_data: Optional[Union[List[str], str]] = None
48
50
  # The sampling_params. See descriptions below.
49
51
  sampling_params: Optional[Union[List[Dict], Dict]] = None
50
52
  # The request id.
@@ -103,6 +105,8 @@ class GenerateReqInput:
103
105
  self.batch_size = len(self.text)
104
106
  self.input_embeds = None
105
107
  elif self.input_ids is not None:
108
+ if len(self.input_ids) == 0:
109
+ raise ValueError("input_ids cannot be empty.")
106
110
  if isinstance(self.input_ids[0], int):
107
111
  self.is_single = True
108
112
  self.batch_size = 1
@@ -165,6 +169,13 @@ class GenerateReqInput:
165
169
  elif isinstance(self.image_data, list):
166
170
  pass
167
171
 
172
+ if self.audio_data is None:
173
+ self.audio_data = [None] * num
174
+ elif not isinstance(self.audio_data, list):
175
+ self.audio_data = [self.audio_data] * num
176
+ elif isinstance(self.audio_data, list):
177
+ pass
178
+
168
179
  if self.sampling_params is None:
169
180
  self.sampling_params = [{}] * num
170
181
  elif not isinstance(self.sampling_params, list):
@@ -229,6 +240,7 @@ class GenerateReqInput:
229
240
  text=self.text[i] if self.text is not None else None,
230
241
  input_ids=self.input_ids[i] if self.input_ids is not None else None,
231
242
  image_data=self.image_data[i],
243
+ audio_data=self.audio_data[i],
232
244
  sampling_params=self.sampling_params[i],
233
245
  rid=self.rid[i],
234
246
  return_logprob=self.return_logprob[i],
@@ -257,8 +269,8 @@ class TokenizedGenerateReqInput:
257
269
  input_text: str
258
270
  # The input token ids
259
271
  input_ids: List[int]
260
- # The image inputs
261
- image_inputs: dict
272
+ # The multimodal inputs
273
+ mm_inputs: dict
262
274
  # The sampling parameters
263
275
  sampling_params: SamplingParams
264
276
  # Whether to return the logprobs
@@ -538,7 +550,8 @@ class UpdateWeightsFromDistributedReqOutput:
538
550
 
539
551
  @dataclass
540
552
  class UpdateWeightsFromTensorReqInput:
541
- serialized_named_tensors: bytes # indeed Dict[str, torch.Tensor]
553
+ # List containing one serialized Dict[str, torch.Tensor] per TP worker
554
+ serialized_named_tensors: List[bytes]
542
555
  load_format: Optional[str]
543
556
  flush_cache: bool
544
557
 
@@ -637,7 +650,7 @@ class ProfileReqInput:
637
650
  # If it is set, profiling is automatically stopped after this step, and
638
651
  # the caller doesn't need to run stop_profile.
639
652
  num_steps: Optional[int] = None
640
- activities: Optional[List[str]] = None
653
+ activities: Optional[List[Literal["CPU", "GPU", "MEM", "CUDA_PROFILER"]]] = None
641
654
 
642
655
 
643
656
  class ProfileReqType(Enum):
@@ -645,12 +658,25 @@ class ProfileReqType(Enum):
645
658
  STOP_PROFILE = 2
646
659
 
647
660
 
661
+ class ExpertDistributionReq(Enum):
662
+ START_RECORD = 1
663
+ STOP_RECORD = 2
664
+ DUMP_RECORD = 3
665
+
666
+
667
+ @dataclass
668
+ class ExpertDistributionReqOutput:
669
+ pass
670
+
671
+
648
672
  @dataclass
649
673
  class ProfileReq:
650
674
  type: ProfileReqType
651
675
  output_dir: Optional[str] = None
652
676
  num_steps: Optional[int] = None
653
677
  activities: Optional[List[str]] = None
678
+ with_stack: Optional[bool] = None
679
+ record_shapes: Optional[bool] = None
654
680
 
655
681
 
656
682
  @dataclass
@@ -723,3 +749,15 @@ class SeparateReasoningReqInput:
723
749
  class VertexGenerateReqInput:
724
750
  instances: List[dict]
725
751
  parameters: Optional[dict] = None
752
+
753
+
754
+ @dataclass
755
+ class RpcReqInput:
756
+ method: str
757
+ parameters: Optional[Dict] = None
758
+
759
+
760
+ @dataclass
761
+ class RpcReqOutput:
762
+ success: bool
763
+ message: str
@@ -0,0 +1,373 @@
1
+ """
2
+ Multimodality utils
3
+ """
4
+
5
+ from abc import abstractmethod
6
+ from typing import Callable, List, Optional, Tuple
7
+
8
+ import torch
9
+ from torch import nn
10
+
11
+ from sglang.srt.managers.schedule_batch import (
12
+ MultimodalInputs,
13
+ global_server_args_dict,
14
+ logger,
15
+ )
16
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
17
+ from sglang.utils import logger
18
+
19
+
20
+ class MultiModalityDataPaddingPattern:
21
+ """
22
+ Data tokens (like image tokens) often need special handling during padding
23
+ to maintain model compatibility. This class provides the interface for
24
+ implementing different padding strategies for data tokens
25
+ """
26
+
27
+ @abstractmethod
28
+ def pad_input_tokens(
29
+ self, input_ids: List[int], image_inputs: MultimodalInputs
30
+ ) -> List[int]:
31
+ """
32
+ Pad the input ids sequence containing data tokens, and replace them with pad_values
33
+ """
34
+ pass
35
+
36
+
37
+ class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern):
38
+ """In this pattern, data tokens should be enclosed by special token pairs (e.g. <image>...</image>, data_token_pairs)
39
+
40
+ This strategy should be applied when data content is marked by start/end token pairs in the input sequence.
41
+ """
42
+
43
+ def __init__(self, data_token_pairs: Optional[List[Tuple[int, int]]]) -> None:
44
+ self.data_token_id_pairs = data_token_pairs
45
+
46
+ def pad_input_tokens(
47
+ self, input_ids: List[int], mm_inputs: MultimodalInputs
48
+ ) -> List[int]:
49
+ """
50
+ This function will replace the data-tokens inbetween with pad_values accordingly
51
+ """
52
+ pad_values = mm_inputs.pad_values
53
+ data_token_pairs = self.data_token_id_pairs
54
+ mm_inputs.image_offsets = []
55
+ if data_token_pairs is None:
56
+ data_token_pairs = [mm_inputs.im_start_id, mm_inputs.im_end_id]
57
+ if data_token_pairs is None:
58
+ logger.warning(
59
+ "No data_token_pairs provided, RadixAttention might be influenced."
60
+ )
61
+ return input_ids
62
+ start_token_ids = [s for s, _e in data_token_pairs]
63
+ end_tokens_ids = [e for _s, e in data_token_pairs]
64
+
65
+ padded_ids = []
66
+ last_idx = 0
67
+ data_idx = -1
68
+
69
+ start_indices = [i for i, x in enumerate(input_ids) if x in start_token_ids]
70
+ end_indices = [i for i, x in enumerate(input_ids) if x in end_tokens_ids]
71
+
72
+ if len(start_indices) != len(end_indices):
73
+ return input_ids
74
+
75
+ for start_idx, end_idx in zip(start_indices, end_indices):
76
+ padded_ids.extend(input_ids[last_idx : start_idx + 1])
77
+
78
+ if input_ids[start_idx] in start_token_ids:
79
+ data_idx += 1
80
+ mm_inputs.image_offsets += [start_idx]
81
+
82
+ if data_idx >= len(mm_inputs.pad_values):
83
+ data_idx = len(mm_inputs.pad_values) - 1
84
+
85
+ num_tokens = end_idx - start_idx - 1
86
+ pad_value = pad_values[data_idx]
87
+ padded_ids.extend([pad_value] * num_tokens)
88
+
89
+ last_idx = end_idx
90
+
91
+ padded_ids.extend(input_ids[last_idx:])
92
+
93
+ assert len(input_ids) == len(padded_ids), "Length validation fails"
94
+ return padded_ids
95
+
96
+
97
+ class MultModalityDataPaddingPatternSingleToken(MultiModalityDataPaddingPattern):
98
+ """In this pattern, data is represented with a special token_id ( image_inputs.im_token_id ),
99
+ which needs first to be expanded to multiple tokens, then replaced with their padding values
100
+
101
+ This strategy should be used when a single data token represents content that should
102
+ be expanded to multiple tokens during processing.
103
+ """
104
+
105
+ def __init__(
106
+ self, num_data_token_calc_func: Callable[[Tuple[int, int, int]], int]
107
+ ) -> None:
108
+ self.num_data_token_calc_func = num_data_token_calc_func
109
+
110
+ def pad_input_tokens(
111
+ self, input_ids: List[int], mm_inputs: MultimodalInputs
112
+ ) -> List[int]:
113
+ """
114
+ This function will follow the procedure of:
115
+ 1. the data token will be expanded, of which the final number will be calculated by `num_data_token_calc_func`
116
+ 2. the padded data tokens will be replaced with their pad_values
117
+ """
118
+ image_grid_thws = mm_inputs.image_grid_thws
119
+ pad_values = mm_inputs.pad_values
120
+
121
+ image_indices = [
122
+ idx for idx, token in enumerate(input_ids) if token == mm_inputs.im_token_id
123
+ ]
124
+
125
+ mm_inputs.image_offsets = []
126
+
127
+ input_ids_with_image = []
128
+ for image_cnt, _ in enumerate(image_grid_thws):
129
+ # print(f"image_cnt {image_cnt}")
130
+ num_image_tokens = self.num_data_token_calc_func(image_grid_thws[image_cnt])
131
+ if image_cnt == 0:
132
+ non_image_tokens = input_ids[: image_indices[image_cnt]]
133
+ else:
134
+ non_image_tokens = input_ids[
135
+ image_indices[image_cnt - 1] + 1 : image_indices[image_cnt]
136
+ ]
137
+ input_ids_with_image.extend(non_image_tokens)
138
+ mm_inputs.image_offsets.append(len(input_ids_with_image))
139
+ pad_ids = pad_values * (
140
+ (num_image_tokens + len(pad_values)) // len(pad_values)
141
+ )
142
+ input_ids_with_image.extend(pad_ids[:num_image_tokens])
143
+ input_ids_with_image.extend(input_ids[image_indices[-1] + 1 :])
144
+
145
+ return input_ids_with_image
146
+
147
+
148
+ class MultiModalityDataPaddingPatternImageTokens(MultiModalityDataPaddingPattern):
149
+ """In this pattern, data tokens should be represented as image tokens (e.g. <image><image>....<image>)"""
150
+
151
+ def __init__(self, image_token_id: torch.Tensor) -> None:
152
+ self.image_token_id = image_token_id
153
+
154
+ def pad_input_tokens(self, input_ids: List[int], image_inputs) -> List[int]:
155
+ """
156
+ This function will replace the data-tokens in between with pad_values accordingly
157
+ """
158
+ pad_values = image_inputs.pad_values
159
+ assert len(pad_values) != 0
160
+
161
+ input_ids_tensor = torch.tensor(input_ids)
162
+ mask = torch.isin(input_ids_tensor, self.image_token_id)
163
+
164
+ num_image_tokens = mask.sum().item()
165
+ repeated_pad_values = torch.tensor(pad_values).repeat(
166
+ num_image_tokens // len(pad_values) + 1
167
+ )[:num_image_tokens]
168
+
169
+ input_ids_tensor[mask] = repeated_pad_values
170
+ return input_ids_tensor.tolist()
171
+
172
+
173
+ def embed_mm_inputs(
174
+ mm_input: MultimodalInputs,
175
+ input_ids: torch.Tensor,
176
+ input_embedding: nn.Embedding,
177
+ mm_data_embedding_func: Callable[[MultimodalInputs], torch.Tensor],
178
+ placeholder_token_ids: List[int] = None,
179
+ ) -> Optional[torch.Tensor]:
180
+ """
181
+ Calculate the image embeddings if necessary, then scatter the result with
182
+ the help of a boolean mask denoting the embed locations
183
+
184
+ Returns:
185
+ final embedding: Optional[torch.Tensor]
186
+ """
187
+ if mm_input is None:
188
+ return None
189
+
190
+ placeholder_token_ids = placeholder_token_ids or mm_input.pad_values
191
+
192
+ # boolean masking the special tokens
193
+ special_image_mask = torch.isin(
194
+ input_ids,
195
+ torch.tensor(placeholder_token_ids, device=input_ids.device),
196
+ ).unsqueeze(-1)
197
+
198
+ num_image_tokens_in_input_ids = special_image_mask.sum()
199
+ # print(f"{num_image_tokens_in_input_ids}")
200
+ # print(f"{input_ids}")
201
+
202
+ # return
203
+ if num_image_tokens_in_input_ids == 0:
204
+ # unexpected
205
+ inputs_embeds = input_embedding(input_ids)
206
+ else:
207
+ # print(f"Getting image feature")
208
+ image_embedding = mm_data_embedding_func(mm_input)
209
+
210
+ # print(f"image_embedding: {image_embedding.shape}")
211
+
212
+ if image_embedding.dim() == 2:
213
+ num_image_tokens_in_embedding = image_embedding.shape[0]
214
+ else:
215
+ num_image_tokens_in_embedding = (
216
+ image_embedding.shape[0] * image_embedding.shape[1]
217
+ )
218
+ if num_image_tokens_in_input_ids != num_image_tokens_in_embedding:
219
+ num_image = num_image_tokens_in_input_ids // image_embedding.shape[1]
220
+ image_embedding = image_embedding[:num_image, :]
221
+ logger.warning(
222
+ f"Number of images does not match number of special image tokens in the input text. "
223
+ f"Got {num_image_tokens_in_input_ids} image tokens in the text but {num_image_tokens_in_embedding} "
224
+ "tokens from image embeddings."
225
+ )
226
+
227
+ # TODO: chunked prefill will split special tokens from input_ids into several passes, failing the embedding
228
+ # a fix may be cache the unfinished image embedding for future reuse, determine the tokens to embed with
229
+ # extend_start_loc and extend_seq_lens
230
+ if num_image_tokens_in_input_ids > num_image_tokens_in_embedding:
231
+ chunked_prefill_size = global_server_args_dict["chunked_prefill_size"]
232
+ if chunked_prefill_size != -1:
233
+ logger.warning(
234
+ "You may want to avoid this issue by raising `chunked_prefill_size`, or disabling chunked_prefill"
235
+ )
236
+
237
+ vocab_size = input_embedding.num_embeddings
238
+ # Important: clamp after getting original image regions
239
+ # Clamp input ids. This is because the input_ids for the image tokens are
240
+ # filled with the hash values of the image for the prefix matching in the radix attention.
241
+ # There values are useless because their embeddings will be replaced by vision embeddings anyway.
242
+ input_ids.clamp_(min=0, max=vocab_size - 1)
243
+ inputs_embeds = input_embedding(input_ids)
244
+
245
+ special_image_mask = special_image_mask.expand_as(inputs_embeds).to(
246
+ inputs_embeds.device
247
+ )
248
+
249
+ inputs_embeds = inputs_embeds.masked_scatter(
250
+ special_image_mask,
251
+ image_embedding.to(inputs_embeds.device, inputs_embeds.dtype),
252
+ )
253
+ return inputs_embeds
254
+
255
+
256
+ def embed_image_embedding(
257
+ inputs_embeds: torch.Tensor,
258
+ image_embedding: torch.Tensor,
259
+ image_bounds: torch.Tensor,
260
+ ) -> torch.Tensor:
261
+ """
262
+ scatter image_embedding into inputs_embeds according to image_bounds
263
+ """
264
+ if len(image_bounds) > 0:
265
+ image_indices = torch.stack(
266
+ [
267
+ torch.arange(start, end, dtype=torch.long)
268
+ for start, end in image_bounds.tolist()
269
+ ]
270
+ ).to(inputs_embeds.device)
271
+
272
+ inputs_embeds.scatter_(
273
+ 0,
274
+ image_indices.view(-1, 1).repeat(1, inputs_embeds.shape[-1]),
275
+ image_embedding.view(-1, image_embedding.shape[-1]),
276
+ )
277
+ return inputs_embeds
278
+
279
+
280
+ def general_mm_embed_routine(
281
+ input_ids: torch.Tensor,
282
+ forward_batch: ForwardBatch,
283
+ embed_tokens: nn.Embedding,
284
+ mm_data_embedding_func: Callable[[MultimodalInputs], torch.Tensor],
285
+ placeholder_token_ids: List[int] = None,
286
+ ):
287
+ """
288
+ a general wrapper function to get final input embeds from multimodal models
289
+ with a language model as causal model
290
+
291
+ Args:
292
+ placeholder_token_ids (List[int]): the ids of mm data placeholder tokens
293
+
294
+ """
295
+ if (
296
+ not forward_batch.forward_mode.is_decode()
297
+ and forward_batch.contains_mm_inputs()
298
+ ):
299
+ image = forward_batch.merge_mm_inputs()
300
+ inputs_embeds = embed_mm_inputs(
301
+ mm_input=image,
302
+ input_ids=input_ids,
303
+ input_embedding=embed_tokens,
304
+ mm_data_embedding_func=mm_data_embedding_func,
305
+ placeholder_token_ids=placeholder_token_ids,
306
+ )
307
+ # once used, mm_inputs is useless
308
+ # just being defensive here
309
+ forward_batch.mm_inputs = None
310
+ else:
311
+ inputs_embeds = embed_tokens(input_ids)
312
+
313
+ return inputs_embeds
314
+
315
+
316
+ def get_multimodal_data_bounds(
317
+ input_ids: torch.Tensor, pad_values: List[int], token_pairs: List[Tuple[int, int]]
318
+ ) -> torch.Tensor:
319
+ """
320
+ Returns a tensor indicating the bounds of multimodal data (images, video, audio, etc.)
321
+
322
+ Returns:
323
+ [bounds_count, 2]
324
+ """
325
+ # All the images in the batch should share the same special image
326
+ # bound token ids.
327
+ start_tokens = [s for s, _e in token_pairs]
328
+ end_tokens = [e for _s, e in token_pairs]
329
+
330
+ assert all(isinstance(t, int) for t in start_tokens)
331
+ assert all(isinstance(t, int) for t in end_tokens)
332
+
333
+ # print(input_ids)
334
+ start_cond = torch.isin(
335
+ input_ids, torch.tensor(start_tokens, device=input_ids.device)
336
+ )
337
+ end_cond = torch.isin(input_ids, torch.tensor(end_tokens, device=input_ids.device))
338
+
339
+ (data_start_tokens,) = torch.where(start_cond)
340
+ (data_end_tokens,) = torch.where(end_cond)
341
+
342
+ # the im_start_id sometimes can be cached as prefix, but it is needed for the embedding of the images
343
+ if len(data_start_tokens) != len(data_end_tokens):
344
+ if (
345
+ len(data_start_tokens) + 1 == len(data_end_tokens)
346
+ and input_ids[0] in pad_values
347
+ and data_end_tokens[0] < data_start_tokens[0]
348
+ ):
349
+ data_start_tokens = torch.cat(
350
+ [
351
+ torch.tensor([0], device=data_start_tokens.device),
352
+ data_start_tokens,
353
+ ]
354
+ )
355
+ valid_image_nums = min(len(data_start_tokens), len(data_end_tokens))
356
+
357
+ if valid_image_nums == 0:
358
+ return torch.zeros((0, 2), device=input_ids.device)
359
+
360
+ # Filter out pairs where start_token >= end_token
361
+ valid_pairs = []
362
+ for i in range(valid_image_nums):
363
+ start_token = data_start_tokens[i]
364
+ end_token = data_end_tokens[i]
365
+ if start_token < end_token:
366
+ valid_pairs.append((start_token + 1, end_token - 1))
367
+
368
+ if not valid_pairs:
369
+ return torch.zeros((0, 2), device=input_ids.device)
370
+
371
+ # Convert valid pairs to tensor
372
+ valid_pairs_tensor = torch.tensor(valid_pairs, device=input_ids.device)
373
+ return valid_pairs_tensor
@@ -0,0 +1,68 @@
1
+ # TODO: also move pad_input_ids into this module
2
+ import importlib
3
+ import inspect
4
+ import logging
5
+ import pkgutil
6
+ from functools import lru_cache
7
+
8
+ from transformers import PROCESSOR_MAPPING
9
+
10
+ from sglang.srt.managers.multimodal_processors.base_processor import (
11
+ BaseMultimodalProcessor,
12
+ )
13
+ from sglang.srt.server_args import ServerArgs
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+ PROCESSOR_MAPPING = {}
18
+
19
+
20
+ class DummyMultimodalProcessor(BaseMultimodalProcessor):
21
+ def __init__(self):
22
+ pass
23
+
24
+ async def process_mm_data_async(self, *args, **kwargs):
25
+ return None
26
+
27
+
28
+ def get_dummy_processor():
29
+ return DummyMultimodalProcessor()
30
+
31
+
32
+ @lru_cache()
33
+ def import_processors():
34
+ package_name = "sglang.srt.managers.multimodal_processors"
35
+ package = importlib.import_module(package_name)
36
+ for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."):
37
+ if not ispkg:
38
+ try:
39
+ module = importlib.import_module(name)
40
+ except Exception as e:
41
+ logger.warning(f"Ignore import error when loading {name}: " f"{e}")
42
+ continue
43
+ all_members = inspect.getmembers(module, inspect.isclass)
44
+ classes = [
45
+ member
46
+ for name, member in all_members
47
+ if member.__module__ == module.__name__
48
+ ]
49
+ for cls in (
50
+ cls for cls in classes if issubclass(cls, BaseMultimodalProcessor)
51
+ ):
52
+ assert hasattr(cls, "models")
53
+ for arch in getattr(cls, "models"):
54
+ PROCESSOR_MAPPING[arch] = cls
55
+
56
+
57
+ def get_mm_processor(
58
+ hf_config, server_args: ServerArgs, processor
59
+ ) -> BaseMultimodalProcessor:
60
+ for model_cls, processor_cls in PROCESSOR_MAPPING.items():
61
+ if model_cls.__name__ in hf_config.architectures:
62
+ return processor_cls(hf_config, server_args, processor)
63
+ raise ValueError(
64
+ f"No processor registered for architecture: {hf_config.architectures}.\n"
65
+ f"Registered architectures: {[model_cls.__name__ for model_cls in PROCESSOR_MAPPING.keys()]}"
66
+ )
67
+
68
+ self.image_proce