sglang 0.5.4.post1__py3-none-any.whl → 0.5.4.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +149 -34
- sglang/bench_serving.py +18 -3
- sglang/compile_deep_gemm.py +13 -7
- sglang/srt/batch_invariant_ops/__init__.py +2 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +120 -0
- sglang/srt/checkpoint_engine/__init__.py +9 -0
- sglang/srt/checkpoint_engine/update.py +317 -0
- sglang/srt/configs/__init__.py +2 -0
- sglang/srt/configs/deepseek_ocr.py +542 -10
- sglang/srt/configs/deepseekvl2.py +95 -194
- sglang/srt/configs/kimi_linear.py +160 -0
- sglang/srt/configs/mamba_utils.py +66 -0
- sglang/srt/configs/model_config.py +25 -2
- sglang/srt/constants.py +7 -0
- sglang/srt/debug_utils/tensor_dump_forward_hook.py +149 -0
- sglang/srt/disaggregation/decode.py +34 -6
- sglang/srt/disaggregation/nixl/conn.py +2 -2
- sglang/srt/disaggregation/prefill.py +25 -3
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -1
- sglang/srt/distributed/parallel_state.py +9 -5
- sglang/srt/entrypoints/engine.py +13 -5
- sglang/srt/entrypoints/http_server.py +22 -3
- sglang/srt/entrypoints/openai/protocol.py +7 -1
- sglang/srt/entrypoints/openai/serving_chat.py +42 -0
- sglang/srt/entrypoints/openai/serving_completions.py +10 -0
- sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
- sglang/srt/environ.py +7 -0
- sglang/srt/eplb/expert_distribution.py +34 -1
- sglang/srt/eplb/expert_location.py +106 -36
- sglang/srt/grpc/compile_proto.py +3 -0
- sglang/srt/layers/attention/ascend_backend.py +233 -5
- sglang/srt/layers/attention/attention_registry.py +3 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +61 -32
- sglang/srt/layers/attention/fla/fused_recurrent.py +17 -4
- sglang/srt/layers/attention/fla/kda.py +1359 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +7 -1
- sglang/srt/layers/attention/flashattention_backend.py +7 -6
- sglang/srt/layers/attention/flashinfer_mla_backend.py +3 -1
- sglang/srt/layers/attention/flashmla_backend.py +1 -1
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +223 -0
- sglang/srt/layers/attention/mamba/mamba.py +20 -11
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +138 -6
- sglang/srt/layers/attention/nsa/nsa_indexer.py +45 -22
- sglang/srt/layers/attention/nsa/quant_k_cache.py +44 -12
- sglang/srt/layers/attention/nsa/transform_index.py +1 -1
- sglang/srt/layers/attention/nsa_backend.py +157 -23
- sglang/srt/layers/attention/triton_backend.py +4 -1
- sglang/srt/layers/attention/trtllm_mha_backend.py +10 -4
- sglang/srt/layers/attention/trtllm_mla_backend.py +10 -2
- sglang/srt/layers/communicator.py +23 -1
- sglang/srt/layers/layernorm.py +16 -2
- sglang/srt/layers/logits_processor.py +4 -20
- sglang/srt/layers/moe/ep_moe/layer.py +0 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128]_down.json +164 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +68 -22
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +43 -3
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +106 -26
- sglang/srt/layers/moe/moe_runner/deep_gemm.py +53 -33
- sglang/srt/layers/moe/token_dispatcher/deepep.py +12 -9
- sglang/srt/layers/moe/topk.py +31 -6
- sglang/srt/layers/pooler.py +21 -2
- sglang/srt/layers/quantization/__init__.py +9 -78
- sglang/srt/layers/quantization/auto_round.py +394 -0
- sglang/srt/layers/quantization/fp8_kernel.py +1 -1
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/quantization/modelopt_quant.py +168 -11
- sglang/srt/layers/rotary_embedding.py +117 -45
- sglang/srt/lora/lora_registry.py +9 -0
- sglang/srt/managers/async_mm_data_processor.py +122 -0
- sglang/srt/managers/data_parallel_controller.py +30 -3
- sglang/srt/managers/detokenizer_manager.py +3 -0
- sglang/srt/managers/io_struct.py +26 -4
- sglang/srt/managers/multi_tokenizer_mixin.py +5 -0
- sglang/srt/managers/schedule_batch.py +74 -15
- sglang/srt/managers/scheduler.py +164 -129
- sglang/srt/managers/scheduler_output_processor_mixin.py +40 -3
- sglang/srt/managers/scheduler_pp_mixin.py +7 -2
- sglang/srt/managers/scheduler_runtime_checker_mixin.py +45 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +18 -3
- sglang/srt/managers/session_controller.py +6 -5
- sglang/srt/managers/tokenizer_manager.py +154 -59
- sglang/srt/managers/tp_worker.py +24 -1
- sglang/srt/mem_cache/base_prefix_cache.py +23 -4
- sglang/srt/mem_cache/common.py +1 -0
- sglang/srt/mem_cache/memory_pool.py +171 -57
- sglang/srt/mem_cache/memory_pool_host.py +12 -5
- sglang/srt/mem_cache/radix_cache.py +4 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +1 -1
- sglang/srt/metrics/collector.py +46 -3
- sglang/srt/model_executor/cuda_graph_runner.py +15 -3
- sglang/srt/model_executor/forward_batch_info.py +11 -11
- sglang/srt/model_executor/model_runner.py +76 -21
- sglang/srt/model_executor/npu_graph_runner.py +7 -3
- sglang/srt/model_loader/weight_utils.py +1 -1
- sglang/srt/models/bailing_moe.py +9 -2
- sglang/srt/models/deepseek_nextn.py +11 -2
- sglang/srt/models/deepseek_v2.py +149 -34
- sglang/srt/models/glm4.py +391 -77
- sglang/srt/models/glm4v.py +196 -55
- sglang/srt/models/glm4v_moe.py +0 -1
- sglang/srt/models/gpt_oss.py +1 -10
- sglang/srt/models/kimi_linear.py +678 -0
- sglang/srt/models/llama4.py +1 -1
- sglang/srt/models/llama_eagle3.py +11 -1
- sglang/srt/models/longcat_flash.py +2 -2
- sglang/srt/models/minimax_m2.py +1 -1
- sglang/srt/models/qwen2.py +1 -1
- sglang/srt/models/qwen2_moe.py +30 -15
- sglang/srt/models/qwen3.py +1 -1
- sglang/srt/models/qwen3_moe.py +16 -8
- sglang/srt/models/qwen3_next.py +7 -0
- sglang/srt/multimodal/customized_mm_processor_utils.py +35 -0
- sglang/srt/multiplex/multiplexing_mixin.py +209 -0
- sglang/srt/multiplex/pdmux_context.py +164 -0
- sglang/srt/parser/conversation.py +7 -1
- sglang/srt/sampling/custom_logit_processor.py +67 -1
- sglang/srt/sampling/penaltylib/frequency_penalty.py +6 -8
- sglang/srt/sampling/penaltylib/min_new_tokens.py +7 -8
- sglang/srt/sampling/penaltylib/orchestrator.py +43 -3
- sglang/srt/sampling/penaltylib/presence_penalty.py +6 -8
- sglang/srt/server_args.py +103 -22
- sglang/srt/single_batch_overlap.py +4 -1
- sglang/srt/speculative/draft_utils.py +16 -0
- sglang/srt/speculative/eagle_info.py +42 -36
- sglang/srt/speculative/eagle_info_v2.py +68 -25
- sglang/srt/speculative/eagle_utils.py +261 -16
- sglang/srt/speculative/eagle_worker.py +11 -3
- sglang/srt/speculative/eagle_worker_v2.py +15 -9
- sglang/srt/speculative/spec_info.py +305 -31
- sglang/srt/speculative/spec_utils.py +44 -8
- sglang/srt/tracing/trace.py +121 -12
- sglang/srt/utils/common.py +55 -32
- sglang/srt/utils/hf_transformers_utils.py +38 -16
- sglang/srt/utils/torch_memory_saver_adapter.py +20 -0
- sglang/test/kits/radix_cache_server_kit.py +50 -0
- sglang/test/runners.py +31 -7
- sglang/test/simple_eval_common.py +5 -3
- sglang/test/simple_eval_humaneval.py +1 -0
- sglang/test/simple_eval_math.py +1 -0
- sglang/test/simple_eval_mmlu.py +1 -0
- sglang/test/simple_eval_mmmu_vlm.py +1 -0
- sglang/test/test_utils.py +7 -1
- sglang/version.py +1 -1
- {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/METADATA +10 -24
- {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/RECORD +150 -136
- /sglang/test/{kit_matched_stop.py → kits/matched_stop_kit.py} +0 -0
- {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/WHEEL +0 -0
- {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/top_level.txt +0 -0
|
@@ -1,8 +1,22 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
3
|
-
import
|
|
4
|
-
|
|
5
|
-
|
|
1
|
+
import math
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from typing import Any, Dict, List, Optional, Tuple
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from PIL import Image, ImageOps
|
|
7
|
+
from transformers import (
|
|
8
|
+
AutoProcessor,
|
|
9
|
+
LlamaTokenizerFast,
|
|
10
|
+
PretrainedConfig,
|
|
11
|
+
ProcessorMixin,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
from sglang.srt.multimodal.customized_mm_processor_utils import (
|
|
15
|
+
register_customized_processor,
|
|
16
|
+
)
|
|
17
|
+
from sglang.srt.sampling.custom_logit_processor import (
|
|
18
|
+
DeepseekOCRNoRepeatNGramLogitProcessor,
|
|
19
|
+
)
|
|
6
20
|
|
|
7
21
|
BASE_SIZE = 1024
|
|
8
22
|
IMAGE_SIZE = 640
|
|
@@ -15,21 +29,80 @@ PRINT_NUM_VIS_TOKENS = False
|
|
|
15
29
|
SKIP_REPEAT = True
|
|
16
30
|
MODEL_PATH = "deepseek-ai/DeepSeek-OCR" # change to your model path
|
|
17
31
|
|
|
32
|
+
NGRAM_NO_REPEAT_SIZE = 30
|
|
33
|
+
NGRAM_NO_REPEAT_WINDOW = 90
|
|
34
|
+
# Whitelist `<td>` and `</td>` token ids to allow table structures.
|
|
35
|
+
NGRAM_NO_REPEAT_WHITELIST = (128821, 128822)
|
|
36
|
+
|
|
37
|
+
DEFAULT_CUSTOM_LOGIT_PROCESSOR = DeepseekOCRNoRepeatNGramLogitProcessor.to_str()
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def get_default_ngram_custom_params() -> Dict[str, Any]:
|
|
41
|
+
"""Return default custom params for the DeepSeek-OCR n-gram no repeat processor."""
|
|
42
|
+
|
|
43
|
+
return {
|
|
44
|
+
"ngram_size": NGRAM_NO_REPEAT_SIZE,
|
|
45
|
+
"window_size": NGRAM_NO_REPEAT_WINDOW,
|
|
46
|
+
"whitelist_token_ids": list(NGRAM_NO_REPEAT_WHITELIST),
|
|
47
|
+
}
|
|
48
|
+
|
|
49
|
+
|
|
18
50
|
PROMPT = "<image>\n<|grounding|>Convert the document to markdown."
|
|
19
51
|
|
|
20
52
|
|
|
21
|
-
class
|
|
53
|
+
class DictOutput(object):
|
|
54
|
+
def items(self):
|
|
55
|
+
return self.__dict__.items()
|
|
56
|
+
|
|
57
|
+
def keys(self):
|
|
58
|
+
return self.__dict__.keys()
|
|
59
|
+
|
|
60
|
+
def __getitem__(self, item):
|
|
61
|
+
return self.__dict__[item]
|
|
22
62
|
|
|
63
|
+
def __contains__(self, key):
|
|
64
|
+
return key in self.__dict__
|
|
65
|
+
|
|
66
|
+
def __setitem__(self, key, value):
|
|
67
|
+
self.__dict__[key] = value
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
@dataclass
|
|
71
|
+
class VLChatProcessorOutput(DictOutput):
|
|
72
|
+
input_ids: torch.LongTensor
|
|
73
|
+
target_ids: torch.LongTensor
|
|
74
|
+
images_crop: torch.LongTensor
|
|
75
|
+
pixel_values: (
|
|
76
|
+
torch.Tensor
|
|
77
|
+
) # rename from "images" to "pixel_values" for compatibility
|
|
78
|
+
images_seq_mask: torch.BoolTensor
|
|
79
|
+
images_spatial_crop: torch.LongTensor
|
|
80
|
+
|
|
81
|
+
def __len__(self):
|
|
82
|
+
return len(self.input_ids)
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
class ImageTransform(object):
|
|
23
86
|
def __init__(
|
|
24
87
|
self,
|
|
25
|
-
mean: Tuple[float, float, float] = (0.5, 0.5, 0.5),
|
|
26
|
-
std: Tuple[float, float, float] = (0.5, 0.5, 0.5),
|
|
88
|
+
mean: Optional[Tuple[float, float, float]] = (0.5, 0.5, 0.5),
|
|
89
|
+
std: Optional[Tuple[float, float, float]] = (0.5, 0.5, 0.5),
|
|
27
90
|
normalize: bool = True,
|
|
28
91
|
):
|
|
29
92
|
self.mean = mean
|
|
30
93
|
self.std = std
|
|
31
94
|
self.normalize = normalize
|
|
32
95
|
|
|
96
|
+
# only load torchvision.transforms when needed
|
|
97
|
+
try:
|
|
98
|
+
import torchvision.transforms as T
|
|
99
|
+
|
|
100
|
+
# FIXME: add version check for gguf
|
|
101
|
+
except ImportError as err:
|
|
102
|
+
raise ImportError(
|
|
103
|
+
"Please install torchvision via `pip install torchvision` to use Deepseek-VL2."
|
|
104
|
+
) from err
|
|
105
|
+
|
|
33
106
|
transform_pipelines = [T.ToTensor()]
|
|
34
107
|
|
|
35
108
|
if normalize:
|
|
@@ -42,6 +115,464 @@ class ImageTransform:
|
|
|
42
115
|
return x
|
|
43
116
|
|
|
44
117
|
|
|
118
|
+
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
|
|
119
|
+
best_ratio_diff = float("inf")
|
|
120
|
+
best_ratio = (1, 1)
|
|
121
|
+
area = width * height
|
|
122
|
+
for ratio in target_ratios:
|
|
123
|
+
target_aspect_ratio = ratio[0] / ratio[1]
|
|
124
|
+
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
|
|
125
|
+
if ratio_diff < best_ratio_diff:
|
|
126
|
+
best_ratio_diff = ratio_diff
|
|
127
|
+
best_ratio = ratio
|
|
128
|
+
elif ratio_diff == best_ratio_diff:
|
|
129
|
+
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
|
|
130
|
+
best_ratio = ratio
|
|
131
|
+
return best_ratio
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def dynamic_preprocess(
|
|
135
|
+
image, min_num=MIN_CROPS, max_num=MAX_CROPS, image_size=640, use_thumbnail=False
|
|
136
|
+
):
|
|
137
|
+
orig_width, orig_height = image.size
|
|
138
|
+
aspect_ratio = orig_width / orig_height
|
|
139
|
+
|
|
140
|
+
# calculate the existing image aspect ratio
|
|
141
|
+
target_ratios = set(
|
|
142
|
+
(i, j)
|
|
143
|
+
for n in range(min_num, max_num + 1)
|
|
144
|
+
for i in range(1, n + 1)
|
|
145
|
+
for j in range(1, n + 1)
|
|
146
|
+
if i * j <= max_num and i * j >= min_num
|
|
147
|
+
)
|
|
148
|
+
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
|
|
149
|
+
|
|
150
|
+
# find the closest aspect ratio to the target
|
|
151
|
+
target_aspect_ratio = find_closest_aspect_ratio(
|
|
152
|
+
aspect_ratio, target_ratios, orig_width, orig_height, image_size
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
# calculate the target width and height
|
|
156
|
+
target_width = image_size * target_aspect_ratio[0]
|
|
157
|
+
target_height = image_size * target_aspect_ratio[1]
|
|
158
|
+
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
|
|
159
|
+
|
|
160
|
+
# resize the image
|
|
161
|
+
resized_img = image.resize((target_width, target_height))
|
|
162
|
+
processed_images = []
|
|
163
|
+
for i in range(blocks):
|
|
164
|
+
box = (
|
|
165
|
+
(i % (target_width // image_size)) * image_size,
|
|
166
|
+
(i // (target_width // image_size)) * image_size,
|
|
167
|
+
((i % (target_width // image_size)) + 1) * image_size,
|
|
168
|
+
((i // (target_width // image_size)) + 1) * image_size,
|
|
169
|
+
)
|
|
170
|
+
# split the image
|
|
171
|
+
split_img = resized_img.crop(box)
|
|
172
|
+
processed_images.append(split_img)
|
|
173
|
+
assert len(processed_images) == blocks
|
|
174
|
+
if use_thumbnail and len(processed_images) != 1:
|
|
175
|
+
thumbnail_img = image.resize((image_size, image_size))
|
|
176
|
+
processed_images.append(thumbnail_img)
|
|
177
|
+
return processed_images, target_aspect_ratio
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
class DeepseekOCRProcessor(ProcessorMixin):
|
|
181
|
+
tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast")
|
|
182
|
+
attributes = ["tokenizer"]
|
|
183
|
+
|
|
184
|
+
def __init__(
|
|
185
|
+
self,
|
|
186
|
+
tokenizer: LlamaTokenizerFast,
|
|
187
|
+
candidate_resolutions: Tuple[Tuple[int, int]],
|
|
188
|
+
patch_size: int,
|
|
189
|
+
downsample_ratio: int,
|
|
190
|
+
image_mean: Tuple[float, float, float] = (0.5, 0.5, 0.5),
|
|
191
|
+
image_std: Tuple[float, float, float] = (0.5, 0.5, 0.5),
|
|
192
|
+
normalize: bool = True,
|
|
193
|
+
image_token: str = "<image>",
|
|
194
|
+
pad_token: str = "<|▁pad▁|>",
|
|
195
|
+
add_special_token: bool = False,
|
|
196
|
+
sft_format: str = "deepseek",
|
|
197
|
+
mask_prompt: bool = True,
|
|
198
|
+
ignore_id: int = -100,
|
|
199
|
+
**kwargs,
|
|
200
|
+
):
|
|
201
|
+
|
|
202
|
+
self.candidate_resolutions = candidate_resolutions
|
|
203
|
+
self.image_size = candidate_resolutions[0][0]
|
|
204
|
+
self.patch_size = patch_size
|
|
205
|
+
self.image_mean = image_mean
|
|
206
|
+
self.image_std = image_std
|
|
207
|
+
self.normalize = normalize
|
|
208
|
+
self.downsample_ratio = downsample_ratio
|
|
209
|
+
self.base_size = BASE_SIZE
|
|
210
|
+
self.image_transform = ImageTransform(
|
|
211
|
+
mean=image_mean, std=image_std, normalize=normalize
|
|
212
|
+
)
|
|
213
|
+
self.tokenizer = tokenizer
|
|
214
|
+
# must set this,padding side with make a difference in batch inference
|
|
215
|
+
self.tokenizer.padding_side = "left"
|
|
216
|
+
|
|
217
|
+
# add the pad_token as special token to use 'tokenizer.pad_token' and 'tokenizer.pad_token_id'
|
|
218
|
+
if tokenizer.pad_token is None:
|
|
219
|
+
self.tokenizer.add_special_tokens({"pad_token": pad_token})
|
|
220
|
+
|
|
221
|
+
# add image token
|
|
222
|
+
image_token_id = self.tokenizer.vocab.get(image_token)
|
|
223
|
+
if image_token_id is None:
|
|
224
|
+
special_tokens = [image_token]
|
|
225
|
+
special_tokens_dict = {"additional_special_tokens": special_tokens}
|
|
226
|
+
self.tokenizer.add_special_tokens(special_tokens_dict)
|
|
227
|
+
self.image_token_id = self.tokenizer.vocab.get(image_token)
|
|
228
|
+
|
|
229
|
+
# add five special tokens for grounding-related tasks
|
|
230
|
+
# <|ref|>, <|/ref|>, <|det|>, <|/det|>, <|grounding|>
|
|
231
|
+
special_tokens = ["<|ref|>", "<|/ref|>", "<|det|>", "<|/det|>", "<|grounding|>"]
|
|
232
|
+
special_tokens_dict = {"additional_special_tokens": special_tokens}
|
|
233
|
+
self.tokenizer.add_special_tokens(special_tokens_dict)
|
|
234
|
+
|
|
235
|
+
# add special tokens for SFT data
|
|
236
|
+
special_tokens = ["<|User|>", "<|Assistant|>"]
|
|
237
|
+
special_tokens_dict = {"additional_special_tokens": special_tokens}
|
|
238
|
+
self.tokenizer.add_special_tokens(special_tokens_dict)
|
|
239
|
+
|
|
240
|
+
self.image_token = image_token
|
|
241
|
+
self.pad_token = pad_token
|
|
242
|
+
self.add_special_token = add_special_token
|
|
243
|
+
self.sft_format = sft_format
|
|
244
|
+
self.mask_prompt = mask_prompt
|
|
245
|
+
self.ignore_id = ignore_id
|
|
246
|
+
|
|
247
|
+
super().__init__(
|
|
248
|
+
tokenizer,
|
|
249
|
+
**kwargs,
|
|
250
|
+
)
|
|
251
|
+
|
|
252
|
+
def format_messages_v2(self, messages: str, pil_images, max_req_input_len=-1):
|
|
253
|
+
"""play the role of format_messages_v2 and get_images_info in the last version"""
|
|
254
|
+
tokenized_data = []
|
|
255
|
+
masked_tokenized_data = [] # labels
|
|
256
|
+
images_list = []
|
|
257
|
+
images_seq_mask = []
|
|
258
|
+
images_spatial_crop = []
|
|
259
|
+
|
|
260
|
+
image_index = 0
|
|
261
|
+
image_token_cnt = messages.count(self.image_token)
|
|
262
|
+
(
|
|
263
|
+
input_ids,
|
|
264
|
+
images,
|
|
265
|
+
images_crop,
|
|
266
|
+
seq_mask,
|
|
267
|
+
spatial_crop,
|
|
268
|
+
num_image_tokens,
|
|
269
|
+
image_shapes,
|
|
270
|
+
) = self.tokenize_with_images(
|
|
271
|
+
messages,
|
|
272
|
+
pil_images[image_index : image_index + image_token_cnt],
|
|
273
|
+
bos=True,
|
|
274
|
+
eos=True,
|
|
275
|
+
cropping=len(pil_images) <= 2,
|
|
276
|
+
)
|
|
277
|
+
|
|
278
|
+
image_index = image_token_cnt
|
|
279
|
+
images_list += images
|
|
280
|
+
images_seq_mask += seq_mask
|
|
281
|
+
images_spatial_crop = spatial_crop
|
|
282
|
+
|
|
283
|
+
return (
|
|
284
|
+
input_ids,
|
|
285
|
+
masked_tokenized_data,
|
|
286
|
+
images_list,
|
|
287
|
+
images_seq_mask,
|
|
288
|
+
images_spatial_crop,
|
|
289
|
+
images_crop,
|
|
290
|
+
)
|
|
291
|
+
|
|
292
|
+
@property
|
|
293
|
+
def bos_id(self):
|
|
294
|
+
return self.tokenizer.bos_token_id
|
|
295
|
+
|
|
296
|
+
@property
|
|
297
|
+
def eos_id(self):
|
|
298
|
+
return self.tokenizer.eos_token_id
|
|
299
|
+
|
|
300
|
+
@property
|
|
301
|
+
def pad_id(self):
|
|
302
|
+
return self.tokenizer.pad_token_id
|
|
303
|
+
|
|
304
|
+
def encode(self, text: str, bos: bool = True, eos: bool = False):
|
|
305
|
+
t = self.tokenizer.encode(text, add_special_tokens=False)
|
|
306
|
+
|
|
307
|
+
if bos:
|
|
308
|
+
t = [self.bos_id] + t
|
|
309
|
+
if eos:
|
|
310
|
+
t = t + [self.eos_id]
|
|
311
|
+
|
|
312
|
+
return t
|
|
313
|
+
|
|
314
|
+
def decode(self, t: List[int], **kwargs) -> str:
|
|
315
|
+
return self.tokenizer.decode(t, **kwargs)
|
|
316
|
+
|
|
317
|
+
def process_one(
|
|
318
|
+
self,
|
|
319
|
+
prompt: str = None,
|
|
320
|
+
conversations: List[Dict[str, str]] = None,
|
|
321
|
+
images: List[Image.Image] = None,
|
|
322
|
+
apply_sft_format: bool = False,
|
|
323
|
+
inference_mode: bool = True,
|
|
324
|
+
system_prompt: str = "",
|
|
325
|
+
max_req_input_len: int = -1,
|
|
326
|
+
cropping: bool = True,
|
|
327
|
+
**kwargs,
|
|
328
|
+
):
|
|
329
|
+
"""
|
|
330
|
+
|
|
331
|
+
Args:
|
|
332
|
+
prompt (str): the formatted prompt;
|
|
333
|
+
conversations (List[Dict]): conversations with a list of messages;
|
|
334
|
+
images (List[ImageType]): the list of images;
|
|
335
|
+
apply_sft_format (bool): if prompt is not None, then apply the SFT format to prompt;
|
|
336
|
+
if conversations is not None, then it will always apply the SFT format to conversations;
|
|
337
|
+
inference_mode (bool): if True, then remove the last eos token;
|
|
338
|
+
system_prompt (str): the system prompt;
|
|
339
|
+
**kwargs:
|
|
340
|
+
|
|
341
|
+
Returns:
|
|
342
|
+
outputs (BaseProcessorOutput): the output of the processor,
|
|
343
|
+
- input_ids (torch.LongTensor): [N + image tokens]
|
|
344
|
+
- target_ids (torch.LongTensor): [N + image tokens]
|
|
345
|
+
- images (torch.FloatTensor): [n_images, 3, H, W]
|
|
346
|
+
- image_id (int): the id of the image token
|
|
347
|
+
- num_image_tokens (List[int]): the number of image tokens
|
|
348
|
+
"""
|
|
349
|
+
|
|
350
|
+
prompt = conversations or prompt
|
|
351
|
+
(
|
|
352
|
+
input_ids,
|
|
353
|
+
masked_tokenized_str,
|
|
354
|
+
images_list,
|
|
355
|
+
images_seq_mask,
|
|
356
|
+
images_spatial_crop,
|
|
357
|
+
images_crop,
|
|
358
|
+
) = self.format_messages_v2(prompt, images, max_req_input_len)
|
|
359
|
+
|
|
360
|
+
target_ids = torch.LongTensor(masked_tokenized_str)
|
|
361
|
+
|
|
362
|
+
if len(images_list) == 0:
|
|
363
|
+
images = torch.zeros((1, 3, self.image_size, self.image_size))
|
|
364
|
+
else:
|
|
365
|
+
images = torch.stack(images_list, dim=0)
|
|
366
|
+
|
|
367
|
+
images_spatial_crop = torch.stack(
|
|
368
|
+
[images_spatial_crop], dim=0
|
|
369
|
+
) # stack the tensor to make it a batch of 1
|
|
370
|
+
|
|
371
|
+
prepare = VLChatProcessorOutput(
|
|
372
|
+
input_ids=input_ids,
|
|
373
|
+
target_ids=target_ids,
|
|
374
|
+
images_crop=images_crop,
|
|
375
|
+
pixel_values=images,
|
|
376
|
+
images_seq_mask=images_seq_mask,
|
|
377
|
+
images_spatial_crop=images_spatial_crop,
|
|
378
|
+
)
|
|
379
|
+
|
|
380
|
+
return prepare
|
|
381
|
+
|
|
382
|
+
def __call__(
|
|
383
|
+
self,
|
|
384
|
+
*,
|
|
385
|
+
prompt: str = None,
|
|
386
|
+
conversations: List[Dict[str, str]] = None,
|
|
387
|
+
images: List[Image.Image] = None,
|
|
388
|
+
apply_sft_format: bool = False,
|
|
389
|
+
inference_mode: bool = True,
|
|
390
|
+
system_prompt: str = "",
|
|
391
|
+
max_req_input_len: int = -1,
|
|
392
|
+
text: list[str] = None,
|
|
393
|
+
**kwargs,
|
|
394
|
+
):
|
|
395
|
+
assert text is None or isinstance(text, list)
|
|
396
|
+
if text is not None:
|
|
397
|
+
text = text[0]
|
|
398
|
+
prepare = self.process_one(
|
|
399
|
+
prompt=prompt or text,
|
|
400
|
+
conversations=conversations,
|
|
401
|
+
images=images,
|
|
402
|
+
apply_sft_format=apply_sft_format,
|
|
403
|
+
inference_mode=inference_mode,
|
|
404
|
+
system_prompt=system_prompt,
|
|
405
|
+
max_req_input_len=max_req_input_len,
|
|
406
|
+
)
|
|
407
|
+
|
|
408
|
+
return prepare
|
|
409
|
+
|
|
410
|
+
def find_all_indices(self, messages, target_value):
|
|
411
|
+
indices = []
|
|
412
|
+
for index, item in enumerate(messages):
|
|
413
|
+
if item == target_value:
|
|
414
|
+
indices.append(index)
|
|
415
|
+
return indices
|
|
416
|
+
|
|
417
|
+
def tokenize_with_images(
|
|
418
|
+
self,
|
|
419
|
+
conversation: str,
|
|
420
|
+
images: List[Image.Image],
|
|
421
|
+
bos: bool = True,
|
|
422
|
+
eos: bool = True,
|
|
423
|
+
cropping: bool = True,
|
|
424
|
+
):
|
|
425
|
+
"""Tokenize text with <image> tags."""
|
|
426
|
+
|
|
427
|
+
conversation = conversation
|
|
428
|
+
assert conversation.count(self.image_token) == len(images)
|
|
429
|
+
text_splits = conversation.split(self.image_token)
|
|
430
|
+
images_list, images_crop_list, images_seq_mask, images_spatial_crop = (
|
|
431
|
+
[],
|
|
432
|
+
[],
|
|
433
|
+
[],
|
|
434
|
+
[],
|
|
435
|
+
)
|
|
436
|
+
image_shapes = []
|
|
437
|
+
num_image_tokens = []
|
|
438
|
+
tokenized_str = []
|
|
439
|
+
for text_sep, image in zip(text_splits, images):
|
|
440
|
+
"""encode text_sep"""
|
|
441
|
+
tokenized_sep = self.encode(text_sep, bos=False, eos=False)
|
|
442
|
+
|
|
443
|
+
tokenized_str += tokenized_sep
|
|
444
|
+
images_seq_mask += [False] * len(tokenized_sep)
|
|
445
|
+
|
|
446
|
+
image_shapes.append(image.size)
|
|
447
|
+
|
|
448
|
+
if image.size[0] <= 640 and image.size[1] <= 640:
|
|
449
|
+
crop_ratio = [1, 1]
|
|
450
|
+
else:
|
|
451
|
+
if cropping:
|
|
452
|
+
images_crop_raw, crop_ratio = dynamic_preprocess(
|
|
453
|
+
image, image_size=IMAGE_SIZE
|
|
454
|
+
)
|
|
455
|
+
else:
|
|
456
|
+
crop_ratio = [1, 1]
|
|
457
|
+
|
|
458
|
+
"""process the global view"""
|
|
459
|
+
if self.image_size <= 640 and not cropping:
|
|
460
|
+
image = image.resize((self.image_size, self.image_size))
|
|
461
|
+
|
|
462
|
+
global_view = ImageOps.pad(
|
|
463
|
+
image,
|
|
464
|
+
(self.base_size, self.base_size),
|
|
465
|
+
color=tuple(int(x * 255) for x in self.image_transform.mean),
|
|
466
|
+
)
|
|
467
|
+
images_list.append(self.image_transform(global_view))
|
|
468
|
+
|
|
469
|
+
num_width_tiles, num_height_tiles = crop_ratio
|
|
470
|
+
images_spatial_crop.append([num_width_tiles, num_height_tiles])
|
|
471
|
+
|
|
472
|
+
if num_width_tiles > 1 or num_height_tiles > 1:
|
|
473
|
+
for i in range(len(images_crop_raw)):
|
|
474
|
+
images_crop_list.append(self.image_transform(images_crop_raw[i]))
|
|
475
|
+
|
|
476
|
+
"""add image tokens"""
|
|
477
|
+
num_queries = math.ceil(
|
|
478
|
+
(self.image_size // self.patch_size) / self.downsample_ratio
|
|
479
|
+
)
|
|
480
|
+
num_queries_base = math.ceil(
|
|
481
|
+
(self.base_size // self.patch_size) / self.downsample_ratio
|
|
482
|
+
)
|
|
483
|
+
|
|
484
|
+
tokenized_image = (
|
|
485
|
+
[self.image_token_id] * num_queries_base + [self.image_token_id]
|
|
486
|
+
) * num_queries_base
|
|
487
|
+
tokenized_image += [self.image_token_id]
|
|
488
|
+
if num_width_tiles > 1 or num_height_tiles > 1:
|
|
489
|
+
tokenized_image += (
|
|
490
|
+
[self.image_token_id] * (num_queries * num_width_tiles)
|
|
491
|
+
+ [self.image_token_id]
|
|
492
|
+
) * (num_queries * num_height_tiles)
|
|
493
|
+
tokenized_str += tokenized_image
|
|
494
|
+
|
|
495
|
+
images_seq_mask += [True] * len(tokenized_image)
|
|
496
|
+
num_image_tokens.append(len(tokenized_image))
|
|
497
|
+
|
|
498
|
+
"""process the last text split"""
|
|
499
|
+
tokenized_sep = self.encode(text_splits[-1], bos=False, eos=False)
|
|
500
|
+
|
|
501
|
+
tokenized_str += tokenized_sep
|
|
502
|
+
images_seq_mask += [False] * len(tokenized_sep)
|
|
503
|
+
|
|
504
|
+
"""add the bos and eos tokens"""
|
|
505
|
+
if bos:
|
|
506
|
+
tokenized_str = [self.bos_id] + tokenized_str
|
|
507
|
+
images_seq_mask = [False] + images_seq_mask
|
|
508
|
+
if eos:
|
|
509
|
+
tokenized_str = tokenized_str + [self.eos_id]
|
|
510
|
+
images_seq_mask = images_seq_mask + [False]
|
|
511
|
+
|
|
512
|
+
assert len(tokenized_str) == len(
|
|
513
|
+
images_seq_mask
|
|
514
|
+
), f"tokenize_with_images func: tokenized_str's length {len(tokenized_str)} is not equal to imags_seq_mask's length {len(images_seq_mask)}"
|
|
515
|
+
|
|
516
|
+
masked_tokenized_str = []
|
|
517
|
+
for token_index in tokenized_str:
|
|
518
|
+
if token_index != self.image_token_id:
|
|
519
|
+
masked_tokenized_str.append(token_index)
|
|
520
|
+
else:
|
|
521
|
+
masked_tokenized_str.append(self.ignore_id)
|
|
522
|
+
|
|
523
|
+
assert (
|
|
524
|
+
len(tokenized_str) == len(images_seq_mask) == len(masked_tokenized_str)
|
|
525
|
+
), (
|
|
526
|
+
f"tokenized_str's length {len(tokenized_str)}, input_ids' length {len(masked_tokenized_str)}, "
|
|
527
|
+
f"imags_seq_mask's length {len(images_seq_mask)}, are not equal"
|
|
528
|
+
)
|
|
529
|
+
input_ids = torch.LongTensor(tokenized_str)
|
|
530
|
+
target_ids = torch.LongTensor(masked_tokenized_str)
|
|
531
|
+
images_seq_mask = torch.tensor(images_seq_mask, dtype=torch.bool)
|
|
532
|
+
|
|
533
|
+
# set input_ids < 0 | input_ids == self.image_token_id as ignore_id
|
|
534
|
+
target_ids[(input_ids < 0) | (input_ids == self.image_token_id)] = (
|
|
535
|
+
self.ignore_id
|
|
536
|
+
)
|
|
537
|
+
input_ids[input_ids < 0] = self.pad_id
|
|
538
|
+
|
|
539
|
+
inference_mode = True
|
|
540
|
+
|
|
541
|
+
if inference_mode:
|
|
542
|
+
# Remove the ending eos token
|
|
543
|
+
assert input_ids[-1] == self.eos_id
|
|
544
|
+
input_ids = input_ids[:-1]
|
|
545
|
+
target_ids = target_ids[:-1]
|
|
546
|
+
images_seq_mask = images_seq_mask[:-1]
|
|
547
|
+
|
|
548
|
+
if len(images_list) == 0:
|
|
549
|
+
pixel_values = torch.zeros((1, 3, self.base_size, self.base_size))
|
|
550
|
+
images_spatial_crop = torch.zeros((1, 1), dtype=torch.long)
|
|
551
|
+
images_crop = torch.zeros(
|
|
552
|
+
(1, 3, self.image_size, self.image_size)
|
|
553
|
+
).unsqueeze(0)
|
|
554
|
+
else:
|
|
555
|
+
pixel_values = torch.stack(images_list, dim=0)
|
|
556
|
+
images_spatial_crop = torch.tensor(images_spatial_crop, dtype=torch.long)
|
|
557
|
+
if images_crop_list:
|
|
558
|
+
images_crop = torch.stack(images_crop_list, dim=0).unsqueeze(0)
|
|
559
|
+
else:
|
|
560
|
+
images_crop = torch.zeros(
|
|
561
|
+
(1, 3, self.image_size, self.image_size)
|
|
562
|
+
).unsqueeze(0)
|
|
563
|
+
|
|
564
|
+
input_ids = input_ids.unsqueeze(0)
|
|
565
|
+
return (
|
|
566
|
+
input_ids,
|
|
567
|
+
pixel_values,
|
|
568
|
+
images_crop,
|
|
569
|
+
images_seq_mask,
|
|
570
|
+
images_spatial_crop,
|
|
571
|
+
num_image_tokens,
|
|
572
|
+
image_shapes,
|
|
573
|
+
)
|
|
574
|
+
|
|
575
|
+
|
|
45
576
|
class VisionEncoderConfig(PretrainedConfig):
|
|
46
577
|
model_type: str = "vision"
|
|
47
578
|
|
|
@@ -223,6 +754,7 @@ class DeepseekV2Config(PretrainedConfig):
|
|
|
223
754
|
)
|
|
224
755
|
|
|
225
756
|
|
|
757
|
+
@register_customized_processor(processor_class=DeepseekOCRProcessor)
|
|
226
758
|
class DeepseekVLV2Config(PretrainedConfig):
|
|
227
759
|
# model_type = "deepseek_vl_v2"
|
|
228
760
|
model_type = "deepseek-ocr"
|
|
@@ -232,6 +764,7 @@ class DeepseekVLV2Config(PretrainedConfig):
|
|
|
232
764
|
tile_tag: str = "2D"
|
|
233
765
|
global_view_pos: str = "head"
|
|
234
766
|
candidate_resolutions: tuple[tuple[int, int]] = ((384, 384),)
|
|
767
|
+
customized_processor_type: type[Any] = DeepseekOCRProcessor
|
|
235
768
|
|
|
236
769
|
def __init__(
|
|
237
770
|
self,
|
|
@@ -258,5 +791,4 @@ class DeepseekVLV2Config(PretrainedConfig):
|
|
|
258
791
|
self.hidden_size = self.text_config.hidden_size
|
|
259
792
|
|
|
260
793
|
|
|
261
|
-
|
|
262
|
-
model_type = "DeepseekOCR"
|
|
794
|
+
AutoProcessor.register(DeepseekVLV2Config, DeepseekOCRProcessor)
|