sglang 0.4.3.post4__py3-none-any.whl → 0.4.4.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (131) hide show
  1. sglang/bench_serving.py +1 -1
  2. sglang/lang/chat_template.py +29 -0
  3. sglang/srt/_custom_ops.py +19 -17
  4. sglang/srt/configs/__init__.py +2 -0
  5. sglang/srt/configs/janus_pro.py +629 -0
  6. sglang/srt/configs/model_config.py +24 -14
  7. sglang/srt/conversation.py +80 -2
  8. sglang/srt/custom_op.py +64 -3
  9. sglang/srt/distributed/device_communicators/custom_all_reduce.py +18 -17
  10. sglang/srt/distributed/parallel_state.py +10 -1
  11. sglang/srt/entrypoints/engine.py +5 -3
  12. sglang/srt/entrypoints/http_server.py +1 -1
  13. sglang/srt/function_call_parser.py +33 -2
  14. sglang/srt/hf_transformers_utils.py +16 -1
  15. sglang/srt/layers/attention/flashinfer_backend.py +1 -1
  16. sglang/srt/layers/attention/flashinfer_mla_backend.py +317 -57
  17. sglang/srt/layers/attention/triton_backend.py +1 -3
  18. sglang/srt/layers/attention/triton_ops/decode_attention.py +6 -6
  19. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +3 -3
  20. sglang/srt/layers/attention/triton_ops/extend_attention.py +4 -4
  21. sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +3 -3
  22. sglang/srt/layers/attention/vision.py +43 -62
  23. sglang/srt/layers/dp_attention.py +30 -2
  24. sglang/srt/layers/elementwise.py +411 -0
  25. sglang/srt/layers/linear.py +1 -1
  26. sglang/srt/layers/logits_processor.py +1 -0
  27. sglang/srt/layers/moe/ep_moe/kernels.py +2 -1
  28. sglang/srt/layers/moe/ep_moe/layer.py +25 -9
  29. sglang/srt/layers/moe/fused_moe_triton/configs/E=160,N=192,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  30. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  31. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  32. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  33. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  34. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  35. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +63 -23
  36. sglang/srt/layers/moe/fused_moe_triton/layer.py +16 -4
  37. sglang/srt/layers/moe/router.py +342 -0
  38. sglang/srt/layers/parameter.py +10 -0
  39. sglang/srt/layers/quantization/__init__.py +90 -68
  40. sglang/srt/layers/quantization/blockwise_int8.py +1 -2
  41. sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  42. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  43. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  44. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  45. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  46. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  47. sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  48. sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  49. sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  50. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  51. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  52. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  53. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  54. sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  56. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  57. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  58. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  59. sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  60. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  62. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  63. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  64. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  65. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  66. sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  67. sglang/srt/layers/quantization/fp8.py +174 -106
  68. sglang/srt/layers/quantization/fp8_kernel.py +210 -38
  69. sglang/srt/layers/quantization/fp8_utils.py +156 -15
  70. sglang/srt/layers/quantization/modelopt_quant.py +5 -1
  71. sglang/srt/layers/quantization/w8a8_fp8.py +128 -0
  72. sglang/srt/layers/quantization/w8a8_int8.py +152 -3
  73. sglang/srt/layers/rotary_embedding.py +5 -3
  74. sglang/srt/layers/sampler.py +29 -35
  75. sglang/srt/layers/vocab_parallel_embedding.py +0 -1
  76. sglang/srt/lora/backend/__init__.py +9 -12
  77. sglang/srt/managers/cache_controller.py +74 -8
  78. sglang/srt/managers/data_parallel_controller.py +1 -1
  79. sglang/srt/managers/image_processor.py +37 -631
  80. sglang/srt/managers/image_processors/base_image_processor.py +219 -0
  81. sglang/srt/managers/image_processors/janus_pro.py +79 -0
  82. sglang/srt/managers/image_processors/llava.py +152 -0
  83. sglang/srt/managers/image_processors/minicpmv.py +86 -0
  84. sglang/srt/managers/image_processors/mlama.py +60 -0
  85. sglang/srt/managers/image_processors/qwen_vl.py +161 -0
  86. sglang/srt/managers/io_struct.py +32 -15
  87. sglang/srt/managers/multi_modality_padding.py +134 -0
  88. sglang/srt/managers/schedule_batch.py +213 -118
  89. sglang/srt/managers/schedule_policy.py +40 -8
  90. sglang/srt/managers/scheduler.py +176 -683
  91. sglang/srt/managers/scheduler_output_processor_mixin.py +614 -0
  92. sglang/srt/managers/tokenizer_manager.py +6 -6
  93. sglang/srt/managers/tp_worker_overlap_thread.py +4 -1
  94. sglang/srt/mem_cache/base_prefix_cache.py +6 -8
  95. sglang/srt/mem_cache/chunk_cache.py +12 -44
  96. sglang/srt/mem_cache/hiradix_cache.py +71 -34
  97. sglang/srt/mem_cache/memory_pool.py +81 -17
  98. sglang/srt/mem_cache/paged_allocator.py +283 -0
  99. sglang/srt/mem_cache/radix_cache.py +117 -36
  100. sglang/srt/model_executor/cuda_graph_runner.py +68 -20
  101. sglang/srt/model_executor/forward_batch_info.py +23 -10
  102. sglang/srt/model_executor/model_runner.py +63 -63
  103. sglang/srt/model_loader/loader.py +2 -1
  104. sglang/srt/model_loader/weight_utils.py +1 -1
  105. sglang/srt/models/deepseek_janus_pro.py +2127 -0
  106. sglang/srt/models/deepseek_nextn.py +23 -3
  107. sglang/srt/models/deepseek_v2.py +200 -191
  108. sglang/srt/models/grok.py +374 -119
  109. sglang/srt/models/minicpmv.py +28 -89
  110. sglang/srt/models/mllama.py +1 -1
  111. sglang/srt/models/qwen2.py +0 -1
  112. sglang/srt/models/qwen2_5_vl.py +25 -50
  113. sglang/srt/models/qwen2_vl.py +33 -49
  114. sglang/srt/openai_api/adapter.py +59 -35
  115. sglang/srt/openai_api/protocol.py +8 -1
  116. sglang/srt/sampling/penaltylib/frequency_penalty.py +0 -1
  117. sglang/srt/sampling/penaltylib/presence_penalty.py +0 -1
  118. sglang/srt/server_args.py +24 -16
  119. sglang/srt/speculative/eagle_worker.py +75 -39
  120. sglang/srt/utils.py +104 -9
  121. sglang/test/runners.py +104 -10
  122. sglang/test/test_block_fp8.py +106 -16
  123. sglang/test/test_custom_ops.py +88 -0
  124. sglang/test/test_utils.py +20 -4
  125. sglang/utils.py +0 -4
  126. sglang/version.py +1 -1
  127. {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/METADATA +9 -10
  128. {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/RECORD +131 -84
  129. {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/WHEEL +1 -1
  130. {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/LICENSE +0 -0
  131. {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,629 @@
1
+ # Adapted from:
2
+ # https://github.com/deepseek-ai/Janus/tree/main/janus/models
3
+
4
+ from dataclasses import dataclass
5
+ from typing import Dict, List, Tuple, Union
6
+
7
+ import numpy as np
8
+ import PIL
9
+ import torch
10
+ from PIL.Image import Image
11
+ from transformers import (
12
+ AutoImageProcessor,
13
+ AutoProcessor,
14
+ BaseImageProcessor,
15
+ BatchFeature,
16
+ LlamaConfig,
17
+ LlamaTokenizerFast,
18
+ PretrainedConfig,
19
+ ProcessorMixin,
20
+ )
21
+ from transformers.image_utils import to_numpy_array
22
+
23
+ from sglang.srt.mm_utils import expand2square
24
+
25
+
26
+ class DictToObject(dict):
27
+ def __init__(self, dictionary):
28
+ super(self).__init__(dictionary)
29
+
30
+ for key, value in dictionary.items():
31
+ if isinstance(value, dict):
32
+ value = DictToObject(value)
33
+ setattr(self, key, value)
34
+
35
+
36
+ class VisionConfig(PretrainedConfig):
37
+ model_type = "vision"
38
+ cls: str = ""
39
+ params = {}
40
+
41
+ def __init__(self, **kwargs):
42
+ super().__init__(**kwargs)
43
+
44
+ self.cls = kwargs.get("cls", "")
45
+ if not isinstance(self.cls, str):
46
+ self.cls = self.cls.__name__
47
+
48
+ self.params = kwargs.get("params", {})
49
+
50
+
51
+ class GenAlignerConfig(PretrainedConfig):
52
+ model_type = "gen_aligner"
53
+ cls: str = ""
54
+ params = {}
55
+
56
+ def __init__(self, **kwargs):
57
+ super().__init__(**kwargs)
58
+
59
+ self.cls = kwargs.get("cls", "")
60
+ if not isinstance(self.cls, str):
61
+ self.cls = self.cls.__name__
62
+
63
+ self.params = kwargs.get("params", {})
64
+
65
+
66
+ class GenHeadConfig(PretrainedConfig):
67
+ model_type = "gen_head"
68
+ cls: str = ""
69
+ params = {}
70
+
71
+ def __init__(self, **kwargs):
72
+ super().__init__(**kwargs)
73
+
74
+ self.cls = kwargs.get("cls", "")
75
+ if not isinstance(self.cls, str):
76
+ self.cls = self.cls.__name__
77
+
78
+ self.params = kwargs.get("params", {})
79
+
80
+
81
+ class AlignerConfig(PretrainedConfig):
82
+ model_type = "aligner"
83
+ cls: str = ""
84
+ params = {}
85
+
86
+ def __init__(self, **kwargs):
87
+ super().__init__(**kwargs)
88
+
89
+ self.cls = kwargs.get("cls", "")
90
+ if not isinstance(self.cls, str):
91
+ self.cls = self.cls.__name__
92
+
93
+ self.params = kwargs.get("params", {})
94
+
95
+
96
+ class GenVisionConfig(PretrainedConfig):
97
+ model_type = "gen_vision"
98
+ cls: str = ""
99
+ params = {}
100
+
101
+ def __init__(self, **kwargs):
102
+ super().__init__(**kwargs)
103
+
104
+ self.cls = kwargs.get("cls", "")
105
+ if not isinstance(self.cls, str):
106
+ self.cls = self.cls.__name__
107
+
108
+ self.params = kwargs.get("params", {})
109
+
110
+
111
+ @dataclass
112
+ class SigLIPVisionCfg:
113
+ width: int = 1152
114
+ layers: Union[Tuple[int, int, int, int], int] = 27
115
+ heads: int = 16
116
+ patch_size: int = 14
117
+ image_size: Union[Tuple[int, int], int] = 336
118
+ global_pool: str = "map"
119
+ mlp_ratio: float = 3.7362
120
+ class_token: bool = False
121
+ num_classes: int = 0
122
+ use_checkpoint: bool = False
123
+
124
+
125
+ class MultiModalityConfig(PretrainedConfig):
126
+ model_type = "multi_modality"
127
+ vision_config: VisionConfig
128
+ aligner_config: AlignerConfig
129
+
130
+ gen_vision_config: GenVisionConfig
131
+ gen_aligner_config: GenAlignerConfig
132
+ gen_head_config: GenHeadConfig
133
+
134
+ language_config: LlamaConfig
135
+
136
+ def __init__(self, **kwargs):
137
+ super().__init__(**kwargs)
138
+ vision_config = kwargs.get("vision_config", {})
139
+ self.vision_config = VisionConfig(**vision_config)
140
+
141
+ aligner_config = kwargs.get("aligner_config", {})
142
+ self.aligner_config = AlignerConfig(**aligner_config)
143
+
144
+ gen_vision_config = kwargs.get("gen_vision_config", {})
145
+ self.gen_vision_config = GenVisionConfig(**gen_vision_config)
146
+
147
+ gen_aligner_config = kwargs.get("gen_aligner_config", {})
148
+ self.gen_aligner_config = GenAlignerConfig(**gen_aligner_config)
149
+
150
+ gen_head_config = kwargs.get("gen_head_config", {})
151
+ self.gen_head_config = GenHeadConfig(**gen_head_config)
152
+
153
+ language_config = kwargs.get("language_config", {})
154
+ if isinstance(language_config, LlamaConfig):
155
+ self.language_config = language_config
156
+ else:
157
+ self.language_config = LlamaConfig(**language_config)
158
+
159
+
160
+ class VLMImageProcessor(BaseImageProcessor):
161
+ model_input_names = ["pixel_values"]
162
+
163
+ def __init__(
164
+ self,
165
+ image_size: int,
166
+ min_size: int = 14,
167
+ image_mean: Union[Tuple[float, float, float], List[float]] = (
168
+ 0.48145466,
169
+ 0.4578275,
170
+ 0.40821073,
171
+ ),
172
+ image_std: Union[Tuple[float, float, float], List[float]] = (
173
+ 0.26862954,
174
+ 0.26130258,
175
+ 0.27577711,
176
+ ),
177
+ rescale_factor: float = 1.0 / 255.0,
178
+ do_normalize: bool = True,
179
+ **kwargs,
180
+ ):
181
+ super().__init__(**kwargs)
182
+
183
+ self.image_size = image_size
184
+ self.rescale_factor = rescale_factor
185
+ self.image_mean = image_mean
186
+ self.image_std = image_std
187
+ self.min_size = min_size
188
+ self.do_normalize = do_normalize
189
+
190
+ if image_mean is None:
191
+ self.background_color = (127, 127, 127)
192
+ else:
193
+ self.background_color = tuple([int(x * 255) for x in image_mean])
194
+
195
+ def resize(self, pil_img: Image) -> np.ndarray:
196
+ """
197
+
198
+ Args:
199
+ pil_img (PIL.Image): [H, W, 3] in PIL.Image in RGB
200
+
201
+ Returns:
202
+ x (np.ndarray): [3, self.image_size, self.image_size]
203
+ """
204
+
205
+ width, height = pil_img.size
206
+ max_size = max(width, height)
207
+
208
+ size = [
209
+ max(int(height / max_size * self.image_size), self.min_size),
210
+ max(int(width / max_size * self.image_size), self.min_size),
211
+ ]
212
+
213
+ if width <= 0 or height <= 0 or size[0] <= 0 or size[1] <= 0:
214
+ # print(f"orig size = {pil_img.size}, new size = {size}")
215
+ raise ValueError("Invalid size!")
216
+
217
+ def resize(
218
+ pil_img, size, interpolation=PIL.Image.Resampling.BICUBIC, antialias=True
219
+ ):
220
+ if isinstance(size, int):
221
+ w, h = pil_img.size
222
+ if (w <= h and w == size) or (h <= w and h == size):
223
+ return pil_img
224
+ if w < h:
225
+ ow = size
226
+ oh = int(size * h / w)
227
+ else:
228
+ oh = size
229
+ ow = int(size * w / h)
230
+ size = (ow, oh)
231
+ else:
232
+ size = (size[1], size[0])
233
+
234
+ return pil_img.resize(
235
+ size, resample=interpolation, reducing_gap=None if antialias else 3.0
236
+ )
237
+
238
+ pil_img = resize(
239
+ pil_img, size, interpolation=PIL.Image.Resampling.BICUBIC, antialias=True
240
+ )
241
+
242
+ pil_img = expand2square(pil_img, self.background_color)
243
+ x = to_numpy_array(pil_img)
244
+
245
+ # [H, W, 3] -> [3, H, W]
246
+ x = np.transpose(x, (2, 0, 1))
247
+
248
+ return x
249
+
250
+ def preprocess(self, images, return_tensors: str = "pt", **kwargs) -> BatchFeature:
251
+ # resize and pad to [self.image_size, self.image_size]
252
+ # then convert from [H, W, 3] to [3, H, W]
253
+ if not isinstance(images, list):
254
+ images = [images]
255
+ images: List[np.ndarray] = [self.resize(image) for image in images]
256
+ images = [image[:3, ...] for image in images]
257
+
258
+ # rescale from [0, 255] -> [0, 1]
259
+ images = [
260
+ self.rescale(
261
+ image=image,
262
+ scale=self.rescale_factor,
263
+ input_data_format="channels_first",
264
+ )
265
+ for image in images
266
+ ]
267
+
268
+ # normalize
269
+ if self.do_normalize:
270
+ images = [
271
+ self.normalize(
272
+ image=image,
273
+ mean=self.image_mean,
274
+ std=self.image_std,
275
+ input_data_format="channels_first",
276
+ )
277
+ for image in images
278
+ ]
279
+ data = {"pixel_values": images}
280
+ return BatchFeature(data=data, tensor_type=return_tensors)
281
+
282
+ @property
283
+ def default_shape(self):
284
+ return [3, self.image_size, self.image_size]
285
+
286
+
287
+ class DictOutput(object):
288
+ def keys(self):
289
+ return self.__dict__.keys()
290
+
291
+ def __getitem__(self, item):
292
+ return self.__dict__[item]
293
+
294
+ def __setitem__(self, key, value):
295
+ self.__dict__[key] = value
296
+
297
+
298
+ @dataclass
299
+ class VLChatProcessorOutput(DictOutput):
300
+ sft_format: str
301
+ input_ids: torch.Tensor
302
+ pixel_values: torch.Tensor
303
+ num_image_tokens: torch.IntTensor
304
+
305
+ def __len__(self):
306
+ return len(self.input_ids)
307
+
308
+
309
+ @dataclass
310
+ class BatchedVLChatProcessorOutput(DictOutput):
311
+ sft_format: List[str]
312
+ input_ids: torch.Tensor
313
+ pixel_values: torch.Tensor
314
+ attention_mask: torch.Tensor
315
+ images_seq_mask: torch.BoolTensor
316
+ images_emb_mask: torch.BoolTensor
317
+
318
+
319
+ # FIXME: had to place Official Processor here, since image_processor module would not be imported in all threads,
320
+ # hence AutoProcessor registration would not be affective in some cases
321
+ class VLChatProcessor(ProcessorMixin):
322
+ image_processor_class = "AutoImageProcessor"
323
+ tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast")
324
+
325
+ attributes = ["image_processor", "tokenizer"]
326
+
327
+ def __init__(
328
+ self,
329
+ image_processor: VLMImageProcessor,
330
+ tokenizer: LlamaTokenizerFast,
331
+ image_tag: str = "<image_placeholder>",
332
+ image_start_tag: str = "<begin_of_image>",
333
+ image_end_tag: str = "<end_of_image>",
334
+ pad_tag: str = "<|▁pad▁|>",
335
+ num_image_tokens: int = 576,
336
+ add_special_token: bool = False,
337
+ sft_format: str = "deepseek",
338
+ mask_prompt: bool = True,
339
+ ignore_id: int = -100,
340
+ **kwargs,
341
+ ):
342
+ self.image_processor = image_processor
343
+ self.tokenizer = tokenizer
344
+
345
+ image_id = self.tokenizer.vocab.get(image_tag)
346
+ if image_id is None:
347
+ special_tokens = [image_tag]
348
+ special_tokens_dict = {"additional_special_tokens": special_tokens}
349
+ self.tokenizer.add_special_tokens(special_tokens_dict)
350
+ # print(f"Add image tag = {image_tag} to the tokenizer")
351
+
352
+ self.image_tag = image_tag
353
+ self.image_start_tag = image_start_tag
354
+ self.image_end_tag = image_end_tag
355
+ self.pad_tag = pad_tag
356
+
357
+ self.num_image_tokens = num_image_tokens
358
+ self.add_special_token = add_special_token
359
+ self.sft_format = sft_format
360
+ self.ignore_id = ignore_id
361
+
362
+ super().__init__(
363
+ image_processor,
364
+ tokenizer,
365
+ **kwargs,
366
+ )
367
+
368
+ @property
369
+ def image_token(self):
370
+ return self.image_tag
371
+
372
+ @property
373
+ def image_id(self) -> int:
374
+ image_id = self.tokenizer.vocab.get(self.image_tag)
375
+ return image_id
376
+
377
+ @property
378
+ def image_start_id(self):
379
+ image_start_id = self.tokenizer.vocab.get(self.image_start_tag)
380
+ return image_start_id
381
+
382
+ @property
383
+ def image_end_id(self):
384
+ image_end_id = self.tokenizer.vocab.get(self.image_end_tag)
385
+ return image_end_id
386
+
387
+ @property
388
+ def image_start_token(self):
389
+ return self.image_start_tag
390
+
391
+ @property
392
+ def image_end_token(self):
393
+ return self.image_end_tag
394
+
395
+ @property
396
+ def pad_id(self):
397
+ pad_id = self.tokenizer.vocab.get(self.pad_tag)
398
+ return pad_id
399
+
400
+ def add_image_token(
401
+ self,
402
+ image_indices: List[int],
403
+ input_ids: torch.LongTensor,
404
+ ):
405
+ """
406
+
407
+ Args:
408
+ image_indices (List[int]): [index_0, index_1, ..., index_j]
409
+ input_ids (torch.LongTensor): [N]
410
+
411
+ Returns:
412
+ input_ids (torch.LongTensor): [N + image tokens]
413
+ num_image_tokens (torch.IntTensor): [n_images]
414
+ """
415
+
416
+ input_slices = []
417
+
418
+ start = 0
419
+ for index in image_indices:
420
+ if self.add_special_token:
421
+ end = index + 1
422
+ else:
423
+ end = index
424
+
425
+ # original text tokens
426
+ input_slices.append(input_ids[start:end])
427
+
428
+ # add boi, image tokens, eoi and set the mask as False
429
+ input_slices.append(self.image_start_id * torch.ones((1), dtype=torch.long))
430
+ input_slices.append(
431
+ self.image_id * torch.ones((self.num_image_tokens,), dtype=torch.long)
432
+ )
433
+ input_slices.append(self.image_end_id * torch.ones((1), dtype=torch.long))
434
+ start = index + 1
435
+
436
+ # the left part
437
+ input_slices.append(input_ids[start:])
438
+
439
+ # concat all slices
440
+ input_ids = torch.cat(input_slices, dim=0)
441
+ num_image_tokens = torch.IntTensor([self.num_image_tokens] * len(image_indices))
442
+
443
+ return input_ids, num_image_tokens
444
+
445
+ def process_one(
446
+ self,
447
+ prompt: str = None,
448
+ images: List[Image] = None,
449
+ **kwargs,
450
+ ):
451
+ """
452
+
453
+ Args:
454
+ prompt (str): the formatted prompt;
455
+ images (List[ImageType]): the list of images;
456
+ **kwargs:
457
+
458
+ Returns:
459
+ outputs (BaseProcessorOutput): the output of the processor,
460
+ - input_ids (torch.LongTensor): [N + image tokens]
461
+ - target_ids (torch.LongTensor): [N + image tokens]
462
+ - images (torch.FloatTensor): [n_images, 3, H, W]
463
+ - image_id (int): the id of the image token
464
+ - num_image_tokens (List[int]): the number of image tokens
465
+ """
466
+
467
+ sft_format = prompt
468
+ # tokenize
469
+ input_ids = self.tokenizer.encode(sft_format)
470
+ input_ids = torch.LongTensor(input_ids)
471
+
472
+ # add image tokens to the input_ids
473
+ image_token_mask: torch.Tensor = (input_ids == self.image_id).to(torch.bool)
474
+ image_indices = image_token_mask.nonzero()
475
+ input_ids, num_image_tokens = self.add_image_token(
476
+ image_indices=image_indices,
477
+ input_ids=input_ids,
478
+ )
479
+
480
+ # load images
481
+ images_outputs = self.image_processor(images, return_tensors="pt")
482
+
483
+ prepare = VLChatProcessorOutput(
484
+ sft_format=sft_format,
485
+ input_ids=input_ids,
486
+ pixel_values=images_outputs.pixel_values,
487
+ num_image_tokens=num_image_tokens,
488
+ )
489
+
490
+ return prepare
491
+
492
+ def __call__(
493
+ self,
494
+ *,
495
+ prompt: str = None,
496
+ conversations: List[Dict[str, str]] = None,
497
+ images: List[Image] = None,
498
+ force_batchify: bool = True,
499
+ **kwargs,
500
+ ):
501
+ """
502
+
503
+ Args:
504
+ prompt (str): the formatted prompt;
505
+ conversations (List[Dict]): conversations with a list of messages;
506
+ images (List[ImageType]): the list of images;
507
+ force_batchify (bool): force batchify the inputs;
508
+ **kwargs:
509
+
510
+ Returns:
511
+ outputs (BaseProcessorOutput): the output of the processor,
512
+ - input_ids (torch.LongTensor): [N + image tokens]
513
+ - images (torch.FloatTensor): [n_images, 3, H, W]
514
+ - image_id (int): the id of the image token
515
+ - num_image_tokens (List[int]): the number of image tokens
516
+ """
517
+
518
+ prepare = self.process_one(
519
+ prompt=prompt, conversations=conversations, images=images
520
+ )
521
+
522
+ if force_batchify:
523
+ prepare = self.batchify([prepare])
524
+
525
+ return prepare
526
+
527
+ def batchify(
528
+ self, prepare_list: List[VLChatProcessorOutput]
529
+ ) -> BatchedVLChatProcessorOutput:
530
+ """
531
+ Preprocesses the inputs for multimodal inference.
532
+
533
+ Args:
534
+ prepare_list (List[VLChatProcessorOutput]): A list of VLChatProcessorOutput.
535
+
536
+ Returns:
537
+ BatchedVLChatProcessorOutput: A dictionary of the inputs to use for multimodal inference.
538
+ """
539
+
540
+ batch_size = len(prepare_list)
541
+ sft_format = []
542
+ n_images = []
543
+ seq_lens = []
544
+ for prepare in prepare_list:
545
+ n_images.append(len(prepare.num_image_tokens))
546
+ seq_lens.append(len(prepare))
547
+
548
+ input_token_max_len = max(seq_lens)
549
+ max_n_images = max(1, max(n_images))
550
+
551
+ batched_input_ids = torch.full(
552
+ (batch_size, input_token_max_len), self.pad_id
553
+ ).long() # FIXME
554
+ batched_attention_mask = torch.zeros((batch_size, input_token_max_len)).long()
555
+ batched_pixel_values = torch.zeros(
556
+ (batch_size, max_n_images, *self.image_processor.default_shape)
557
+ ).float()
558
+ batched_images_seq_mask = torch.zeros((batch_size, input_token_max_len)).bool()
559
+ batched_images_emb_mask = torch.zeros(
560
+ (batch_size, max_n_images, self.num_image_tokens)
561
+ ).bool()
562
+
563
+ for i, prepare in enumerate(prepare_list):
564
+ input_ids = prepare.input_ids
565
+ seq_len = len(prepare)
566
+ n_image = len(prepare.num_image_tokens)
567
+ # left-padding
568
+ batched_attention_mask[i, -seq_len:] = 1
569
+ batched_input_ids[i, -seq_len:] = torch.LongTensor(input_ids)
570
+ batched_images_seq_mask[i, -seq_len:] = input_ids == self.image_id
571
+
572
+ if n_image > 0:
573
+ batched_pixel_values[i, :n_image] = prepare.pixel_values
574
+ for j, n_image_tokens in enumerate(prepare.num_image_tokens):
575
+ batched_images_emb_mask[i, j, :n_image_tokens] = True
576
+
577
+ sft_format.append(prepare.sft_format)
578
+
579
+ batched_prepares = BatchedVLChatProcessorOutput(
580
+ input_ids=batched_input_ids,
581
+ attention_mask=batched_attention_mask,
582
+ pixel_values=batched_pixel_values,
583
+ images_seq_mask=batched_images_seq_mask,
584
+ images_emb_mask=batched_images_emb_mask,
585
+ sft_format=sft_format,
586
+ )
587
+
588
+ return batched_prepares
589
+
590
+
591
+ class VLMImageProcessorConfig(PretrainedConfig):
592
+ model_type = "deepseek_vlm"
593
+ image_size: int
594
+ min_size: int
595
+ image_mean: Union[Tuple[float, float, float], List[float]]
596
+ image_std: Union[Tuple[float, float, float], List[float]]
597
+ rescale_factor: float
598
+ do_normalize: bool
599
+
600
+ def __init__(
601
+ self,
602
+ image_size: int,
603
+ min_size: int = 14,
604
+ image_mean: Union[Tuple[float, float, float], List[float]] = (
605
+ 0.48145466,
606
+ 0.4578275,
607
+ 0.40821073,
608
+ ),
609
+ image_std: Union[Tuple[float, float, float], List[float]] = (
610
+ 0.26862954,
611
+ 0.26130258,
612
+ 0.27577711,
613
+ ),
614
+ rescale_factor: float = 1.0 / 255.0,
615
+ do_normalize: bool = True,
616
+ **kwargs,
617
+ ):
618
+ self.image_size = image_size
619
+ self.min_size = min_size
620
+ self.image_mean = image_mean
621
+ self.image_std = image_std
622
+ self.rescale_factor = rescale_factor
623
+ self.do_normalize = do_normalize
624
+
625
+ super().__init__(**kwargs)
626
+
627
+
628
+ AutoProcessor.register(MultiModalityConfig, VLChatProcessor, exist_ok=True)
629
+ AutoImageProcessor.register(VLMImageProcessorConfig, None, VLMImageProcessor, None)
@@ -81,7 +81,7 @@ class ModelConfig:
81
81
  if context_length is not None:
82
82
  if context_length > derived_context_len:
83
83
  if get_bool_env_var(
84
- "SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN", default="False"
84
+ "SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN", default="True"
85
85
  ):
86
86
  logger.warning(
87
87
  f"Warning: User-specified context_length ({context_length}) is greater than the derived context_length ({derived_context_len}). "
@@ -237,6 +237,7 @@ class ModelConfig:
237
237
  "compressed_tensors",
238
238
  "compressed-tensors",
239
239
  "fbgemm_fp8",
240
+ "w8a8_fp8",
240
241
  ]
241
242
  optimized_quantization_methods = [
242
243
  "fp8",
@@ -250,9 +251,11 @@ class ModelConfig:
250
251
  "compressed-tensors",
251
252
  "experts_int8",
252
253
  "w8a8_int8",
254
+ "w8a8_fp8",
253
255
  ]
254
256
  compatible_quantization_methods = {
255
- "w8a8_int8": ["compressed-tensors", "compressed_tensors"]
257
+ "w8a8_int8": ["compressed-tensors", "compressed_tensors"],
258
+ "w8a8_fp8": ["compressed-tensors", "compressed_tensors"],
256
259
  }
257
260
  if self.quantization is not None:
258
261
  self.quantization = self.quantization.lower()
@@ -405,7 +408,7 @@ def _get_and_verify_dtype(
405
408
 
406
409
  def is_generation_model(model_architectures: List[str], is_embedding: bool = False):
407
410
  # We have two ways to determine whether a model is a generative model.
408
- # 1. Check the model architectue
411
+ # 1. Check the model architecture
409
412
  # 2. check the `is_embedding` server args
410
413
 
411
414
  if (
@@ -421,18 +424,25 @@ def is_generation_model(model_architectures: List[str], is_embedding: bool = Fal
421
424
  return not is_embedding
422
425
 
423
426
 
427
+ multimodal_model_archs = [
428
+ "LlavaLlamaForCausalLM",
429
+ "LlavaQwenForCausalLM",
430
+ "LlavaMistralForCausalLM",
431
+ "LlavaVidForCausalLM",
432
+ "Grok1VForCausalLM",
433
+ "Grok1AForCausalLM",
434
+ "MllamaForConditionalGeneration",
435
+ "Qwen2VLForConditionalGeneration",
436
+ "Qwen2_5_VLForConditionalGeneration",
437
+ "MiniCPMV",
438
+ "MultiModalityCausalLM",
439
+ ]
440
+
441
+
424
442
  def is_multimodal_model(model_architectures: List[str]):
425
- if (
426
- "LlavaLlamaForCausalLM" in model_architectures
427
- or "LlavaQwenForCausalLM" in model_architectures
428
- or "LlavaMistralForCausalLM" in model_architectures
429
- or "LlavaVidForCausalLM" in model_architectures
430
- or "Grok1VForCausalLM" in model_architectures
431
- or "Grok1AForCausalLM" in model_architectures
432
- or "MllamaForConditionalGeneration" in model_architectures
433
- or "Qwen2VLForConditionalGeneration" in model_architectures
434
- or "Qwen2_5_VLForConditionalGeneration" in model_architectures
435
- or "MiniCPMV" in model_architectures
443
+ if any(
444
+ multi_model_arch in model_architectures
445
+ for multi_model_arch in multimodal_model_archs
436
446
  ):
437
447
  return True
438
448
  else: