sglang 0.5.4__py3-none-any.whl → 0.5.4.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (195) hide show
  1. sglang/bench_one_batch.py +149 -34
  2. sglang/bench_serving.py +73 -14
  3. sglang/compile_deep_gemm.py +13 -7
  4. sglang/launch_server.py +2 -0
  5. sglang/srt/batch_invariant_ops/__init__.py +2 -0
  6. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +221 -4
  7. sglang/srt/checkpoint_engine/__init__.py +9 -0
  8. sglang/srt/checkpoint_engine/update.py +317 -0
  9. sglang/srt/compilation/backend.py +1 -1
  10. sglang/srt/configs/__init__.py +2 -0
  11. sglang/srt/configs/deepseek_ocr.py +542 -10
  12. sglang/srt/configs/deepseekvl2.py +95 -194
  13. sglang/srt/configs/kimi_linear.py +160 -0
  14. sglang/srt/configs/mamba_utils.py +66 -0
  15. sglang/srt/configs/model_config.py +30 -7
  16. sglang/srt/constants.py +7 -0
  17. sglang/srt/debug_utils/tensor_dump_forward_hook.py +149 -0
  18. sglang/srt/disaggregation/decode.py +34 -6
  19. sglang/srt/disaggregation/nixl/conn.py +2 -2
  20. sglang/srt/disaggregation/prefill.py +25 -3
  21. sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -1
  22. sglang/srt/distributed/parallel_state.py +9 -12
  23. sglang/srt/entrypoints/engine.py +31 -20
  24. sglang/srt/entrypoints/grpc_server.py +0 -1
  25. sglang/srt/entrypoints/http_server.py +94 -94
  26. sglang/srt/entrypoints/openai/protocol.py +7 -1
  27. sglang/srt/entrypoints/openai/serving_chat.py +42 -0
  28. sglang/srt/entrypoints/openai/serving_completions.py +10 -0
  29. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  30. sglang/srt/environ.py +23 -2
  31. sglang/srt/eplb/expert_distribution.py +64 -1
  32. sglang/srt/eplb/expert_location.py +106 -36
  33. sglang/srt/function_call/function_call_parser.py +2 -0
  34. sglang/srt/function_call/minimax_m2.py +367 -0
  35. sglang/srt/grpc/compile_proto.py +3 -0
  36. sglang/srt/layers/activation.py +6 -0
  37. sglang/srt/layers/attention/ascend_backend.py +233 -5
  38. sglang/srt/layers/attention/attention_registry.py +3 -0
  39. sglang/srt/layers/attention/fla/chunk_delta_h.py +61 -32
  40. sglang/srt/layers/attention/fla/fused_recurrent.py +17 -4
  41. sglang/srt/layers/attention/fla/kda.py +1359 -0
  42. sglang/srt/layers/attention/fla/layernorm_gated.py +7 -1
  43. sglang/srt/layers/attention/flashattention_backend.py +19 -8
  44. sglang/srt/layers/attention/flashinfer_backend.py +10 -1
  45. sglang/srt/layers/attention/flashinfer_mla_backend.py +21 -11
  46. sglang/srt/layers/attention/flashmla_backend.py +1 -1
  47. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +223 -0
  48. sglang/srt/layers/attention/mamba/mamba.py +20 -11
  49. sglang/srt/layers/attention/nsa/dequant_k_cache.py +138 -6
  50. sglang/srt/layers/attention/nsa/nsa_indexer.py +45 -22
  51. sglang/srt/layers/attention/nsa/quant_k_cache.py +44 -12
  52. sglang/srt/layers/attention/nsa/transform_index.py +1 -1
  53. sglang/srt/layers/attention/nsa_backend.py +157 -23
  54. sglang/srt/layers/attention/triton_backend.py +4 -1
  55. sglang/srt/layers/attention/trtllm_mha_backend.py +10 -4
  56. sglang/srt/layers/attention/trtllm_mla_backend.py +11 -15
  57. sglang/srt/layers/attention/utils.py +78 -0
  58. sglang/srt/layers/communicator.py +24 -1
  59. sglang/srt/layers/deep_gemm_wrapper/compile_utils.py +1 -1
  60. sglang/srt/layers/layernorm.py +35 -6
  61. sglang/srt/layers/logits_processor.py +9 -20
  62. sglang/srt/layers/moe/cutlass_w4a8_moe.py +138 -0
  63. sglang/srt/layers/moe/ep_moe/kernels.py +194 -0
  64. sglang/srt/layers/moe/ep_moe/layer.py +78 -289
  65. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  66. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128]_down.json +164 -0
  67. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +68 -22
  68. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +43 -3
  69. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +106 -26
  70. sglang/srt/layers/moe/fused_moe_triton/layer.py +3 -3
  71. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +7 -4
  72. sglang/srt/layers/moe/moe_runner/deep_gemm.py +340 -55
  73. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  74. sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
  75. sglang/srt/layers/moe/token_dispatcher/__init__.py +4 -4
  76. sglang/srt/layers/moe/token_dispatcher/base.py +11 -5
  77. sglang/srt/layers/moe/token_dispatcher/deepep.py +25 -18
  78. sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
  79. sglang/srt/layers/moe/topk.py +35 -10
  80. sglang/srt/layers/moe/utils.py +3 -4
  81. sglang/srt/layers/pooler.py +21 -2
  82. sglang/srt/layers/quantization/__init__.py +13 -84
  83. sglang/srt/layers/quantization/auto_round.py +394 -0
  84. sglang/srt/layers/quantization/awq.py +0 -3
  85. sglang/srt/layers/quantization/base_config.py +7 -0
  86. sglang/srt/layers/quantization/fp8.py +68 -63
  87. sglang/srt/layers/quantization/fp8_kernel.py +1 -1
  88. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  89. sglang/srt/layers/quantization/gguf.py +566 -0
  90. sglang/srt/layers/quantization/modelopt_quant.py +168 -11
  91. sglang/srt/layers/quantization/mxfp4.py +30 -38
  92. sglang/srt/layers/quantization/unquant.py +23 -45
  93. sglang/srt/layers/quantization/w4afp8.py +38 -2
  94. sglang/srt/layers/radix_attention.py +5 -2
  95. sglang/srt/layers/rotary_embedding.py +130 -46
  96. sglang/srt/layers/sampler.py +12 -1
  97. sglang/srt/lora/lora_registry.py +9 -0
  98. sglang/srt/managers/async_mm_data_processor.py +122 -0
  99. sglang/srt/managers/data_parallel_controller.py +30 -3
  100. sglang/srt/managers/detokenizer_manager.py +3 -0
  101. sglang/srt/managers/io_struct.py +29 -4
  102. sglang/srt/managers/multi_tokenizer_mixin.py +22 -1
  103. sglang/srt/managers/schedule_batch.py +74 -15
  104. sglang/srt/managers/scheduler.py +185 -144
  105. sglang/srt/managers/scheduler_metrics_mixin.py +22 -14
  106. sglang/srt/managers/scheduler_output_processor_mixin.py +40 -3
  107. sglang/srt/managers/scheduler_pp_mixin.py +7 -2
  108. sglang/srt/managers/scheduler_profiler_mixin.py +3 -4
  109. sglang/srt/managers/scheduler_runtime_checker_mixin.py +45 -0
  110. sglang/srt/managers/scheduler_update_weights_mixin.py +18 -3
  111. sglang/srt/managers/session_controller.py +6 -5
  112. sglang/srt/managers/tokenizer_manager.py +165 -78
  113. sglang/srt/managers/tp_worker.py +24 -1
  114. sglang/srt/mem_cache/base_prefix_cache.py +23 -4
  115. sglang/srt/mem_cache/common.py +1 -0
  116. sglang/srt/mem_cache/hicache_storage.py +7 -1
  117. sglang/srt/mem_cache/memory_pool.py +253 -57
  118. sglang/srt/mem_cache/memory_pool_host.py +12 -5
  119. sglang/srt/mem_cache/radix_cache.py +4 -0
  120. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
  121. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +1 -1
  122. sglang/srt/metrics/collector.py +46 -3
  123. sglang/srt/model_executor/cuda_graph_runner.py +15 -3
  124. sglang/srt/model_executor/forward_batch_info.py +55 -14
  125. sglang/srt/model_executor/model_runner.py +77 -170
  126. sglang/srt/model_executor/npu_graph_runner.py +7 -3
  127. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +22 -12
  128. sglang/srt/model_loader/weight_utils.py +1 -1
  129. sglang/srt/models/bailing_moe.py +9 -2
  130. sglang/srt/models/deepseek_nextn.py +11 -2
  131. sglang/srt/models/deepseek_v2.py +296 -78
  132. sglang/srt/models/glm4.py +391 -77
  133. sglang/srt/models/glm4_moe.py +322 -354
  134. sglang/srt/models/glm4_moe_nextn.py +4 -14
  135. sglang/srt/models/glm4v.py +196 -55
  136. sglang/srt/models/glm4v_moe.py +29 -197
  137. sglang/srt/models/gpt_oss.py +1 -10
  138. sglang/srt/models/kimi_linear.py +678 -0
  139. sglang/srt/models/llama4.py +1 -1
  140. sglang/srt/models/llama_eagle3.py +11 -1
  141. sglang/srt/models/longcat_flash.py +2 -2
  142. sglang/srt/models/minimax_m2.py +922 -0
  143. sglang/srt/models/nvila.py +355 -0
  144. sglang/srt/models/nvila_lite.py +184 -0
  145. sglang/srt/models/qwen2.py +23 -2
  146. sglang/srt/models/qwen2_moe.py +30 -15
  147. sglang/srt/models/qwen3.py +35 -5
  148. sglang/srt/models/qwen3_moe.py +18 -12
  149. sglang/srt/models/qwen3_next.py +7 -0
  150. sglang/srt/multimodal/customized_mm_processor_utils.py +35 -0
  151. sglang/srt/multimodal/processors/base_processor.py +1 -0
  152. sglang/srt/multimodal/processors/glm4v.py +1 -1
  153. sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
  154. sglang/srt/multimodal/processors/points_v15_chat.py +2 -2
  155. sglang/srt/multiplex/multiplexing_mixin.py +209 -0
  156. sglang/srt/multiplex/pdmux_context.py +164 -0
  157. sglang/srt/parser/conversation.py +7 -1
  158. sglang/srt/parser/reasoning_parser.py +28 -1
  159. sglang/srt/sampling/custom_logit_processor.py +67 -1
  160. sglang/srt/sampling/penaltylib/frequency_penalty.py +6 -8
  161. sglang/srt/sampling/penaltylib/min_new_tokens.py +7 -8
  162. sglang/srt/sampling/penaltylib/orchestrator.py +43 -3
  163. sglang/srt/sampling/penaltylib/presence_penalty.py +6 -8
  164. sglang/srt/server_args.py +459 -199
  165. sglang/srt/single_batch_overlap.py +2 -4
  166. sglang/srt/speculative/draft_utils.py +16 -0
  167. sglang/srt/speculative/eagle_info.py +42 -36
  168. sglang/srt/speculative/eagle_info_v2.py +68 -25
  169. sglang/srt/speculative/eagle_utils.py +261 -16
  170. sglang/srt/speculative/eagle_worker.py +11 -3
  171. sglang/srt/speculative/eagle_worker_v2.py +15 -9
  172. sglang/srt/speculative/spec_info.py +305 -31
  173. sglang/srt/speculative/spec_utils.py +44 -8
  174. sglang/srt/tracing/trace.py +121 -12
  175. sglang/srt/utils/common.py +142 -74
  176. sglang/srt/utils/hf_transformers_utils.py +38 -12
  177. sglang/srt/utils/torch_memory_saver_adapter.py +20 -0
  178. sglang/test/kits/radix_cache_server_kit.py +50 -0
  179. sglang/test/runners.py +31 -7
  180. sglang/test/simple_eval_common.py +5 -3
  181. sglang/test/simple_eval_humaneval.py +1 -0
  182. sglang/test/simple_eval_math.py +1 -0
  183. sglang/test/simple_eval_mmlu.py +1 -0
  184. sglang/test/simple_eval_mmmu_vlm.py +1 -0
  185. sglang/test/test_deterministic.py +235 -12
  186. sglang/test/test_deterministic_utils.py +2 -1
  187. sglang/test/test_utils.py +7 -1
  188. sglang/version.py +1 -1
  189. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/METADATA +15 -28
  190. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/RECORD +194 -175
  191. sglang/srt/models/vila.py +0 -306
  192. /sglang/test/{kit_matched_stop.py → kits/matched_stop_kit.py} +0 -0
  193. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/WHEEL +0 -0
  194. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/licenses/LICENSE +0 -0
  195. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/top_level.txt +0 -0
@@ -1,8 +1,22 @@
1
- from typing import Tuple
2
-
3
- import torchvision.transforms as T
4
- from PIL import Image
5
- from transformers import PretrainedConfig
1
+ import math
2
+ from dataclasses import dataclass
3
+ from typing import Any, Dict, List, Optional, Tuple
4
+
5
+ import torch
6
+ from PIL import Image, ImageOps
7
+ from transformers import (
8
+ AutoProcessor,
9
+ LlamaTokenizerFast,
10
+ PretrainedConfig,
11
+ ProcessorMixin,
12
+ )
13
+
14
+ from sglang.srt.multimodal.customized_mm_processor_utils import (
15
+ register_customized_processor,
16
+ )
17
+ from sglang.srt.sampling.custom_logit_processor import (
18
+ DeepseekOCRNoRepeatNGramLogitProcessor,
19
+ )
6
20
 
7
21
  BASE_SIZE = 1024
8
22
  IMAGE_SIZE = 640
@@ -15,21 +29,80 @@ PRINT_NUM_VIS_TOKENS = False
15
29
  SKIP_REPEAT = True
16
30
  MODEL_PATH = "deepseek-ai/DeepSeek-OCR" # change to your model path
17
31
 
32
+ NGRAM_NO_REPEAT_SIZE = 30
33
+ NGRAM_NO_REPEAT_WINDOW = 90
34
+ # Whitelist `<td>` and `</td>` token ids to allow table structures.
35
+ NGRAM_NO_REPEAT_WHITELIST = (128821, 128822)
36
+
37
+ DEFAULT_CUSTOM_LOGIT_PROCESSOR = DeepseekOCRNoRepeatNGramLogitProcessor.to_str()
38
+
39
+
40
+ def get_default_ngram_custom_params() -> Dict[str, Any]:
41
+ """Return default custom params for the DeepSeek-OCR n-gram no repeat processor."""
42
+
43
+ return {
44
+ "ngram_size": NGRAM_NO_REPEAT_SIZE,
45
+ "window_size": NGRAM_NO_REPEAT_WINDOW,
46
+ "whitelist_token_ids": list(NGRAM_NO_REPEAT_WHITELIST),
47
+ }
48
+
49
+
18
50
  PROMPT = "<image>\n<|grounding|>Convert the document to markdown."
19
51
 
20
52
 
21
- class ImageTransform:
53
+ class DictOutput(object):
54
+ def items(self):
55
+ return self.__dict__.items()
56
+
57
+ def keys(self):
58
+ return self.__dict__.keys()
59
+
60
+ def __getitem__(self, item):
61
+ return self.__dict__[item]
22
62
 
63
+ def __contains__(self, key):
64
+ return key in self.__dict__
65
+
66
+ def __setitem__(self, key, value):
67
+ self.__dict__[key] = value
68
+
69
+
70
+ @dataclass
71
+ class VLChatProcessorOutput(DictOutput):
72
+ input_ids: torch.LongTensor
73
+ target_ids: torch.LongTensor
74
+ images_crop: torch.LongTensor
75
+ pixel_values: (
76
+ torch.Tensor
77
+ ) # rename from "images" to "pixel_values" for compatibility
78
+ images_seq_mask: torch.BoolTensor
79
+ images_spatial_crop: torch.LongTensor
80
+
81
+ def __len__(self):
82
+ return len(self.input_ids)
83
+
84
+
85
+ class ImageTransform(object):
23
86
  def __init__(
24
87
  self,
25
- mean: Tuple[float, float, float] = (0.5, 0.5, 0.5),
26
- std: Tuple[float, float, float] = (0.5, 0.5, 0.5),
88
+ mean: Optional[Tuple[float, float, float]] = (0.5, 0.5, 0.5),
89
+ std: Optional[Tuple[float, float, float]] = (0.5, 0.5, 0.5),
27
90
  normalize: bool = True,
28
91
  ):
29
92
  self.mean = mean
30
93
  self.std = std
31
94
  self.normalize = normalize
32
95
 
96
+ # only load torchvision.transforms when needed
97
+ try:
98
+ import torchvision.transforms as T
99
+
100
+ # FIXME: add version check for gguf
101
+ except ImportError as err:
102
+ raise ImportError(
103
+ "Please install torchvision via `pip install torchvision` to use Deepseek-VL2."
104
+ ) from err
105
+
33
106
  transform_pipelines = [T.ToTensor()]
34
107
 
35
108
  if normalize:
@@ -42,6 +115,464 @@ class ImageTransform:
42
115
  return x
43
116
 
44
117
 
118
+ def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
119
+ best_ratio_diff = float("inf")
120
+ best_ratio = (1, 1)
121
+ area = width * height
122
+ for ratio in target_ratios:
123
+ target_aspect_ratio = ratio[0] / ratio[1]
124
+ ratio_diff = abs(aspect_ratio - target_aspect_ratio)
125
+ if ratio_diff < best_ratio_diff:
126
+ best_ratio_diff = ratio_diff
127
+ best_ratio = ratio
128
+ elif ratio_diff == best_ratio_diff:
129
+ if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
130
+ best_ratio = ratio
131
+ return best_ratio
132
+
133
+
134
+ def dynamic_preprocess(
135
+ image, min_num=MIN_CROPS, max_num=MAX_CROPS, image_size=640, use_thumbnail=False
136
+ ):
137
+ orig_width, orig_height = image.size
138
+ aspect_ratio = orig_width / orig_height
139
+
140
+ # calculate the existing image aspect ratio
141
+ target_ratios = set(
142
+ (i, j)
143
+ for n in range(min_num, max_num + 1)
144
+ for i in range(1, n + 1)
145
+ for j in range(1, n + 1)
146
+ if i * j <= max_num and i * j >= min_num
147
+ )
148
+ target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
149
+
150
+ # find the closest aspect ratio to the target
151
+ target_aspect_ratio = find_closest_aspect_ratio(
152
+ aspect_ratio, target_ratios, orig_width, orig_height, image_size
153
+ )
154
+
155
+ # calculate the target width and height
156
+ target_width = image_size * target_aspect_ratio[0]
157
+ target_height = image_size * target_aspect_ratio[1]
158
+ blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
159
+
160
+ # resize the image
161
+ resized_img = image.resize((target_width, target_height))
162
+ processed_images = []
163
+ for i in range(blocks):
164
+ box = (
165
+ (i % (target_width // image_size)) * image_size,
166
+ (i // (target_width // image_size)) * image_size,
167
+ ((i % (target_width // image_size)) + 1) * image_size,
168
+ ((i // (target_width // image_size)) + 1) * image_size,
169
+ )
170
+ # split the image
171
+ split_img = resized_img.crop(box)
172
+ processed_images.append(split_img)
173
+ assert len(processed_images) == blocks
174
+ if use_thumbnail and len(processed_images) != 1:
175
+ thumbnail_img = image.resize((image_size, image_size))
176
+ processed_images.append(thumbnail_img)
177
+ return processed_images, target_aspect_ratio
178
+
179
+
180
+ class DeepseekOCRProcessor(ProcessorMixin):
181
+ tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast")
182
+ attributes = ["tokenizer"]
183
+
184
+ def __init__(
185
+ self,
186
+ tokenizer: LlamaTokenizerFast,
187
+ candidate_resolutions: Tuple[Tuple[int, int]],
188
+ patch_size: int,
189
+ downsample_ratio: int,
190
+ image_mean: Tuple[float, float, float] = (0.5, 0.5, 0.5),
191
+ image_std: Tuple[float, float, float] = (0.5, 0.5, 0.5),
192
+ normalize: bool = True,
193
+ image_token: str = "<image>",
194
+ pad_token: str = "<|▁pad▁|>",
195
+ add_special_token: bool = False,
196
+ sft_format: str = "deepseek",
197
+ mask_prompt: bool = True,
198
+ ignore_id: int = -100,
199
+ **kwargs,
200
+ ):
201
+
202
+ self.candidate_resolutions = candidate_resolutions
203
+ self.image_size = candidate_resolutions[0][0]
204
+ self.patch_size = patch_size
205
+ self.image_mean = image_mean
206
+ self.image_std = image_std
207
+ self.normalize = normalize
208
+ self.downsample_ratio = downsample_ratio
209
+ self.base_size = BASE_SIZE
210
+ self.image_transform = ImageTransform(
211
+ mean=image_mean, std=image_std, normalize=normalize
212
+ )
213
+ self.tokenizer = tokenizer
214
+ # must set this,padding side with make a difference in batch inference
215
+ self.tokenizer.padding_side = "left"
216
+
217
+ # add the pad_token as special token to use 'tokenizer.pad_token' and 'tokenizer.pad_token_id'
218
+ if tokenizer.pad_token is None:
219
+ self.tokenizer.add_special_tokens({"pad_token": pad_token})
220
+
221
+ # add image token
222
+ image_token_id = self.tokenizer.vocab.get(image_token)
223
+ if image_token_id is None:
224
+ special_tokens = [image_token]
225
+ special_tokens_dict = {"additional_special_tokens": special_tokens}
226
+ self.tokenizer.add_special_tokens(special_tokens_dict)
227
+ self.image_token_id = self.tokenizer.vocab.get(image_token)
228
+
229
+ # add five special tokens for grounding-related tasks
230
+ # <|ref|>, <|/ref|>, <|det|>, <|/det|>, <|grounding|>
231
+ special_tokens = ["<|ref|>", "<|/ref|>", "<|det|>", "<|/det|>", "<|grounding|>"]
232
+ special_tokens_dict = {"additional_special_tokens": special_tokens}
233
+ self.tokenizer.add_special_tokens(special_tokens_dict)
234
+
235
+ # add special tokens for SFT data
236
+ special_tokens = ["<|User|>", "<|Assistant|>"]
237
+ special_tokens_dict = {"additional_special_tokens": special_tokens}
238
+ self.tokenizer.add_special_tokens(special_tokens_dict)
239
+
240
+ self.image_token = image_token
241
+ self.pad_token = pad_token
242
+ self.add_special_token = add_special_token
243
+ self.sft_format = sft_format
244
+ self.mask_prompt = mask_prompt
245
+ self.ignore_id = ignore_id
246
+
247
+ super().__init__(
248
+ tokenizer,
249
+ **kwargs,
250
+ )
251
+
252
+ def format_messages_v2(self, messages: str, pil_images, max_req_input_len=-1):
253
+ """play the role of format_messages_v2 and get_images_info in the last version"""
254
+ tokenized_data = []
255
+ masked_tokenized_data = [] # labels
256
+ images_list = []
257
+ images_seq_mask = []
258
+ images_spatial_crop = []
259
+
260
+ image_index = 0
261
+ image_token_cnt = messages.count(self.image_token)
262
+ (
263
+ input_ids,
264
+ images,
265
+ images_crop,
266
+ seq_mask,
267
+ spatial_crop,
268
+ num_image_tokens,
269
+ image_shapes,
270
+ ) = self.tokenize_with_images(
271
+ messages,
272
+ pil_images[image_index : image_index + image_token_cnt],
273
+ bos=True,
274
+ eos=True,
275
+ cropping=len(pil_images) <= 2,
276
+ )
277
+
278
+ image_index = image_token_cnt
279
+ images_list += images
280
+ images_seq_mask += seq_mask
281
+ images_spatial_crop = spatial_crop
282
+
283
+ return (
284
+ input_ids,
285
+ masked_tokenized_data,
286
+ images_list,
287
+ images_seq_mask,
288
+ images_spatial_crop,
289
+ images_crop,
290
+ )
291
+
292
+ @property
293
+ def bos_id(self):
294
+ return self.tokenizer.bos_token_id
295
+
296
+ @property
297
+ def eos_id(self):
298
+ return self.tokenizer.eos_token_id
299
+
300
+ @property
301
+ def pad_id(self):
302
+ return self.tokenizer.pad_token_id
303
+
304
+ def encode(self, text: str, bos: bool = True, eos: bool = False):
305
+ t = self.tokenizer.encode(text, add_special_tokens=False)
306
+
307
+ if bos:
308
+ t = [self.bos_id] + t
309
+ if eos:
310
+ t = t + [self.eos_id]
311
+
312
+ return t
313
+
314
+ def decode(self, t: List[int], **kwargs) -> str:
315
+ return self.tokenizer.decode(t, **kwargs)
316
+
317
+ def process_one(
318
+ self,
319
+ prompt: str = None,
320
+ conversations: List[Dict[str, str]] = None,
321
+ images: List[Image.Image] = None,
322
+ apply_sft_format: bool = False,
323
+ inference_mode: bool = True,
324
+ system_prompt: str = "",
325
+ max_req_input_len: int = -1,
326
+ cropping: bool = True,
327
+ **kwargs,
328
+ ):
329
+ """
330
+
331
+ Args:
332
+ prompt (str): the formatted prompt;
333
+ conversations (List[Dict]): conversations with a list of messages;
334
+ images (List[ImageType]): the list of images;
335
+ apply_sft_format (bool): if prompt is not None, then apply the SFT format to prompt;
336
+ if conversations is not None, then it will always apply the SFT format to conversations;
337
+ inference_mode (bool): if True, then remove the last eos token;
338
+ system_prompt (str): the system prompt;
339
+ **kwargs:
340
+
341
+ Returns:
342
+ outputs (BaseProcessorOutput): the output of the processor,
343
+ - input_ids (torch.LongTensor): [N + image tokens]
344
+ - target_ids (torch.LongTensor): [N + image tokens]
345
+ - images (torch.FloatTensor): [n_images, 3, H, W]
346
+ - image_id (int): the id of the image token
347
+ - num_image_tokens (List[int]): the number of image tokens
348
+ """
349
+
350
+ prompt = conversations or prompt
351
+ (
352
+ input_ids,
353
+ masked_tokenized_str,
354
+ images_list,
355
+ images_seq_mask,
356
+ images_spatial_crop,
357
+ images_crop,
358
+ ) = self.format_messages_v2(prompt, images, max_req_input_len)
359
+
360
+ target_ids = torch.LongTensor(masked_tokenized_str)
361
+
362
+ if len(images_list) == 0:
363
+ images = torch.zeros((1, 3, self.image_size, self.image_size))
364
+ else:
365
+ images = torch.stack(images_list, dim=0)
366
+
367
+ images_spatial_crop = torch.stack(
368
+ [images_spatial_crop], dim=0
369
+ ) # stack the tensor to make it a batch of 1
370
+
371
+ prepare = VLChatProcessorOutput(
372
+ input_ids=input_ids,
373
+ target_ids=target_ids,
374
+ images_crop=images_crop,
375
+ pixel_values=images,
376
+ images_seq_mask=images_seq_mask,
377
+ images_spatial_crop=images_spatial_crop,
378
+ )
379
+
380
+ return prepare
381
+
382
+ def __call__(
383
+ self,
384
+ *,
385
+ prompt: str = None,
386
+ conversations: List[Dict[str, str]] = None,
387
+ images: List[Image.Image] = None,
388
+ apply_sft_format: bool = False,
389
+ inference_mode: bool = True,
390
+ system_prompt: str = "",
391
+ max_req_input_len: int = -1,
392
+ text: list[str] = None,
393
+ **kwargs,
394
+ ):
395
+ assert text is None or isinstance(text, list)
396
+ if text is not None:
397
+ text = text[0]
398
+ prepare = self.process_one(
399
+ prompt=prompt or text,
400
+ conversations=conversations,
401
+ images=images,
402
+ apply_sft_format=apply_sft_format,
403
+ inference_mode=inference_mode,
404
+ system_prompt=system_prompt,
405
+ max_req_input_len=max_req_input_len,
406
+ )
407
+
408
+ return prepare
409
+
410
+ def find_all_indices(self, messages, target_value):
411
+ indices = []
412
+ for index, item in enumerate(messages):
413
+ if item == target_value:
414
+ indices.append(index)
415
+ return indices
416
+
417
+ def tokenize_with_images(
418
+ self,
419
+ conversation: str,
420
+ images: List[Image.Image],
421
+ bos: bool = True,
422
+ eos: bool = True,
423
+ cropping: bool = True,
424
+ ):
425
+ """Tokenize text with <image> tags."""
426
+
427
+ conversation = conversation
428
+ assert conversation.count(self.image_token) == len(images)
429
+ text_splits = conversation.split(self.image_token)
430
+ images_list, images_crop_list, images_seq_mask, images_spatial_crop = (
431
+ [],
432
+ [],
433
+ [],
434
+ [],
435
+ )
436
+ image_shapes = []
437
+ num_image_tokens = []
438
+ tokenized_str = []
439
+ for text_sep, image in zip(text_splits, images):
440
+ """encode text_sep"""
441
+ tokenized_sep = self.encode(text_sep, bos=False, eos=False)
442
+
443
+ tokenized_str += tokenized_sep
444
+ images_seq_mask += [False] * len(tokenized_sep)
445
+
446
+ image_shapes.append(image.size)
447
+
448
+ if image.size[0] <= 640 and image.size[1] <= 640:
449
+ crop_ratio = [1, 1]
450
+ else:
451
+ if cropping:
452
+ images_crop_raw, crop_ratio = dynamic_preprocess(
453
+ image, image_size=IMAGE_SIZE
454
+ )
455
+ else:
456
+ crop_ratio = [1, 1]
457
+
458
+ """process the global view"""
459
+ if self.image_size <= 640 and not cropping:
460
+ image = image.resize((self.image_size, self.image_size))
461
+
462
+ global_view = ImageOps.pad(
463
+ image,
464
+ (self.base_size, self.base_size),
465
+ color=tuple(int(x * 255) for x in self.image_transform.mean),
466
+ )
467
+ images_list.append(self.image_transform(global_view))
468
+
469
+ num_width_tiles, num_height_tiles = crop_ratio
470
+ images_spatial_crop.append([num_width_tiles, num_height_tiles])
471
+
472
+ if num_width_tiles > 1 or num_height_tiles > 1:
473
+ for i in range(len(images_crop_raw)):
474
+ images_crop_list.append(self.image_transform(images_crop_raw[i]))
475
+
476
+ """add image tokens"""
477
+ num_queries = math.ceil(
478
+ (self.image_size // self.patch_size) / self.downsample_ratio
479
+ )
480
+ num_queries_base = math.ceil(
481
+ (self.base_size // self.patch_size) / self.downsample_ratio
482
+ )
483
+
484
+ tokenized_image = (
485
+ [self.image_token_id] * num_queries_base + [self.image_token_id]
486
+ ) * num_queries_base
487
+ tokenized_image += [self.image_token_id]
488
+ if num_width_tiles > 1 or num_height_tiles > 1:
489
+ tokenized_image += (
490
+ [self.image_token_id] * (num_queries * num_width_tiles)
491
+ + [self.image_token_id]
492
+ ) * (num_queries * num_height_tiles)
493
+ tokenized_str += tokenized_image
494
+
495
+ images_seq_mask += [True] * len(tokenized_image)
496
+ num_image_tokens.append(len(tokenized_image))
497
+
498
+ """process the last text split"""
499
+ tokenized_sep = self.encode(text_splits[-1], bos=False, eos=False)
500
+
501
+ tokenized_str += tokenized_sep
502
+ images_seq_mask += [False] * len(tokenized_sep)
503
+
504
+ """add the bos and eos tokens"""
505
+ if bos:
506
+ tokenized_str = [self.bos_id] + tokenized_str
507
+ images_seq_mask = [False] + images_seq_mask
508
+ if eos:
509
+ tokenized_str = tokenized_str + [self.eos_id]
510
+ images_seq_mask = images_seq_mask + [False]
511
+
512
+ assert len(tokenized_str) == len(
513
+ images_seq_mask
514
+ ), f"tokenize_with_images func: tokenized_str's length {len(tokenized_str)} is not equal to imags_seq_mask's length {len(images_seq_mask)}"
515
+
516
+ masked_tokenized_str = []
517
+ for token_index in tokenized_str:
518
+ if token_index != self.image_token_id:
519
+ masked_tokenized_str.append(token_index)
520
+ else:
521
+ masked_tokenized_str.append(self.ignore_id)
522
+
523
+ assert (
524
+ len(tokenized_str) == len(images_seq_mask) == len(masked_tokenized_str)
525
+ ), (
526
+ f"tokenized_str's length {len(tokenized_str)}, input_ids' length {len(masked_tokenized_str)}, "
527
+ f"imags_seq_mask's length {len(images_seq_mask)}, are not equal"
528
+ )
529
+ input_ids = torch.LongTensor(tokenized_str)
530
+ target_ids = torch.LongTensor(masked_tokenized_str)
531
+ images_seq_mask = torch.tensor(images_seq_mask, dtype=torch.bool)
532
+
533
+ # set input_ids < 0 | input_ids == self.image_token_id as ignore_id
534
+ target_ids[(input_ids < 0) | (input_ids == self.image_token_id)] = (
535
+ self.ignore_id
536
+ )
537
+ input_ids[input_ids < 0] = self.pad_id
538
+
539
+ inference_mode = True
540
+
541
+ if inference_mode:
542
+ # Remove the ending eos token
543
+ assert input_ids[-1] == self.eos_id
544
+ input_ids = input_ids[:-1]
545
+ target_ids = target_ids[:-1]
546
+ images_seq_mask = images_seq_mask[:-1]
547
+
548
+ if len(images_list) == 0:
549
+ pixel_values = torch.zeros((1, 3, self.base_size, self.base_size))
550
+ images_spatial_crop = torch.zeros((1, 1), dtype=torch.long)
551
+ images_crop = torch.zeros(
552
+ (1, 3, self.image_size, self.image_size)
553
+ ).unsqueeze(0)
554
+ else:
555
+ pixel_values = torch.stack(images_list, dim=0)
556
+ images_spatial_crop = torch.tensor(images_spatial_crop, dtype=torch.long)
557
+ if images_crop_list:
558
+ images_crop = torch.stack(images_crop_list, dim=0).unsqueeze(0)
559
+ else:
560
+ images_crop = torch.zeros(
561
+ (1, 3, self.image_size, self.image_size)
562
+ ).unsqueeze(0)
563
+
564
+ input_ids = input_ids.unsqueeze(0)
565
+ return (
566
+ input_ids,
567
+ pixel_values,
568
+ images_crop,
569
+ images_seq_mask,
570
+ images_spatial_crop,
571
+ num_image_tokens,
572
+ image_shapes,
573
+ )
574
+
575
+
45
576
  class VisionEncoderConfig(PretrainedConfig):
46
577
  model_type: str = "vision"
47
578
 
@@ -223,6 +754,7 @@ class DeepseekV2Config(PretrainedConfig):
223
754
  )
224
755
 
225
756
 
757
+ @register_customized_processor(processor_class=DeepseekOCRProcessor)
226
758
  class DeepseekVLV2Config(PretrainedConfig):
227
759
  # model_type = "deepseek_vl_v2"
228
760
  model_type = "deepseek-ocr"
@@ -232,6 +764,7 @@ class DeepseekVLV2Config(PretrainedConfig):
232
764
  tile_tag: str = "2D"
233
765
  global_view_pos: str = "head"
234
766
  candidate_resolutions: tuple[tuple[int, int]] = ((384, 384),)
767
+ customized_processor_type: type[Any] = DeepseekOCRProcessor
235
768
 
236
769
  def __init__(
237
770
  self,
@@ -258,5 +791,4 @@ class DeepseekVLV2Config(PretrainedConfig):
258
791
  self.hidden_size = self.text_config.hidden_size
259
792
 
260
793
 
261
- class DeepseekOCRConfig(DeepseekV2Config):
262
- model_type = "DeepseekOCR"
794
+ AutoProcessor.register(DeepseekVLV2Config, DeepseekOCRProcessor)