openocr-python 0.0.9__py3-none-any.whl → 0.1.0.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- openocr/__init__.py +35 -1
- openocr/configs/dataset/rec/evaluation.yaml +41 -0
- openocr/configs/dataset/rec/ltb.yaml +9 -0
- openocr/configs/dataset/rec/mjsynth.yaml +11 -0
- openocr/configs/dataset/rec/openvino.yaml +25 -0
- openocr/configs/dataset/rec/ost.yaml +17 -0
- openocr/configs/dataset/rec/synthtext.yaml +7 -0
- openocr/configs/dataset/rec/test.yaml +77 -0
- openocr/configs/dataset/rec/textocr.yaml +13 -0
- openocr/configs/dataset/rec/textocr_horizontal.yaml +13 -0
- openocr/configs/dataset/rec/union14m_b.yaml +47 -0
- openocr/configs/dataset/rec/union14m_l_filtered.yaml +35 -0
- openocr/configs/rec/cmer/cmer.yml +127 -0
- openocr/configs/rec/mdiff4str/svtrv2_mdiffdecoder_base.yml +152 -0
- openocr/configs/rec/mdiff4str/svtrv2_mdiffdecoder_small.yml +152 -0
- openocr/configs/rec/unirec/focalsvtr_ardecoder_unirec.yml +114 -0
- openocr/configs/rec/unirec/opendoc_pipeline.yml +105 -0
- openocr/demo_gradio.py +28 -8
- openocr/demo_opendoc.py +572 -0
- openocr/demo_unirec.py +392 -0
- openocr/opendet/losses/__init__.py +5 -7
- openocr/opendet/preprocess/crop_resize.py +2 -1
- openocr/openocr.py +685 -0
- openocr/openrec/losses/__init__.py +8 -3
- openocr/openrec/losses/cmer_loss.py +12 -0
- openocr/openrec/losses/mdiff_loss.py +11 -0
- openocr/openrec/losses/unirec_loss.py +12 -0
- openocr/openrec/metrics/__init__.py +4 -1
- openocr/openrec/metrics/rec_metric_cmer.py +328 -0
- openocr/openrec/modeling/cmer_modeling/modeling_cmer.py +643 -0
- openocr/openrec/modeling/decoders/__init__.py +1 -0
- openocr/openrec/modeling/decoders/ctc_decoder.py +1 -1
- openocr/openrec/modeling/decoders/dan_decoder.py +4 -4
- openocr/openrec/modeling/decoders/dptr_parseq_clip_b_decoder.py +1563 -1398
- openocr/openrec/modeling/decoders/mdiff_decoder.py +587 -0
- openocr/openrec/modeling/decoders/smtr_decoder.py +99 -48
- openocr/openrec/modeling/unirec_modeling/configuration_unirec.py +166 -0
- openocr/openrec/modeling/unirec_modeling/modeling_unirec.py +433 -0
- openocr/openrec/optimizer/__init__.py +4 -3
- openocr/openrec/optimizer/lr.py +49 -0
- openocr/openrec/postprocess/__init__.py +2 -0
- openocr/openrec/postprocess/abinet_postprocess.py +1 -1
- openocr/openrec/postprocess/ar_postprocess.py +1 -1
- openocr/openrec/postprocess/cmer_postprocess.py +86 -0
- openocr/openrec/postprocess/cppd_postprocess.py +1 -1
- openocr/openrec/postprocess/igtr_postprocess.py +1 -1
- openocr/openrec/postprocess/lister_postprocess.py +1 -1
- openocr/openrec/postprocess/mgp_postprocess.py +1 -1
- openocr/openrec/postprocess/nrtr_postprocess.py +2 -2
- openocr/openrec/postprocess/smtr_postprocess.py +1 -1
- openocr/openrec/postprocess/srn_postprocess.py +1 -1
- openocr/openrec/postprocess/unirec_postprocess.py +58 -0
- openocr/openrec/postprocess/visionlan_postprocess.py +1 -1
- openocr/openrec/preprocess/__init__.py +5 -0
- openocr/openrec/preprocess/ce_label_encode.py +1 -1
- openocr/openrec/preprocess/cmer_label_encode.py +1025 -0
- openocr/openrec/preprocess/ctc_label_encode.py +1 -1
- openocr/openrec/preprocess/dptr_label_encode.py +177 -157
- openocr/openrec/preprocess/igtr_label_encode.py +4 -2
- openocr/openrec/preprocess/mdiff_label_encode.py +312 -0
- openocr/openrec/preprocess/rec_aug.py +128 -2
- openocr/openrec/preprocess/resize.py +57 -0
- openocr/openrec/preprocess/unirec_label_encode.py +62 -0
- openocr/tools/data/__init__.py +78 -55
- openocr/tools/data/cmer_web_dataset.py +310 -0
- openocr/tools/data/native_size_dataset.py +753 -0
- openocr/tools/data/native_size_sampler.py +158 -0
- openocr/tools/data/ratio_dataset_tvresize.py +2 -0
- openocr/tools/data/ratio_sampler.py +2 -1
- openocr/tools/download/download_dataset.py +38 -0
- openocr/tools/download/utils.py +28 -0
- openocr/tools/download_example_images.py +236 -0
- openocr/tools/engine/trainer.py +155 -39
- openocr/tools/eval_rec_all_ch.py +2 -2
- openocr/tools/infer_det.py +20 -2
- openocr/tools/infer_doc.py +898 -0
- openocr/tools/infer_doc_onnx.py +1172 -0
- openocr/tools/infer_e2e.py +27 -10
- openocr/tools/infer_rec.py +64 -15
- openocr/tools/infer_unirec_onnx.py +730 -0
- openocr/tools/to_markdown.py +468 -0
- openocr/tools/utils/ckpt.py +17 -5
- openocr/tools/utils/opendoc_onnx_utils/utils.py +1052 -0
- openocr_python-0.1.0.dev0.dist-info/METADATA +324 -0
- {openocr_python-0.0.9.dist-info → openocr_python-0.1.0.dev0.dist-info}/RECORD +89 -45
- {openocr_python-0.0.9.dist-info → openocr_python-0.1.0.dev0.dist-info}/WHEEL +1 -1
- openocr_python-0.1.0.dev0.dist-info/entry_points.txt +2 -0
- openocr_python-0.0.9.dist-info/METADATA +0 -149
- /openocr_python-0.0.9.dist-info/LICENCE → /openocr_python-0.1.0.dev0.dist-info/licenses/LICENSE +0 -0
- {openocr_python-0.0.9.dist-info → openocr_python-0.1.0.dev0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,1025 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
import math
|
|
3
|
+
from typing import List, Optional, Tuple, Union
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
from scipy.ndimage import zoom as scizoom
|
|
7
|
+
|
|
8
|
+
# Transformers imports
|
|
9
|
+
from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
|
|
10
|
+
from transformers.image_transforms import convert_to_rgb, resize, to_channel_dimension_format
|
|
11
|
+
from transformers.image_utils import (
|
|
12
|
+
ChannelDimension,
|
|
13
|
+
ImageInput,
|
|
14
|
+
PILImageResampling,
|
|
15
|
+
infer_channel_dimension_format,
|
|
16
|
+
make_flat_list_of_images,
|
|
17
|
+
to_numpy_array,
|
|
18
|
+
valid_images,
|
|
19
|
+
validate_preprocess_arguments,
|
|
20
|
+
)
|
|
21
|
+
from transformers import PreTrainedTokenizerFast
|
|
22
|
+
from transformers.utils import TensorType, filter_out_non_signature_kwargs, is_vision_available, logging
|
|
23
|
+
from transformers import AutoImageProcessor, ProcessorMixin
|
|
24
|
+
import torch
|
|
25
|
+
# Third-party optional imports
|
|
26
|
+
logger = logging.get_logger(__name__)
|
|
27
|
+
|
|
28
|
+
try:
|
|
29
|
+
import albumentations as A
|
|
30
|
+
except Exception as _e:
|
|
31
|
+
A = None
|
|
32
|
+
_A_IMPORT_ERR = str(_e)
|
|
33
|
+
|
|
34
|
+
try:
|
|
35
|
+
import cv2
|
|
36
|
+
except Exception:
|
|
37
|
+
cv2 = None
|
|
38
|
+
|
|
39
|
+
if is_vision_available():
|
|
40
|
+
from PIL import Image, ImageOps, ImageDraw
|
|
41
|
+
|
|
42
|
+
# Albumentations Custom Transforms
|
|
43
|
+
if A is not None:
|
|
44
|
+
|
|
45
|
+
class Erosion(A.ImageOnlyTransform):
|
|
46
|
+
|
|
47
|
+
def __init__(self, scale, always_apply=False, p=0.5):
|
|
48
|
+
super().__init__(always_apply=always_apply, p=p)
|
|
49
|
+
if type(scale) is tuple or type(scale) is list:
|
|
50
|
+
assert len(scale) == 2
|
|
51
|
+
self.scale = scale
|
|
52
|
+
else:
|
|
53
|
+
self.scale = (scale, scale)
|
|
54
|
+
|
|
55
|
+
def apply(self, img, **params):
|
|
56
|
+
if cv2 is None:
|
|
57
|
+
return img
|
|
58
|
+
kernel = cv2.getStructuringElement(
|
|
59
|
+
cv2.MORPH_ELLIPSE,
|
|
60
|
+
tuple(np.random.randint(self.scale[0], self.scale[1], 2)))
|
|
61
|
+
img = cv2.erode(img, kernel, iterations=1)
|
|
62
|
+
return img
|
|
63
|
+
|
|
64
|
+
class Dilation(A.ImageOnlyTransform):
|
|
65
|
+
|
|
66
|
+
def __init__(self, scale, always_apply=False, p=0.5):
|
|
67
|
+
super().__init__(always_apply=always_apply, p=p)
|
|
68
|
+
if type(scale) is tuple or type(scale) is list:
|
|
69
|
+
assert len(scale) == 2
|
|
70
|
+
self.scale = scale
|
|
71
|
+
else:
|
|
72
|
+
self.scale = (scale, scale)
|
|
73
|
+
|
|
74
|
+
def apply(self, img, **params):
|
|
75
|
+
if cv2 is None:
|
|
76
|
+
return img
|
|
77
|
+
kernel = cv2.getStructuringElement(
|
|
78
|
+
cv2.MORPH_ELLIPSE,
|
|
79
|
+
tuple(np.random.randint(self.scale[0], self.scale[1], 2)))
|
|
80
|
+
img = cv2.dilate(img, kernel, iterations=1)
|
|
81
|
+
return img
|
|
82
|
+
|
|
83
|
+
class Bitmap(A.ImageOnlyTransform):
|
|
84
|
+
|
|
85
|
+
def __init__(self, value=0, lower=200, p=0.5):
|
|
86
|
+
super().__init__(p=p)
|
|
87
|
+
self.lower = lower
|
|
88
|
+
self.value = value
|
|
89
|
+
|
|
90
|
+
def apply(self, img, **params):
|
|
91
|
+
img = img.copy()
|
|
92
|
+
img[img < self.lower] = self.value
|
|
93
|
+
return img
|
|
94
|
+
|
|
95
|
+
class Fog(A.ImageOnlyTransform):
|
|
96
|
+
|
|
97
|
+
def __init__(self, mag=-1, always_apply=False, p=1.):
|
|
98
|
+
super().__init__(always_apply=always_apply, p=p)
|
|
99
|
+
self.rng = np.random.default_rng()
|
|
100
|
+
self.mag = mag
|
|
101
|
+
|
|
102
|
+
def apply(self, img, **params):
|
|
103
|
+
img = Image.fromarray(img.astype(np.uint8))
|
|
104
|
+
w, h = img.size
|
|
105
|
+
c = [(1.5, 2), (2., 2), (2.5, 1.7)]
|
|
106
|
+
if self.mag < 0 or self.mag >= len(c):
|
|
107
|
+
index = self.rng.integers(0, len(c))
|
|
108
|
+
else:
|
|
109
|
+
index = self.mag
|
|
110
|
+
c = c[index]
|
|
111
|
+
n_channels = len(img.getbands())
|
|
112
|
+
isgray = n_channels == 1
|
|
113
|
+
img = np.asarray(img) / 255.
|
|
114
|
+
max_val = img.max()
|
|
115
|
+
max_size = 2**math.ceil(math.log2(max(w, h)) + 1)
|
|
116
|
+
fog = c[0] * plasma_fractal(mapsize=max_size,
|
|
117
|
+
wibbledecay=c[1],
|
|
118
|
+
rng=self.rng)[:h, :w][..., np.newaxis]
|
|
119
|
+
if isgray:
|
|
120
|
+
fog = np.squeeze(fog)
|
|
121
|
+
else:
|
|
122
|
+
fog = np.repeat(fog, 3, axis=2)
|
|
123
|
+
img += fog
|
|
124
|
+
img = np.clip(img * max_val / (max_val + c[0]), 0, 1) * 255
|
|
125
|
+
return img.astype(np.uint8)
|
|
126
|
+
|
|
127
|
+
class Frost(A.ImageOnlyTransform):
|
|
128
|
+
|
|
129
|
+
def __init__(self, mag=-1, always_apply=False, p=1.):
|
|
130
|
+
super().__init__(always_apply=always_apply, p=p)
|
|
131
|
+
self.rng = np.random.default_rng()
|
|
132
|
+
self.mag = mag
|
|
133
|
+
|
|
134
|
+
def apply(self, img, **params):
|
|
135
|
+
img = Image.fromarray(img.astype(np.uint8))
|
|
136
|
+
w, h = img.size
|
|
137
|
+
c = [(0.78, 0.22), (0.64, 0.36), (0.5, 0.5)]
|
|
138
|
+
if self.mag < 0 or self.mag >= len(c):
|
|
139
|
+
index = self.rng.integers(0, len(c))
|
|
140
|
+
else:
|
|
141
|
+
index = self.mag
|
|
142
|
+
c = c[index]
|
|
143
|
+
filename = [
|
|
144
|
+
'./openrec/preprocess/cmer_frost/frost1.png',
|
|
145
|
+
'./openrec/preprocess/cmer_frost/frost2.png',
|
|
146
|
+
'./openrec/preprocess/cmer_frost/frost3.png',
|
|
147
|
+
'./openrec/preprocess/cmer_frost/frost4.jpg',
|
|
148
|
+
'./openrec/preprocess/cmer_frost/frost5.jpg',
|
|
149
|
+
'./openrec/preprocess/cmer_frost/frost6.jpg',
|
|
150
|
+
]
|
|
151
|
+
index = self.rng.integers(0, len(filename))
|
|
152
|
+
filename = filename[index]
|
|
153
|
+
try:
|
|
154
|
+
frost = Image.open(filename).convert('RGB')
|
|
155
|
+
except Exception:
|
|
156
|
+
# Fallback if file not found
|
|
157
|
+
return np.asarray(img).astype(np.uint8)
|
|
158
|
+
|
|
159
|
+
f_w, f_h = frost.size
|
|
160
|
+
if w / h > f_w / f_h:
|
|
161
|
+
f_h = round(f_h * w / f_w)
|
|
162
|
+
f_w = w
|
|
163
|
+
else:
|
|
164
|
+
f_w = round(f_w * h / f_h)
|
|
165
|
+
f_h = h
|
|
166
|
+
frost = np.asarray(frost.resize((f_w, f_h)))
|
|
167
|
+
y_start = self.rng.integers(0, f_h - h + 1)
|
|
168
|
+
x_start = self.rng.integers(0, f_w - w + 1)
|
|
169
|
+
frost = frost[y_start:y_start + h, x_start:x_start + w]
|
|
170
|
+
n_channels = len(img.getbands())
|
|
171
|
+
isgray = n_channels == 1
|
|
172
|
+
img = np.asarray(img)
|
|
173
|
+
if isgray:
|
|
174
|
+
img = np.expand_dims(img, axis=2)
|
|
175
|
+
img = np.repeat(img, 3, axis=2)
|
|
176
|
+
img = np.clip(np.round(c[0] * img + c[1] * frost), 0, 255)
|
|
177
|
+
img = img.astype(np.uint8)
|
|
178
|
+
if isgray:
|
|
179
|
+
img = np.squeeze(img)
|
|
180
|
+
return img
|
|
181
|
+
|
|
182
|
+
class Snow(A.ImageOnlyTransform):
|
|
183
|
+
|
|
184
|
+
def __init__(self, mag=-1, always_apply=False, p=1.):
|
|
185
|
+
super().__init__(always_apply=always_apply, p=p)
|
|
186
|
+
self.rng = np.random.default_rng()
|
|
187
|
+
self.mag = mag
|
|
188
|
+
|
|
189
|
+
def apply(self, img, **params):
|
|
190
|
+
img_pil = Image.fromarray(img.astype(np.uint8))
|
|
191
|
+
w, h = img_pil.size
|
|
192
|
+
c = [(0.1, 0.3, 3, 0.5, 10, 4, 0.8),
|
|
193
|
+
(0.2, 0.3, 2, 0.5, 12, 4, 0.7),
|
|
194
|
+
(0.55, 0.3, 4, 0.9, 12, 8, 0.7)]
|
|
195
|
+
if self.mag < 0 or self.mag >= len(c):
|
|
196
|
+
index = self.rng.integers(0, len(c))
|
|
197
|
+
else:
|
|
198
|
+
index = self.mag
|
|
199
|
+
c = c[index]
|
|
200
|
+
isgray = (len(img_pil.getbands()) == 1)
|
|
201
|
+
img = np.asarray(img_pil, dtype=np.float32) / 255.
|
|
202
|
+
if isgray:
|
|
203
|
+
img = np.repeat(img[..., None], 3, axis=2)
|
|
204
|
+
snow_layer = self.rng.normal(loc=c[0],
|
|
205
|
+
scale=c[1],
|
|
206
|
+
size=img.shape[:2])
|
|
207
|
+
snow_layer[snow_layer < c[3]] = 0
|
|
208
|
+
snow_layer = np.clip(snow_layer, 0, 1).astype(np.float32)
|
|
209
|
+
angle = float(self.rng.uniform(-135, -45))
|
|
210
|
+
snow_layer = motion_blur(snow_layer,
|
|
211
|
+
radius=c[4],
|
|
212
|
+
sigma=c[5],
|
|
213
|
+
angle=angle)
|
|
214
|
+
snow_layer = snow_layer[..., None]
|
|
215
|
+
img = c[6] * img
|
|
216
|
+
if cv2 is not None:
|
|
217
|
+
gray_img = (1 - c[6]) * np.maximum(
|
|
218
|
+
img,
|
|
219
|
+
cv2.cvtColor(img, cv2.COLOR_RGB2GRAY).reshape(h, w, 1) *
|
|
220
|
+
1.5 + 0.5)
|
|
221
|
+
img += gray_img
|
|
222
|
+
img = np.clip(img + snow_layer + np.rot90(snow_layer, k=2), 0,
|
|
223
|
+
1) * 255
|
|
224
|
+
img = img.astype(np.uint8)
|
|
225
|
+
return np.squeeze(img) if isgray else img
|
|
226
|
+
|
|
227
|
+
class Rain(A.ImageOnlyTransform):
|
|
228
|
+
|
|
229
|
+
def __init__(self, mag=-1, always_apply=False, p=1.):
|
|
230
|
+
super().__init__(always_apply=always_apply, p=p)
|
|
231
|
+
self.rng = np.random.default_rng()
|
|
232
|
+
self.mag = mag
|
|
233
|
+
|
|
234
|
+
def apply(self, img, **params):
|
|
235
|
+
img = Image.fromarray(img.astype(np.uint8))
|
|
236
|
+
img = img.copy()
|
|
237
|
+
w, h = img.size
|
|
238
|
+
n_channels = len(img.getbands())
|
|
239
|
+
isgray = n_channels == 1
|
|
240
|
+
line_width = self.rng.integers(1, 2)
|
|
241
|
+
c = [50, 70, 90]
|
|
242
|
+
if self.mag < 0 or self.mag >= len(c):
|
|
243
|
+
index = 0
|
|
244
|
+
else:
|
|
245
|
+
index = self.mag
|
|
246
|
+
c = c[index]
|
|
247
|
+
n_rains = self.rng.integers(c, c + 20)
|
|
248
|
+
slant = self.rng.integers(-60, 60)
|
|
249
|
+
fillcolor = 200 if isgray else (200, 200, 200)
|
|
250
|
+
draw = ImageDraw.Draw(img)
|
|
251
|
+
max_length = min(w, h, 10)
|
|
252
|
+
for i in range(1, n_rains):
|
|
253
|
+
length = self.rng.integers(5, max_length)
|
|
254
|
+
x1 = self.rng.integers(0, w - length)
|
|
255
|
+
y1 = self.rng.integers(0, h - length)
|
|
256
|
+
x2 = x1 + length * math.sin(slant * math.pi / 180.)
|
|
257
|
+
y2 = y1 + length * math.cos(slant * math.pi / 180.)
|
|
258
|
+
x2 = int(x2)
|
|
259
|
+
y2 = int(y2)
|
|
260
|
+
draw.line([(x1, y1), (x2, y2)],
|
|
261
|
+
width=line_width,
|
|
262
|
+
fill=fillcolor)
|
|
263
|
+
img = np.asarray(img).astype(np.uint8)
|
|
264
|
+
return img
|
|
265
|
+
|
|
266
|
+
class Shadow(A.ImageOnlyTransform):
|
|
267
|
+
|
|
268
|
+
def __init__(self, mag=-1, always_apply=False, p=1.):
|
|
269
|
+
super().__init__(always_apply=always_apply, p=p)
|
|
270
|
+
self.rng = np.random.default_rng()
|
|
271
|
+
self.mag = mag
|
|
272
|
+
|
|
273
|
+
def apply(self, img, **params):
|
|
274
|
+
img = Image.fromarray(img.astype(np.uint8))
|
|
275
|
+
w, h = img.size
|
|
276
|
+
n_channels = len(img.getbands())
|
|
277
|
+
isgray = n_channels == 1
|
|
278
|
+
c = [64, 96, 128]
|
|
279
|
+
if self.mag < 0 or self.mag >= len(c):
|
|
280
|
+
index = 0
|
|
281
|
+
else:
|
|
282
|
+
index = self.mag
|
|
283
|
+
c = c[index]
|
|
284
|
+
img = img.convert('RGBA')
|
|
285
|
+
overlay = Image.new('RGBA', img.size, (255, 255, 255, 0))
|
|
286
|
+
draw = ImageDraw.Draw(overlay)
|
|
287
|
+
transparency = self.rng.integers(c, c + 32)
|
|
288
|
+
x1 = self.rng.integers(0, w // 2)
|
|
289
|
+
y1 = 0
|
|
290
|
+
x2 = self.rng.integers(w // 2, w)
|
|
291
|
+
y2 = 0
|
|
292
|
+
x3 = self.rng.integers(w // 2, w)
|
|
293
|
+
y3 = h - 1
|
|
294
|
+
x4 = self.rng.integers(0, w // 2)
|
|
295
|
+
y4 = h - 1
|
|
296
|
+
draw.polygon([(x1, y1), (x2, y2), (x3, y3), (x4, y4)],
|
|
297
|
+
fill=(0, 0, 0, transparency))
|
|
298
|
+
img = Image.alpha_composite(img, overlay)
|
|
299
|
+
img = img.convert('RGB')
|
|
300
|
+
if isgray:
|
|
301
|
+
img = ImageOps.grayscale(img)
|
|
302
|
+
img = np.asarray(img).astype(np.uint8)
|
|
303
|
+
return img
|
|
304
|
+
|
|
305
|
+
else:
|
|
306
|
+
# Fallback placeholders if Albumentations is missing
|
|
307
|
+
Erosion = None
|
|
308
|
+
Dilation = None
|
|
309
|
+
Bitmap = None
|
|
310
|
+
Fog = None
|
|
311
|
+
Frost = None
|
|
312
|
+
Snow = None
|
|
313
|
+
Rain = None
|
|
314
|
+
Shadow = None
|
|
315
|
+
|
|
316
|
+
|
|
317
|
+
def clipped_zoom(img, zoom_factor):
|
|
318
|
+
h = img.shape[1]
|
|
319
|
+
ch = int(np.ceil(h / float(zoom_factor)))
|
|
320
|
+
top = (h - ch) // 2
|
|
321
|
+
img = scizoom(img[top:top + ch, top:top + ch],
|
|
322
|
+
(zoom_factor, zoom_factor, 1),
|
|
323
|
+
order=1)
|
|
324
|
+
trim_top = (img.shape[0] - h) // 2
|
|
325
|
+
return img[trim_top:trim_top + h, trim_top:trim_top + h]
|
|
326
|
+
|
|
327
|
+
|
|
328
|
+
def disk(radius, alias_blur=0.1, dtype=np.float32):
|
|
329
|
+
if cv2 is None:
|
|
330
|
+
return np.zeros((1, 1), dtype=dtype)
|
|
331
|
+
if radius <= 8:
|
|
332
|
+
coords = np.arange(-8, 8 + 1)
|
|
333
|
+
ksize = (3, 3)
|
|
334
|
+
else:
|
|
335
|
+
coords = np.arange(-radius, radius + 1)
|
|
336
|
+
ksize = (5, 5)
|
|
337
|
+
x, y = np.meshgrid(coords, coords)
|
|
338
|
+
aliased_disk = np.asarray((x**2 + y**2) <= radius**2, dtype=dtype)
|
|
339
|
+
aliased_disk /= np.sum(aliased_disk)
|
|
340
|
+
return cv2.GaussianBlur(aliased_disk, ksize=ksize, sigmaX=alias_blur)
|
|
341
|
+
|
|
342
|
+
|
|
343
|
+
def plasma_fractal(mapsize=256, wibbledecay=3, rng=None):
|
|
344
|
+
assert (mapsize & (mapsize - 1) == 0)
|
|
345
|
+
maparray = np.empty((mapsize, mapsize), dtype=np.float64)
|
|
346
|
+
maparray[0, 0] = 0
|
|
347
|
+
stepsize = mapsize
|
|
348
|
+
wibble = 100
|
|
349
|
+
if rng is None:
|
|
350
|
+
rng = np.random.default_rng()
|
|
351
|
+
|
|
352
|
+
def wibbledmean(array):
|
|
353
|
+
return array / 4 + wibble * rng.uniform(-wibble, wibble, array.shape)
|
|
354
|
+
|
|
355
|
+
def fillsquares():
|
|
356
|
+
cornerref = maparray[0:mapsize:stepsize, 0:mapsize:stepsize]
|
|
357
|
+
squareaccum = cornerref + np.roll(cornerref, shift=-1, axis=0)
|
|
358
|
+
squareaccum += np.roll(squareaccum, shift=-1, axis=1)
|
|
359
|
+
maparray[stepsize // 2:mapsize:stepsize,
|
|
360
|
+
stepsize // 2:mapsize:stepsize] = wibbledmean(squareaccum)
|
|
361
|
+
|
|
362
|
+
def filldiamonds():
|
|
363
|
+
drgrid = maparray[stepsize // 2:mapsize:stepsize,
|
|
364
|
+
stepsize // 2:mapsize:stepsize]
|
|
365
|
+
ulgrid = maparray[0:mapsize:stepsize, 0:mapsize:stepsize]
|
|
366
|
+
ldrsum = drgrid + np.roll(drgrid, 1, axis=0)
|
|
367
|
+
lulsum = ulgrid + np.roll(ulgrid, -1, axis=1)
|
|
368
|
+
ltsum = ldrsum + lulsum
|
|
369
|
+
maparray[0:mapsize:stepsize,
|
|
370
|
+
stepsize // 2:mapsize:stepsize] = wibbledmean(ltsum)
|
|
371
|
+
tdrsum = drgrid + np.roll(drgrid, 1, axis=1)
|
|
372
|
+
tulsum = ulgrid + np.roll(ulgrid, -1, axis=0)
|
|
373
|
+
ttsum = tdrsum + tulsum
|
|
374
|
+
maparray[stepsize // 2:mapsize:stepsize,
|
|
375
|
+
0:mapsize:stepsize] = wibbledmean(ttsum)
|
|
376
|
+
|
|
377
|
+
while stepsize >= 2:
|
|
378
|
+
fillsquares()
|
|
379
|
+
filldiamonds()
|
|
380
|
+
stepsize //= 2
|
|
381
|
+
wibble /= wibbledecay
|
|
382
|
+
maparray -= maparray.min()
|
|
383
|
+
return maparray / maparray.max()
|
|
384
|
+
|
|
385
|
+
|
|
386
|
+
def motion_blur(img: np.ndarray, radius: int, sigma: float,
|
|
387
|
+
angle: float) -> np.ndarray:
|
|
388
|
+
if cv2 is None:
|
|
389
|
+
return img
|
|
390
|
+
kernel_size = max(1, int(radius) * 2 + 1)
|
|
391
|
+
psf = np.zeros((kernel_size, kernel_size), dtype=np.float32)
|
|
392
|
+
psf[kernel_size // 2] = 1.0
|
|
393
|
+
M = cv2.getRotationMatrix2D((kernel_size / 2, kernel_size / 2), angle, 1)
|
|
394
|
+
psf = cv2.warpAffine(psf, M, (kernel_size, kernel_size))
|
|
395
|
+
if sigma > 0:
|
|
396
|
+
psf = cv2.GaussianBlur(psf, (kernel_size, kernel_size), sigma)
|
|
397
|
+
psf /= psf.sum() if psf.sum() != 0 else 1
|
|
398
|
+
return cv2.filter2D(img, -1, psf, borderType=cv2.BORDER_REPLICATE)
|
|
399
|
+
|
|
400
|
+
|
|
401
|
+
class CMERImageProcessor(BaseImageProcessor):
|
|
402
|
+
model_input_names = [
|
|
403
|
+
'pixel_values', 'orig_spatial_shape', 'expanded_from_indices',
|
|
404
|
+
'is_original_flags'
|
|
405
|
+
]
|
|
406
|
+
|
|
407
|
+
def __init__(
|
|
408
|
+
self,
|
|
409
|
+
down_sample_ratio: int = 32,
|
|
410
|
+
do_convert_rgb: bool = True,
|
|
411
|
+
do_rescale: bool = True,
|
|
412
|
+
rescale_factor: float = 1.0 / 255.0,
|
|
413
|
+
do_normalize: bool = True,
|
|
414
|
+
image_mean: Optional[Union[float, list[float]]] = None,
|
|
415
|
+
image_std: Optional[Union[float, list[float]]] = None,
|
|
416
|
+
resample: 'PILImageResampling' = PILImageResampling.BILINEAR,
|
|
417
|
+
output_channel_format: ChannelDimension = ChannelDimension.FIRST,
|
|
418
|
+
pad_value_strategy: str = 'mean',
|
|
419
|
+
pad_value: Optional[Union[float, List[float]]] = None,
|
|
420
|
+
center_pad: bool = False,
|
|
421
|
+
do_augment: bool = True,
|
|
422
|
+
augment_prob: float = 1.0,
|
|
423
|
+
pre_pad_expand_ratio: float = 0.04,
|
|
424
|
+
pre_pad_min_px: int = 8,
|
|
425
|
+
aug_repeats: int = 0,
|
|
426
|
+
keep_original: bool = True,
|
|
427
|
+
num_workers: int = 8,
|
|
428
|
+
pad_num_workers: Optional[int] = None,
|
|
429
|
+
resize_backend: str = 'auto',
|
|
430
|
+
normalize_inplace: bool = True,
|
|
431
|
+
**kwargs,
|
|
432
|
+
):
|
|
433
|
+
super().__init__(**kwargs)
|
|
434
|
+
self.down_sample_ratio = int(down_sample_ratio)
|
|
435
|
+
self.do_convert_rgb = bool(do_convert_rgb)
|
|
436
|
+
self.do_rescale = bool(do_rescale)
|
|
437
|
+
self.rescale_factor = float(rescale_factor)
|
|
438
|
+
self.do_normalize = bool(do_normalize)
|
|
439
|
+
self.image_mean = image_mean if image_mean is not None else [
|
|
440
|
+
0.5, 0.5, 0.5
|
|
441
|
+
]
|
|
442
|
+
self.image_std = image_std if image_std is not None else [
|
|
443
|
+
0.5, 0.5, 0.5
|
|
444
|
+
]
|
|
445
|
+
self.resample = resample
|
|
446
|
+
self.output_channel_format = output_channel_format
|
|
447
|
+
self.pad_value_strategy = str(pad_value_strategy).lower()
|
|
448
|
+
self.pad_value = pad_value
|
|
449
|
+
self.center_pad = bool(center_pad)
|
|
450
|
+
self.default_do_augment = bool(do_augment)
|
|
451
|
+
self.augment_prob = float(augment_prob)
|
|
452
|
+
self.pre_pad_expand_ratio = float(pre_pad_expand_ratio)
|
|
453
|
+
self.pre_pad_min_px = int(pre_pad_min_px)
|
|
454
|
+
self.aug_repeats = max(int(aug_repeats), 0)
|
|
455
|
+
self.keep_original = bool(keep_original)
|
|
456
|
+
self.num_workers = max(int(num_workers), 0)
|
|
457
|
+
self.pad_num_workers = pad_num_workers if pad_num_workers is not None else self.num_workers
|
|
458
|
+
self.resize_backend = resize_backend
|
|
459
|
+
self.normalize_inplace = bool(normalize_inplace)
|
|
460
|
+
self._augmentations = self._build_augmentations()
|
|
461
|
+
|
|
462
|
+
def _build_augmentations(self):
|
|
463
|
+
if A is None:
|
|
464
|
+
logger.warning_once(
|
|
465
|
+
f"[CMERImageProcessor] Albumentations 未安装,跳过图像增强。{_A_IMPORT_ERR if '_A_IMPORT_ERR' in globals() else ''}"
|
|
466
|
+
)
|
|
467
|
+
return None
|
|
468
|
+
tlist = []
|
|
469
|
+
if Bitmap is not None:
|
|
470
|
+
tlist.append(Bitmap(p=0.2))
|
|
471
|
+
weather_ops = []
|
|
472
|
+
for op in (Fog, Frost, Snow, Rain, Shadow):
|
|
473
|
+
if op is not None:
|
|
474
|
+
try:
|
|
475
|
+
weather_ops.append(op())
|
|
476
|
+
except Exception:
|
|
477
|
+
pass
|
|
478
|
+
if weather_ops:
|
|
479
|
+
tlist.append(A.OneOf(weather_ops, p=0.5))
|
|
480
|
+
morph_ops = []
|
|
481
|
+
if Erosion is not None:
|
|
482
|
+
try:
|
|
483
|
+
morph_ops.append(Erosion((2, 3)))
|
|
484
|
+
except Exception:
|
|
485
|
+
pass
|
|
486
|
+
if Dilation is not None:
|
|
487
|
+
try:
|
|
488
|
+
morph_ops.append(Dilation((2, 3)))
|
|
489
|
+
except Exception:
|
|
490
|
+
pass
|
|
491
|
+
if morph_ops:
|
|
492
|
+
tlist.append(A.OneOf(morph_ops, p=0.2))
|
|
493
|
+
tlist.extend([
|
|
494
|
+
A.ShiftScaleRotate(shift_limit=0,
|
|
495
|
+
scale_limit=(-.15, 0),
|
|
496
|
+
rotate_limit=1,
|
|
497
|
+
border_mode=0,
|
|
498
|
+
interpolation=3,
|
|
499
|
+
value=[255, 255, 255],
|
|
500
|
+
p=1),
|
|
501
|
+
A.GridDistortion(distort_limit=0.1,
|
|
502
|
+
border_mode=0,
|
|
503
|
+
interpolation=3,
|
|
504
|
+
value=[255, 255, 255],
|
|
505
|
+
p=0.5),
|
|
506
|
+
A.RGBShift(r_shift_limit=15,
|
|
507
|
+
g_shift_limit=15,
|
|
508
|
+
b_shift_limit=15,
|
|
509
|
+
p=0.3),
|
|
510
|
+
A.GaussNoise(var_limit=(10.0, 20.0), p=0.2),
|
|
511
|
+
A.RandomBrightnessContrast(0.05, (-0.2, 0), True, p=0.2),
|
|
512
|
+
])
|
|
513
|
+
return A.Compose(tlist, p=self.augment_prob)
|
|
514
|
+
|
|
515
|
+
@staticmethod
|
|
516
|
+
def _constant_border(img: np.ndarray,
|
|
517
|
+
pad_px: int,
|
|
518
|
+
value: int = 255) -> np.ndarray:
|
|
519
|
+
if pad_px <= 0:
|
|
520
|
+
return img
|
|
521
|
+
if cv2 is not None:
|
|
522
|
+
return cv2.copyMakeBorder(img,
|
|
523
|
+
pad_px,
|
|
524
|
+
pad_px,
|
|
525
|
+
pad_px,
|
|
526
|
+
pad_px,
|
|
527
|
+
cv2.BORDER_CONSTANT,
|
|
528
|
+
value=[value, value, value])
|
|
529
|
+
return np.pad(img, ((pad_px, pad_px), (pad_px, pad_px), (0, 0)),
|
|
530
|
+
constant_values=value)
|
|
531
|
+
|
|
532
|
+
def _maybe_augment_uint8(self, img_uint8: np.ndarray, seed: Optional[int],
|
|
533
|
+
pre_pad_px: int) -> np.ndarray:
|
|
534
|
+
if self._augmentations is None:
|
|
535
|
+
return img_uint8
|
|
536
|
+
if pre_pad_px > 0:
|
|
537
|
+
img_uint8 = self._constant_border(img_uint8, pre_pad_px, value=255)
|
|
538
|
+
if seed is not None:
|
|
539
|
+
rng_state = np.random.get_state()
|
|
540
|
+
np.random.seed(seed)
|
|
541
|
+
try:
|
|
542
|
+
out = self._augmentations(image=img_uint8)['image']
|
|
543
|
+
finally:
|
|
544
|
+
np.random.set_state(rng_state)
|
|
545
|
+
return out
|
|
546
|
+
else:
|
|
547
|
+
return self._augmentations(image=img_uint8)['image']
|
|
548
|
+
|
|
549
|
+
def _prep_uint8(self, img, input_data_format) -> np.ndarray:
|
|
550
|
+
if self.do_convert_rgb:
|
|
551
|
+
img = convert_to_rgb(img)
|
|
552
|
+
np_img = to_numpy_array(img)
|
|
553
|
+
if input_data_format is None:
|
|
554
|
+
_fmt = infer_channel_dimension_format(np_img)
|
|
555
|
+
else:
|
|
556
|
+
_fmt = input_data_format
|
|
557
|
+
if _fmt == ChannelDimension.FIRST:
|
|
558
|
+
np_img = np.transpose(np_img, (1, 2, 0))
|
|
559
|
+
elif _fmt == ChannelDimension.LAST:
|
|
560
|
+
pass
|
|
561
|
+
else:
|
|
562
|
+
np_img = to_channel_dimension_format(np_img,
|
|
563
|
+
ChannelDimension.LAST,
|
|
564
|
+
input_channel_dim=_fmt)
|
|
565
|
+
if np_img.dtype != np.uint8:
|
|
566
|
+
if np_img.dtype.kind == 'f':
|
|
567
|
+
np_img = np.clip(np_img, 0.0, 1.0)
|
|
568
|
+
np_img = (np_img * 255.0 + 0.5).astype(np.uint8)
|
|
569
|
+
else:
|
|
570
|
+
np_img = np_img.astype(np.uint8)
|
|
571
|
+
return np_img
|
|
572
|
+
|
|
573
|
+
def _resize_uint8(self, img_uint8: np.ndarray, th: int, tw: int,
|
|
574
|
+
backend: str) -> np.ndarray:
|
|
575
|
+
if backend == 'cv2' and cv2 is not None:
|
|
576
|
+
return cv2.resize(img_uint8, (tw, th), interpolation=1)
|
|
577
|
+
return resize(img_uint8,
|
|
578
|
+
size=(th, tw),
|
|
579
|
+
resample=self.resample,
|
|
580
|
+
input_data_format=ChannelDimension.LAST)
|
|
581
|
+
|
|
582
|
+
def preprocess_auto(self,
|
|
583
|
+
images: ImageInput,
|
|
584
|
+
return_tensors: Optional[Union[str,
|
|
585
|
+
TensorType]] = None,
|
|
586
|
+
trainer=None,
|
|
587
|
+
**kwargs) -> BatchFeature:
|
|
588
|
+
if trainer is not None and getattr(trainer, 'is_in_train', False):
|
|
589
|
+
kwargs.setdefault('do_augment', True)
|
|
590
|
+
return self.preprocess(images, return_tensors=return_tensors, **kwargs)
|
|
591
|
+
|
|
592
|
+
@filter_out_non_signature_kwargs()
|
|
593
|
+
def preprocess(
|
|
594
|
+
self,
|
|
595
|
+
images: ImageInput,
|
|
596
|
+
return_tensors: Optional[Union[str, TensorType]] = None,
|
|
597
|
+
input_data_format: Optional[Union[str, ChannelDimension]] = None,
|
|
598
|
+
do_convert_rgb: Optional[bool] = None,
|
|
599
|
+
do_rescale: Optional[bool] = None,
|
|
600
|
+
rescale_factor: Optional[float] = None,
|
|
601
|
+
do_normalize: Optional[bool] = None,
|
|
602
|
+
image_mean: Optional[Union[float, list[float]]] = None,
|
|
603
|
+
image_std: Optional[Union[float, list[float]]] = None,
|
|
604
|
+
down_sample_ratio: Optional[int] = None,
|
|
605
|
+
resample: Optional['PILImageResampling'] = None,
|
|
606
|
+
output_channel_format: Optional[ChannelDimension] = None,
|
|
607
|
+
pad_value_strategy: Optional[str] = None,
|
|
608
|
+
pad_value: Optional[Union[float, List[float]]] = None,
|
|
609
|
+
center_pad: Optional[bool] = None,
|
|
610
|
+
do_augment: Optional[bool] = True,
|
|
611
|
+
augment_seed: Optional[int] = None,
|
|
612
|
+
pre_pad_expand_ratio: Optional[float] = None,
|
|
613
|
+
pre_pad_min_px: Optional[int] = None,
|
|
614
|
+
aug_repeats: Optional[int] = None,
|
|
615
|
+
keep_original: Optional[bool] = None,
|
|
616
|
+
num_workers: Optional[int] = None,
|
|
617
|
+
pad_num_workers: Optional[int] = None,
|
|
618
|
+
resize_backend: Optional[str] = None,
|
|
619
|
+
normalize_inplace: Optional[bool] = None,
|
|
620
|
+
) -> BatchFeature:
|
|
621
|
+
do_convert_rgb = self.do_convert_rgb if do_convert_rgb is None else do_convert_rgb
|
|
622
|
+
do_rescale = self.do_rescale if do_rescale is None else do_rescale
|
|
623
|
+
rescale_factor = self.rescale_factor if rescale_factor is None else rescale_factor
|
|
624
|
+
do_normalize = self.do_normalize if do_normalize is None else do_normalize
|
|
625
|
+
image_mean = self.image_mean if image_mean is None else image_mean
|
|
626
|
+
image_std = self.image_std if image_std is None else image_std
|
|
627
|
+
down_sample_ratio = self.down_sample_ratio if down_sample_ratio is None else int(
|
|
628
|
+
down_sample_ratio)
|
|
629
|
+
resample = self.resample if resample is None else resample
|
|
630
|
+
output_channel_format = self.output_channel_format if output_channel_format is None else output_channel_format
|
|
631
|
+
pad_value_strategy = self.pad_value_strategy if pad_value_strategy is None else pad_value_strategy.lower(
|
|
632
|
+
)
|
|
633
|
+
pad_value = self.pad_value if pad_value is None else pad_value
|
|
634
|
+
center_pad = self.center_pad if center_pad is None else bool(
|
|
635
|
+
center_pad)
|
|
636
|
+
do_augment = self.default_do_augment if do_augment is None else bool(
|
|
637
|
+
do_augment)
|
|
638
|
+
pre_pad_expand_ratio = self.pre_pad_expand_ratio if pre_pad_expand_ratio is None else float(
|
|
639
|
+
pre_pad_expand_ratio)
|
|
640
|
+
pre_pad_min_px = self.pre_pad_min_px if pre_pad_min_px is None else int(
|
|
641
|
+
pre_pad_min_px)
|
|
642
|
+
aug_repeats = self.aug_repeats if aug_repeats is None else max(
|
|
643
|
+
int(aug_repeats), 0)
|
|
644
|
+
keep_original = self.keep_original if keep_original is None else bool(
|
|
645
|
+
keep_original)
|
|
646
|
+
num_workers = self.num_workers if num_workers is None else max(
|
|
647
|
+
int(num_workers), 0)
|
|
648
|
+
pad_num_workers = self.pad_num_workers if pad_num_workers is None else max(
|
|
649
|
+
int(pad_num_workers), 0)
|
|
650
|
+
resize_backend = (self.resize_backend if resize_backend is None else
|
|
651
|
+
resize_backend).lower()
|
|
652
|
+
normalize_inplace = self.normalize_inplace if normalize_inplace is None else bool(
|
|
653
|
+
normalize_inplace)
|
|
654
|
+
if type(images) is dict:
|
|
655
|
+
images = images.get('image', None)
|
|
656
|
+
images = self.fetch_images(images)
|
|
657
|
+
else:
|
|
658
|
+
images = self.fetch_images(images)
|
|
659
|
+
images = make_flat_list_of_images(images)
|
|
660
|
+
if not valid_images(images):
|
|
661
|
+
raise ValueError(
|
|
662
|
+
'Invalid image type. Must be PIL.Image.Image, numpy.ndarray, or torch.Tensor'
|
|
663
|
+
)
|
|
664
|
+
validate_preprocess_arguments(
|
|
665
|
+
do_rescale=do_rescale,
|
|
666
|
+
rescale_factor=rescale_factor,
|
|
667
|
+
do_normalize=do_normalize,
|
|
668
|
+
image_mean=image_mean,
|
|
669
|
+
image_std=image_std,
|
|
670
|
+
)
|
|
671
|
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
672
|
+
|
|
673
|
+
def _process_one(idx_img: int):
|
|
674
|
+
base = self._prep_uint8(images[idx_img], input_data_format)
|
|
675
|
+
h0, w0 = base.shape[:2]
|
|
676
|
+
results_imgs: List[np.ndarray] = []
|
|
677
|
+
results_sizes: List[Tuple[int, int]] = []
|
|
678
|
+
results_from: List[int] = []
|
|
679
|
+
results_flag: List[bool] = []
|
|
680
|
+
cand: List[Tuple[np.ndarray, bool]] = []
|
|
681
|
+
if keep_original:
|
|
682
|
+
cand.append((base, True))
|
|
683
|
+
if do_augment and self._augmentations is not None and aug_repeats > 0:
|
|
684
|
+
est_pad = max(
|
|
685
|
+
int(max(h0, w0) * pre_pad_expand_ratio),
|
|
686
|
+
pre_pad_min_px if pre_pad_expand_ratio > 0 else 0)
|
|
687
|
+
for k in range(aug_repeats):
|
|
688
|
+
seed_k = None if augment_seed is None else (
|
|
689
|
+
int(augment_seed) + idx_img * (aug_repeats + 1) + k)
|
|
690
|
+
aug_img = self._maybe_augment_uint8(base,
|
|
691
|
+
seed=seed_k,
|
|
692
|
+
pre_pad_px=est_pad)
|
|
693
|
+
cand.append((aug_img, False))
|
|
694
|
+
|
|
695
|
+
is_cv2_avail = (resize_backend == 'auto' and cv2 is not None)
|
|
696
|
+
be = 'cv2' if is_cv2_avail else resize_backend
|
|
697
|
+
max_long_edge = 1024
|
|
698
|
+
|
|
699
|
+
for uint8_img, is_orig in cand:
|
|
700
|
+
hh, ww = uint8_img.shape[:2]
|
|
701
|
+
if max(hh, ww) > max_long_edge:
|
|
702
|
+
scale = float(max_long_edge) / float(max(hh, ww))
|
|
703
|
+
targ_h = max(1, int(math.floor(hh * scale)))
|
|
704
|
+
targ_w = max(1, int(math.floor(ww * scale)))
|
|
705
|
+
uint8_img = self._resize_uint8(uint8_img, targ_h, targ_w,
|
|
706
|
+
be)
|
|
707
|
+
hh, ww = uint8_img.shape[:2]
|
|
708
|
+
MIN_HW = 224
|
|
709
|
+
ds = down_sample_ratio
|
|
710
|
+
ceil_h = max(MIN_HW, math.ceil(hh / ds) * ds)
|
|
711
|
+
ceil_w = max(MIN_HW, math.ceil(ww / ds) * ds)
|
|
712
|
+
if max(ceil_h, ceil_w) <= max_long_edge:
|
|
713
|
+
th, tw = ceil_h, ceil_w
|
|
714
|
+
else:
|
|
715
|
+
floor_h = max(MIN_HW, (hh // ds) * ds)
|
|
716
|
+
floor_w = max(MIN_HW, (ww // ds) * ds)
|
|
717
|
+
if floor_h <= 0 or floor_w <= 0:
|
|
718
|
+
floor_h = max(MIN_HW,
|
|
719
|
+
min(hh, max_long_edge) // ds * ds)
|
|
720
|
+
floor_w = max(MIN_HW,
|
|
721
|
+
min(ww, max_long_edge) // ds * ds)
|
|
722
|
+
th, tw = floor_h, floor_w
|
|
723
|
+
|
|
724
|
+
rs_img = self._resize_uint8(uint8_img, th, tw, be)
|
|
725
|
+
if do_rescale:
|
|
726
|
+
rs_img = rs_img.astype(np.float32)
|
|
727
|
+
np.multiply(rs_img,
|
|
728
|
+
float(rescale_factor),
|
|
729
|
+
out=rs_img,
|
|
730
|
+
casting='unsafe')
|
|
731
|
+
else:
|
|
732
|
+
rs_img = rs_img.astype(np.float32)
|
|
733
|
+
results_imgs.append(rs_img)
|
|
734
|
+
results_sizes.append((th, tw))
|
|
735
|
+
results_from.append(idx_img)
|
|
736
|
+
results_flag.append(is_orig)
|
|
737
|
+
return results_imgs, results_sizes, results_from, results_flag
|
|
738
|
+
|
|
739
|
+
proc_list: List[np.ndarray] = []
|
|
740
|
+
rec_sizes: List[Tuple[int, int]] = []
|
|
741
|
+
from_indices: List[int] = []
|
|
742
|
+
is_orig_flags: List[bool] = []
|
|
743
|
+
if num_workers and num_workers > 1 and len(images) > 1:
|
|
744
|
+
with ThreadPoolExecutor(max_workers=num_workers) as ex:
|
|
745
|
+
futs = [ex.submit(_process_one, i) for i in range(len(images))]
|
|
746
|
+
for fu in as_completed(futs):
|
|
747
|
+
imgs_i, sizes_i, from_i, flag_i = fu.result()
|
|
748
|
+
proc_list.extend(imgs_i)
|
|
749
|
+
rec_sizes.extend(sizes_i)
|
|
750
|
+
from_indices.extend(from_i)
|
|
751
|
+
is_orig_flags.extend(flag_i)
|
|
752
|
+
else:
|
|
753
|
+
for i in range(len(images)):
|
|
754
|
+
imgs_i, sizes_i, from_i, flag_i = _process_one(i)
|
|
755
|
+
proc_list.extend(imgs_i)
|
|
756
|
+
rec_sizes.extend(sizes_i)
|
|
757
|
+
from_indices.extend(from_i)
|
|
758
|
+
is_orig_flags.extend(flag_i)
|
|
759
|
+
if len(proc_list) == 0:
|
|
760
|
+
return BatchFeature(data={
|
|
761
|
+
'image': [],
|
|
762
|
+
'orig_spatial_shape': [],
|
|
763
|
+
'expanded_from_indices': [],
|
|
764
|
+
'is_original_flags': []
|
|
765
|
+
},
|
|
766
|
+
tensor_type=return_tensors)
|
|
767
|
+
max_h = max(h for h, _ in rec_sizes)
|
|
768
|
+
max_w = max(w for _, w in rec_sizes)
|
|
769
|
+
mean = np.array(image_mean, dtype=np.float32)
|
|
770
|
+
std = np.array(image_std, dtype=np.float32)
|
|
771
|
+
inv_std = 1.0 / np.where(std == 0, 1.0, std)
|
|
772
|
+
|
|
773
|
+
def _maybe_scale_stats_to_image_domain(
|
|
774
|
+
_arr: np.ndarray, exemplar: np.ndarray) -> np.ndarray:
|
|
775
|
+
if not do_rescale and exemplar.max() > 1.5 and _arr.max() <= 1.5:
|
|
776
|
+
return _arr * 255.0
|
|
777
|
+
return _arr
|
|
778
|
+
|
|
779
|
+
def _make_pad_color(c: int, exemplar: np.ndarray) -> np.ndarray:
|
|
780
|
+
_mean = _maybe_scale_stats_to_image_domain(mean, exemplar)
|
|
781
|
+
if pad_value_strategy == 'mean':
|
|
782
|
+
col = _mean
|
|
783
|
+
elif pad_value_strategy == 'white':
|
|
784
|
+
col = np.ones(
|
|
785
|
+
(c, ), dtype=np.float32) * (1.0 if do_rescale else 255.0)
|
|
786
|
+
elif pad_value_strategy == 'zero':
|
|
787
|
+
col = np.zeros((c, ), dtype=np.float32)
|
|
788
|
+
elif pad_value_strategy == 'custom':
|
|
789
|
+
if pad_value is None:
|
|
790
|
+
col = _mean
|
|
791
|
+
else:
|
|
792
|
+
col = np.array(pad_value, dtype=np.float32)
|
|
793
|
+
if col.ndim == 0:
|
|
794
|
+
col = np.full((c, ), float(col), dtype=np.float32)
|
|
795
|
+
if col.shape[0] != c:
|
|
796
|
+
raise ValueError(
|
|
797
|
+
f'pad_value length must match channels={c}')
|
|
798
|
+
else:
|
|
799
|
+
col = _mean
|
|
800
|
+
return col
|
|
801
|
+
|
|
802
|
+
def _to_ch_first(arr: np.ndarray) -> np.ndarray:
|
|
803
|
+
return np.transpose(arr, (2, 0, 1))
|
|
804
|
+
|
|
805
|
+
batched: List[np.ndarray] = [None] * len(proc_list)
|
|
806
|
+
|
|
807
|
+
def _pad_one(i: int):
|
|
808
|
+
np_img = proc_list[i]
|
|
809
|
+
h, w = rec_sizes[i]
|
|
810
|
+
C = np_img.shape[2]
|
|
811
|
+
pad_color = _make_pad_color(C, np_img)
|
|
812
|
+
if center_pad:
|
|
813
|
+
y0 = (max_h - h) // 2
|
|
814
|
+
x0 = (max_w - w) // 2
|
|
815
|
+
else:
|
|
816
|
+
y0 = 0
|
|
817
|
+
x0 = 0
|
|
818
|
+
pad_img = np.empty((max_h, max_w, C), dtype=np.float32)
|
|
819
|
+
pad_img[...] = pad_color
|
|
820
|
+
pad_img[y0:y0 + h, x0:x0 + w, :] = np_img
|
|
821
|
+
if do_normalize:
|
|
822
|
+
_mean = _maybe_scale_stats_to_image_domain(mean, pad_img)
|
|
823
|
+
_invstd = _maybe_scale_stats_to_image_domain(inv_std, pad_img)
|
|
824
|
+
if normalize_inplace:
|
|
825
|
+
np.subtract(pad_img, _mean, out=pad_img)
|
|
826
|
+
np.multiply(pad_img, _invstd, out=pad_img)
|
|
827
|
+
else:
|
|
828
|
+
pad_img = (pad_img - _mean) * _invstd
|
|
829
|
+
batched[i] = _to_ch_first(
|
|
830
|
+
pad_img
|
|
831
|
+
) if output_channel_format == ChannelDimension.FIRST else pad_img
|
|
832
|
+
|
|
833
|
+
if pad_num_workers and pad_num_workers > 1 and len(proc_list) > 1:
|
|
834
|
+
with ThreadPoolExecutor(max_workers=pad_num_workers) as ex:
|
|
835
|
+
list(ex.map(_pad_one, range(len(proc_list))))
|
|
836
|
+
else:
|
|
837
|
+
for i in range(len(proc_list)):
|
|
838
|
+
_pad_one(i)
|
|
839
|
+
return BatchFeature(
|
|
840
|
+
data={
|
|
841
|
+
'image': batched,
|
|
842
|
+
'orig_spatial_shape': rec_sizes,
|
|
843
|
+
'expanded_from_indices': from_indices,
|
|
844
|
+
'is_original_flags': is_orig_flags,
|
|
845
|
+
},
|
|
846
|
+
tensor_type=return_tensors,
|
|
847
|
+
)
|
|
848
|
+
|
|
849
|
+
|
|
850
|
+
AutoImageProcessor.register('CMER',
|
|
851
|
+
slow_image_processor_class=CMERImageProcessor)
|
|
852
|
+
class CMERProcessor(ProcessorMixin):
|
|
853
|
+
attributes = ['image_processor', 'tokenizer']
|
|
854
|
+
image_processor_class = 'CMERImageProcessor'
|
|
855
|
+
tokenizer_class = 'PreTrainedTokenizerFast'
|
|
856
|
+
|
|
857
|
+
def __init__(
|
|
858
|
+
self,
|
|
859
|
+
image_processor=None,
|
|
860
|
+
tokenizer=None,
|
|
861
|
+
tokenizer_file: str = './configs/rec/cmer/cmer_tokenizer/tokenizer.json',
|
|
862
|
+
**kwargs
|
|
863
|
+
):
|
|
864
|
+
if image_processor is None:
|
|
865
|
+
# 确保这里能正确导入你的 CMERImageProcessor
|
|
866
|
+
image_processor = CMERImageProcessor(**kwargs)
|
|
867
|
+
|
|
868
|
+
if tokenizer is None:
|
|
869
|
+
try:
|
|
870
|
+
tokenizer = PreTrainedTokenizerFast(
|
|
871
|
+
tokenizer_file=tokenizer_file,
|
|
872
|
+
padding_side="right",
|
|
873
|
+
truncation_side="right",
|
|
874
|
+
pad_token="<|pad|>",
|
|
875
|
+
bos_token="<|bos|>",
|
|
876
|
+
eos_token="<|eos|>",
|
|
877
|
+
unk_token="<|unk|>",
|
|
878
|
+
)
|
|
879
|
+
except Exception as e:
|
|
880
|
+
# logger 需要外部定义或引入,这里简单用 print 代替
|
|
881
|
+
print(f"Failed to initialize default tokenizer from {tokenizer_file}. Error: {e}")
|
|
882
|
+
tokenizer = None
|
|
883
|
+
|
|
884
|
+
super().__init__(image_processor=image_processor, tokenizer=tokenizer)
|
|
885
|
+
|
|
886
|
+
|
|
887
|
+
def __call__(
|
|
888
|
+
self,
|
|
889
|
+
images: ImageInput,
|
|
890
|
+
text: Union[str, List[str]]=None,
|
|
891
|
+
ids=None,
|
|
892
|
+
categorys=None,
|
|
893
|
+
return_tensors: Optional[Union[str, TensorType]] = "pt",
|
|
894
|
+
padding: Union[bool, str] = True,
|
|
895
|
+
truncation: bool = True,
|
|
896
|
+
max_length: Optional[int] = None,
|
|
897
|
+
**img_kwargs,
|
|
898
|
+
):
|
|
899
|
+
if isinstance(images, dict) and "image" in images:
|
|
900
|
+
images = images["image"]
|
|
901
|
+
# 情况 2: 列表样本,例如 [{'image': <PIL...>}, {'image': <PIL...>}]
|
|
902
|
+
elif isinstance(images, (list, tuple)) and len(images) > 0 and isinstance(images[0], dict) and "image" in images[0]:
|
|
903
|
+
images = [img["image"] for img in images]
|
|
904
|
+
# 计算输入图片的数量,用于后续生成默认 text
|
|
905
|
+
if isinstance(images, (list, tuple)):
|
|
906
|
+
input_batch_size = len(images)
|
|
907
|
+
else:
|
|
908
|
+
input_batch_size = 1
|
|
909
|
+
image_outputs: BatchFeature = self.image_processor.preprocess(
|
|
910
|
+
images=images,
|
|
911
|
+
return_tensors=return_tensors,
|
|
912
|
+
**img_kwargs,
|
|
913
|
+
)
|
|
914
|
+
expanded_from = image_outputs.get("expanded_from_indices")
|
|
915
|
+
# =================================================================
|
|
916
|
+
# 2. [修复核心报错] 处理 text/ids/categorys 为 None 的情况
|
|
917
|
+
# =================================================================
|
|
918
|
+
# 如果 text 为 None (推理模式),生成空字符串列表
|
|
919
|
+
if text is None:
|
|
920
|
+
text_list = [""] * input_batch_size
|
|
921
|
+
elif isinstance(text, str):
|
|
922
|
+
text_list = [text]
|
|
923
|
+
else:
|
|
924
|
+
text_list = list(text)
|
|
925
|
+
|
|
926
|
+
# 如果 ids 为 None,生成默认占位符
|
|
927
|
+
if ids is None:
|
|
928
|
+
ids_list = [None] * len(text_list)
|
|
929
|
+
else:
|
|
930
|
+
ids_list = list(ids)
|
|
931
|
+
|
|
932
|
+
# 如果 categorys 为 None,生成默认占位符
|
|
933
|
+
if categorys is None:
|
|
934
|
+
cats_list = [None] * len(text_list)
|
|
935
|
+
else:
|
|
936
|
+
cats_list = list(categorys)
|
|
937
|
+
# =================================================================
|
|
938
|
+
|
|
939
|
+
if expanded_from is None:
|
|
940
|
+
num_in = len(text_list)
|
|
941
|
+
expanded_from = list(range(num_in))
|
|
942
|
+
else:
|
|
943
|
+
num_in = max(expanded_from) + 1
|
|
944
|
+
|
|
945
|
+
# 检查长度一致性
|
|
946
|
+
if not (len(text_list) == num_in == len(ids_list) == len(cats_list)):
|
|
947
|
+
raise ValueError(
|
|
948
|
+
f"[CMERProcessor] Mismatch between base counts: "
|
|
949
|
+
f"text={len(text_list)}, ids={len(ids_list)}, "
|
|
950
|
+
f"cats={len(cats_list)}, num_in(from expanded_from)={num_in}"
|
|
951
|
+
)
|
|
952
|
+
|
|
953
|
+
bos_token = self.tokenizer.bos_token
|
|
954
|
+
eos_token = self.tokenizer.eos_token
|
|
955
|
+
if bos_token is None or eos_token is None:
|
|
956
|
+
raise ValueError("Tokenizer must have a `bos_token` and an `eos_token`.")
|
|
957
|
+
|
|
958
|
+
base_texts = text_list
|
|
959
|
+
base_ids = ids_list
|
|
960
|
+
base_cats = cats_list
|
|
961
|
+
|
|
962
|
+
try:
|
|
963
|
+
expanded_texts = [
|
|
964
|
+
f"{bos_token}{base_texts[src]}{eos_token}" for src in expanded_from
|
|
965
|
+
]
|
|
966
|
+
expanded_ids = [base_ids[src] for src in expanded_from]
|
|
967
|
+
expanded_cats = [base_cats[src] for src in expanded_from]
|
|
968
|
+
except IndexError:
|
|
969
|
+
raise ValueError(
|
|
970
|
+
f"[CMERProcessor] expanded_from_indices contains index out of range: "
|
|
971
|
+
f"max={max(expanded_from)}, but num_in={num_in}"
|
|
972
|
+
)
|
|
973
|
+
|
|
974
|
+
text_outputs = self.tokenizer(
|
|
975
|
+
expanded_texts,
|
|
976
|
+
return_tensors=return_tensors,
|
|
977
|
+
add_special_tokens=True,
|
|
978
|
+
padding=padding,
|
|
979
|
+
truncation=truncation,
|
|
980
|
+
max_length=max_length,
|
|
981
|
+
)
|
|
982
|
+
|
|
983
|
+
text_outputs["decoder_input_ids"] = text_outputs.pop("input_ids")
|
|
984
|
+
data = {**image_outputs, **text_outputs}
|
|
985
|
+
|
|
986
|
+
labels = (
|
|
987
|
+
data["decoder_input_ids"].clone()
|
|
988
|
+
if return_tensors is not None
|
|
989
|
+
else list(data["decoder_input_ids"])
|
|
990
|
+
)
|
|
991
|
+
|
|
992
|
+
pad_id = self.tokenizer.pad_token_id
|
|
993
|
+
if pad_id is None:
|
|
994
|
+
data["labels"] = labels
|
|
995
|
+
else:
|
|
996
|
+
if hasattr(labels, "masked_fill"):
|
|
997
|
+
labels = labels.masked_fill(labels == pad_id, -100)
|
|
998
|
+
else:
|
|
999
|
+
labels = [[(-100 if tok == pad_id else tok) for tok in seq] for seq in labels]
|
|
1000
|
+
data["labels"] = labels
|
|
1001
|
+
|
|
1002
|
+
# bf = BatchFeature(data=data, tensor_type=return_tensors)
|
|
1003
|
+
# bf["ids"] = expanded_ids
|
|
1004
|
+
# bf["categorys"] = expanded_cats
|
|
1005
|
+
input_ids = data["decoder_input_ids"]
|
|
1006
|
+
# return bf
|
|
1007
|
+
|
|
1008
|
+
if "attention_mask" in text_outputs:
|
|
1009
|
+
# attention_mask shape: [batch, seq_len]
|
|
1010
|
+
# sum(dim=1) 得到每个样本的有效长度
|
|
1011
|
+
length = text_outputs["attention_mask"].sum(dim=1)
|
|
1012
|
+
# 确保是 int32 或 int64
|
|
1013
|
+
length = length.to(dtype=torch.int32)
|
|
1014
|
+
else:
|
|
1015
|
+
# 如果没有 attention_mask,假设没有 padding,直接取 shape
|
|
1016
|
+
seq_len = input_ids.shape[1]
|
|
1017
|
+
batch_size = input_ids.shape[0]
|
|
1018
|
+
length = torch.full((batch_size,), seq_len, dtype=torch.int32)
|
|
1019
|
+
|
|
1020
|
+
# 6. 返回 Tuple (pixel_values, labels, length)
|
|
1021
|
+
pixel_values = data['image']
|
|
1022
|
+
return pixel_values, labels, length
|
|
1023
|
+
|
|
1024
|
+
def batch_decode(self, *args, **kwargs):
|
|
1025
|
+
return self.tokenizer.batch_decode(*args, **kwargs)
|