openocr-python 0.0.9__py3-none-any.whl → 0.1.0.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (90) hide show
  1. openocr/__init__.py +35 -1
  2. openocr/configs/dataset/rec/evaluation.yaml +41 -0
  3. openocr/configs/dataset/rec/ltb.yaml +9 -0
  4. openocr/configs/dataset/rec/mjsynth.yaml +11 -0
  5. openocr/configs/dataset/rec/openvino.yaml +25 -0
  6. openocr/configs/dataset/rec/ost.yaml +17 -0
  7. openocr/configs/dataset/rec/synthtext.yaml +7 -0
  8. openocr/configs/dataset/rec/test.yaml +77 -0
  9. openocr/configs/dataset/rec/textocr.yaml +13 -0
  10. openocr/configs/dataset/rec/textocr_horizontal.yaml +13 -0
  11. openocr/configs/dataset/rec/union14m_b.yaml +47 -0
  12. openocr/configs/dataset/rec/union14m_l_filtered.yaml +35 -0
  13. openocr/configs/rec/cmer/cmer.yml +127 -0
  14. openocr/configs/rec/mdiff4str/svtrv2_mdiffdecoder_base.yml +152 -0
  15. openocr/configs/rec/mdiff4str/svtrv2_mdiffdecoder_small.yml +152 -0
  16. openocr/configs/rec/unirec/focalsvtr_ardecoder_unirec.yml +114 -0
  17. openocr/configs/rec/unirec/opendoc_pipeline.yml +105 -0
  18. openocr/demo_gradio.py +28 -8
  19. openocr/demo_opendoc.py +572 -0
  20. openocr/demo_unirec.py +392 -0
  21. openocr/opendet/losses/__init__.py +5 -7
  22. openocr/opendet/preprocess/crop_resize.py +2 -1
  23. openocr/openocr.py +685 -0
  24. openocr/openrec/losses/__init__.py +8 -3
  25. openocr/openrec/losses/cmer_loss.py +12 -0
  26. openocr/openrec/losses/mdiff_loss.py +11 -0
  27. openocr/openrec/losses/unirec_loss.py +12 -0
  28. openocr/openrec/metrics/__init__.py +4 -1
  29. openocr/openrec/metrics/rec_metric_cmer.py +328 -0
  30. openocr/openrec/modeling/cmer_modeling/modeling_cmer.py +643 -0
  31. openocr/openrec/modeling/decoders/__init__.py +1 -0
  32. openocr/openrec/modeling/decoders/ctc_decoder.py +1 -1
  33. openocr/openrec/modeling/decoders/dan_decoder.py +4 -4
  34. openocr/openrec/modeling/decoders/dptr_parseq_clip_b_decoder.py +1563 -1398
  35. openocr/openrec/modeling/decoders/mdiff_decoder.py +587 -0
  36. openocr/openrec/modeling/decoders/smtr_decoder.py +99 -48
  37. openocr/openrec/modeling/unirec_modeling/configuration_unirec.py +166 -0
  38. openocr/openrec/modeling/unirec_modeling/modeling_unirec.py +433 -0
  39. openocr/openrec/optimizer/__init__.py +4 -3
  40. openocr/openrec/optimizer/lr.py +49 -0
  41. openocr/openrec/postprocess/__init__.py +2 -0
  42. openocr/openrec/postprocess/abinet_postprocess.py +1 -1
  43. openocr/openrec/postprocess/ar_postprocess.py +1 -1
  44. openocr/openrec/postprocess/cmer_postprocess.py +86 -0
  45. openocr/openrec/postprocess/cppd_postprocess.py +1 -1
  46. openocr/openrec/postprocess/igtr_postprocess.py +1 -1
  47. openocr/openrec/postprocess/lister_postprocess.py +1 -1
  48. openocr/openrec/postprocess/mgp_postprocess.py +1 -1
  49. openocr/openrec/postprocess/nrtr_postprocess.py +2 -2
  50. openocr/openrec/postprocess/smtr_postprocess.py +1 -1
  51. openocr/openrec/postprocess/srn_postprocess.py +1 -1
  52. openocr/openrec/postprocess/unirec_postprocess.py +58 -0
  53. openocr/openrec/postprocess/visionlan_postprocess.py +1 -1
  54. openocr/openrec/preprocess/__init__.py +5 -0
  55. openocr/openrec/preprocess/ce_label_encode.py +1 -1
  56. openocr/openrec/preprocess/cmer_label_encode.py +1025 -0
  57. openocr/openrec/preprocess/ctc_label_encode.py +1 -1
  58. openocr/openrec/preprocess/dptr_label_encode.py +177 -157
  59. openocr/openrec/preprocess/igtr_label_encode.py +4 -2
  60. openocr/openrec/preprocess/mdiff_label_encode.py +312 -0
  61. openocr/openrec/preprocess/rec_aug.py +128 -2
  62. openocr/openrec/preprocess/resize.py +57 -0
  63. openocr/openrec/preprocess/unirec_label_encode.py +62 -0
  64. openocr/tools/data/__init__.py +78 -55
  65. openocr/tools/data/cmer_web_dataset.py +310 -0
  66. openocr/tools/data/native_size_dataset.py +753 -0
  67. openocr/tools/data/native_size_sampler.py +158 -0
  68. openocr/tools/data/ratio_dataset_tvresize.py +2 -0
  69. openocr/tools/data/ratio_sampler.py +2 -1
  70. openocr/tools/download/download_dataset.py +38 -0
  71. openocr/tools/download/utils.py +28 -0
  72. openocr/tools/download_example_images.py +236 -0
  73. openocr/tools/engine/trainer.py +155 -39
  74. openocr/tools/eval_rec_all_ch.py +2 -2
  75. openocr/tools/infer_det.py +20 -2
  76. openocr/tools/infer_doc.py +898 -0
  77. openocr/tools/infer_doc_onnx.py +1172 -0
  78. openocr/tools/infer_e2e.py +27 -10
  79. openocr/tools/infer_rec.py +64 -15
  80. openocr/tools/infer_unirec_onnx.py +730 -0
  81. openocr/tools/to_markdown.py +468 -0
  82. openocr/tools/utils/ckpt.py +17 -5
  83. openocr/tools/utils/opendoc_onnx_utils/utils.py +1052 -0
  84. openocr_python-0.1.0.dev0.dist-info/METADATA +324 -0
  85. {openocr_python-0.0.9.dist-info → openocr_python-0.1.0.dev0.dist-info}/RECORD +89 -45
  86. {openocr_python-0.0.9.dist-info → openocr_python-0.1.0.dev0.dist-info}/WHEEL +1 -1
  87. openocr_python-0.1.0.dev0.dist-info/entry_points.txt +2 -0
  88. openocr_python-0.0.9.dist-info/METADATA +0 -149
  89. /openocr_python-0.0.9.dist-info/LICENCE → /openocr_python-0.1.0.dev0.dist-info/licenses/LICENSE +0 -0
  90. {openocr_python-0.0.9.dist-info → openocr_python-0.1.0.dev0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1025 @@
1
+ from __future__ import annotations
2
+ import math
3
+ from typing import List, Optional, Tuple, Union
4
+
5
+ import numpy as np
6
+ from scipy.ndimage import zoom as scizoom
7
+
8
+ # Transformers imports
9
+ from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
10
+ from transformers.image_transforms import convert_to_rgb, resize, to_channel_dimension_format
11
+ from transformers.image_utils import (
12
+ ChannelDimension,
13
+ ImageInput,
14
+ PILImageResampling,
15
+ infer_channel_dimension_format,
16
+ make_flat_list_of_images,
17
+ to_numpy_array,
18
+ valid_images,
19
+ validate_preprocess_arguments,
20
+ )
21
+ from transformers import PreTrainedTokenizerFast
22
+ from transformers.utils import TensorType, filter_out_non_signature_kwargs, is_vision_available, logging
23
+ from transformers import AutoImageProcessor, ProcessorMixin
24
+ import torch
25
+ # Third-party optional imports
26
+ logger = logging.get_logger(__name__)
27
+
28
+ try:
29
+ import albumentations as A
30
+ except Exception as _e:
31
+ A = None
32
+ _A_IMPORT_ERR = str(_e)
33
+
34
+ try:
35
+ import cv2
36
+ except Exception:
37
+ cv2 = None
38
+
39
+ if is_vision_available():
40
+ from PIL import Image, ImageOps, ImageDraw
41
+
42
+ # Albumentations Custom Transforms
43
+ if A is not None:
44
+
45
+ class Erosion(A.ImageOnlyTransform):
46
+
47
+ def __init__(self, scale, always_apply=False, p=0.5):
48
+ super().__init__(always_apply=always_apply, p=p)
49
+ if type(scale) is tuple or type(scale) is list:
50
+ assert len(scale) == 2
51
+ self.scale = scale
52
+ else:
53
+ self.scale = (scale, scale)
54
+
55
+ def apply(self, img, **params):
56
+ if cv2 is None:
57
+ return img
58
+ kernel = cv2.getStructuringElement(
59
+ cv2.MORPH_ELLIPSE,
60
+ tuple(np.random.randint(self.scale[0], self.scale[1], 2)))
61
+ img = cv2.erode(img, kernel, iterations=1)
62
+ return img
63
+
64
+ class Dilation(A.ImageOnlyTransform):
65
+
66
+ def __init__(self, scale, always_apply=False, p=0.5):
67
+ super().__init__(always_apply=always_apply, p=p)
68
+ if type(scale) is tuple or type(scale) is list:
69
+ assert len(scale) == 2
70
+ self.scale = scale
71
+ else:
72
+ self.scale = (scale, scale)
73
+
74
+ def apply(self, img, **params):
75
+ if cv2 is None:
76
+ return img
77
+ kernel = cv2.getStructuringElement(
78
+ cv2.MORPH_ELLIPSE,
79
+ tuple(np.random.randint(self.scale[0], self.scale[1], 2)))
80
+ img = cv2.dilate(img, kernel, iterations=1)
81
+ return img
82
+
83
+ class Bitmap(A.ImageOnlyTransform):
84
+
85
+ def __init__(self, value=0, lower=200, p=0.5):
86
+ super().__init__(p=p)
87
+ self.lower = lower
88
+ self.value = value
89
+
90
+ def apply(self, img, **params):
91
+ img = img.copy()
92
+ img[img < self.lower] = self.value
93
+ return img
94
+
95
+ class Fog(A.ImageOnlyTransform):
96
+
97
+ def __init__(self, mag=-1, always_apply=False, p=1.):
98
+ super().__init__(always_apply=always_apply, p=p)
99
+ self.rng = np.random.default_rng()
100
+ self.mag = mag
101
+
102
+ def apply(self, img, **params):
103
+ img = Image.fromarray(img.astype(np.uint8))
104
+ w, h = img.size
105
+ c = [(1.5, 2), (2., 2), (2.5, 1.7)]
106
+ if self.mag < 0 or self.mag >= len(c):
107
+ index = self.rng.integers(0, len(c))
108
+ else:
109
+ index = self.mag
110
+ c = c[index]
111
+ n_channels = len(img.getbands())
112
+ isgray = n_channels == 1
113
+ img = np.asarray(img) / 255.
114
+ max_val = img.max()
115
+ max_size = 2**math.ceil(math.log2(max(w, h)) + 1)
116
+ fog = c[0] * plasma_fractal(mapsize=max_size,
117
+ wibbledecay=c[1],
118
+ rng=self.rng)[:h, :w][..., np.newaxis]
119
+ if isgray:
120
+ fog = np.squeeze(fog)
121
+ else:
122
+ fog = np.repeat(fog, 3, axis=2)
123
+ img += fog
124
+ img = np.clip(img * max_val / (max_val + c[0]), 0, 1) * 255
125
+ return img.astype(np.uint8)
126
+
127
+ class Frost(A.ImageOnlyTransform):
128
+
129
+ def __init__(self, mag=-1, always_apply=False, p=1.):
130
+ super().__init__(always_apply=always_apply, p=p)
131
+ self.rng = np.random.default_rng()
132
+ self.mag = mag
133
+
134
+ def apply(self, img, **params):
135
+ img = Image.fromarray(img.astype(np.uint8))
136
+ w, h = img.size
137
+ c = [(0.78, 0.22), (0.64, 0.36), (0.5, 0.5)]
138
+ if self.mag < 0 or self.mag >= len(c):
139
+ index = self.rng.integers(0, len(c))
140
+ else:
141
+ index = self.mag
142
+ c = c[index]
143
+ filename = [
144
+ './openrec/preprocess/cmer_frost/frost1.png',
145
+ './openrec/preprocess/cmer_frost/frost2.png',
146
+ './openrec/preprocess/cmer_frost/frost3.png',
147
+ './openrec/preprocess/cmer_frost/frost4.jpg',
148
+ './openrec/preprocess/cmer_frost/frost5.jpg',
149
+ './openrec/preprocess/cmer_frost/frost6.jpg',
150
+ ]
151
+ index = self.rng.integers(0, len(filename))
152
+ filename = filename[index]
153
+ try:
154
+ frost = Image.open(filename).convert('RGB')
155
+ except Exception:
156
+ # Fallback if file not found
157
+ return np.asarray(img).astype(np.uint8)
158
+
159
+ f_w, f_h = frost.size
160
+ if w / h > f_w / f_h:
161
+ f_h = round(f_h * w / f_w)
162
+ f_w = w
163
+ else:
164
+ f_w = round(f_w * h / f_h)
165
+ f_h = h
166
+ frost = np.asarray(frost.resize((f_w, f_h)))
167
+ y_start = self.rng.integers(0, f_h - h + 1)
168
+ x_start = self.rng.integers(0, f_w - w + 1)
169
+ frost = frost[y_start:y_start + h, x_start:x_start + w]
170
+ n_channels = len(img.getbands())
171
+ isgray = n_channels == 1
172
+ img = np.asarray(img)
173
+ if isgray:
174
+ img = np.expand_dims(img, axis=2)
175
+ img = np.repeat(img, 3, axis=2)
176
+ img = np.clip(np.round(c[0] * img + c[1] * frost), 0, 255)
177
+ img = img.astype(np.uint8)
178
+ if isgray:
179
+ img = np.squeeze(img)
180
+ return img
181
+
182
+ class Snow(A.ImageOnlyTransform):
183
+
184
+ def __init__(self, mag=-1, always_apply=False, p=1.):
185
+ super().__init__(always_apply=always_apply, p=p)
186
+ self.rng = np.random.default_rng()
187
+ self.mag = mag
188
+
189
+ def apply(self, img, **params):
190
+ img_pil = Image.fromarray(img.astype(np.uint8))
191
+ w, h = img_pil.size
192
+ c = [(0.1, 0.3, 3, 0.5, 10, 4, 0.8),
193
+ (0.2, 0.3, 2, 0.5, 12, 4, 0.7),
194
+ (0.55, 0.3, 4, 0.9, 12, 8, 0.7)]
195
+ if self.mag < 0 or self.mag >= len(c):
196
+ index = self.rng.integers(0, len(c))
197
+ else:
198
+ index = self.mag
199
+ c = c[index]
200
+ isgray = (len(img_pil.getbands()) == 1)
201
+ img = np.asarray(img_pil, dtype=np.float32) / 255.
202
+ if isgray:
203
+ img = np.repeat(img[..., None], 3, axis=2)
204
+ snow_layer = self.rng.normal(loc=c[0],
205
+ scale=c[1],
206
+ size=img.shape[:2])
207
+ snow_layer[snow_layer < c[3]] = 0
208
+ snow_layer = np.clip(snow_layer, 0, 1).astype(np.float32)
209
+ angle = float(self.rng.uniform(-135, -45))
210
+ snow_layer = motion_blur(snow_layer,
211
+ radius=c[4],
212
+ sigma=c[5],
213
+ angle=angle)
214
+ snow_layer = snow_layer[..., None]
215
+ img = c[6] * img
216
+ if cv2 is not None:
217
+ gray_img = (1 - c[6]) * np.maximum(
218
+ img,
219
+ cv2.cvtColor(img, cv2.COLOR_RGB2GRAY).reshape(h, w, 1) *
220
+ 1.5 + 0.5)
221
+ img += gray_img
222
+ img = np.clip(img + snow_layer + np.rot90(snow_layer, k=2), 0,
223
+ 1) * 255
224
+ img = img.astype(np.uint8)
225
+ return np.squeeze(img) if isgray else img
226
+
227
+ class Rain(A.ImageOnlyTransform):
228
+
229
+ def __init__(self, mag=-1, always_apply=False, p=1.):
230
+ super().__init__(always_apply=always_apply, p=p)
231
+ self.rng = np.random.default_rng()
232
+ self.mag = mag
233
+
234
+ def apply(self, img, **params):
235
+ img = Image.fromarray(img.astype(np.uint8))
236
+ img = img.copy()
237
+ w, h = img.size
238
+ n_channels = len(img.getbands())
239
+ isgray = n_channels == 1
240
+ line_width = self.rng.integers(1, 2)
241
+ c = [50, 70, 90]
242
+ if self.mag < 0 or self.mag >= len(c):
243
+ index = 0
244
+ else:
245
+ index = self.mag
246
+ c = c[index]
247
+ n_rains = self.rng.integers(c, c + 20)
248
+ slant = self.rng.integers(-60, 60)
249
+ fillcolor = 200 if isgray else (200, 200, 200)
250
+ draw = ImageDraw.Draw(img)
251
+ max_length = min(w, h, 10)
252
+ for i in range(1, n_rains):
253
+ length = self.rng.integers(5, max_length)
254
+ x1 = self.rng.integers(0, w - length)
255
+ y1 = self.rng.integers(0, h - length)
256
+ x2 = x1 + length * math.sin(slant * math.pi / 180.)
257
+ y2 = y1 + length * math.cos(slant * math.pi / 180.)
258
+ x2 = int(x2)
259
+ y2 = int(y2)
260
+ draw.line([(x1, y1), (x2, y2)],
261
+ width=line_width,
262
+ fill=fillcolor)
263
+ img = np.asarray(img).astype(np.uint8)
264
+ return img
265
+
266
+ class Shadow(A.ImageOnlyTransform):
267
+
268
+ def __init__(self, mag=-1, always_apply=False, p=1.):
269
+ super().__init__(always_apply=always_apply, p=p)
270
+ self.rng = np.random.default_rng()
271
+ self.mag = mag
272
+
273
+ def apply(self, img, **params):
274
+ img = Image.fromarray(img.astype(np.uint8))
275
+ w, h = img.size
276
+ n_channels = len(img.getbands())
277
+ isgray = n_channels == 1
278
+ c = [64, 96, 128]
279
+ if self.mag < 0 or self.mag >= len(c):
280
+ index = 0
281
+ else:
282
+ index = self.mag
283
+ c = c[index]
284
+ img = img.convert('RGBA')
285
+ overlay = Image.new('RGBA', img.size, (255, 255, 255, 0))
286
+ draw = ImageDraw.Draw(overlay)
287
+ transparency = self.rng.integers(c, c + 32)
288
+ x1 = self.rng.integers(0, w // 2)
289
+ y1 = 0
290
+ x2 = self.rng.integers(w // 2, w)
291
+ y2 = 0
292
+ x3 = self.rng.integers(w // 2, w)
293
+ y3 = h - 1
294
+ x4 = self.rng.integers(0, w // 2)
295
+ y4 = h - 1
296
+ draw.polygon([(x1, y1), (x2, y2), (x3, y3), (x4, y4)],
297
+ fill=(0, 0, 0, transparency))
298
+ img = Image.alpha_composite(img, overlay)
299
+ img = img.convert('RGB')
300
+ if isgray:
301
+ img = ImageOps.grayscale(img)
302
+ img = np.asarray(img).astype(np.uint8)
303
+ return img
304
+
305
+ else:
306
+ # Fallback placeholders if Albumentations is missing
307
+ Erosion = None
308
+ Dilation = None
309
+ Bitmap = None
310
+ Fog = None
311
+ Frost = None
312
+ Snow = None
313
+ Rain = None
314
+ Shadow = None
315
+
316
+
317
+ def clipped_zoom(img, zoom_factor):
318
+ h = img.shape[1]
319
+ ch = int(np.ceil(h / float(zoom_factor)))
320
+ top = (h - ch) // 2
321
+ img = scizoom(img[top:top + ch, top:top + ch],
322
+ (zoom_factor, zoom_factor, 1),
323
+ order=1)
324
+ trim_top = (img.shape[0] - h) // 2
325
+ return img[trim_top:trim_top + h, trim_top:trim_top + h]
326
+
327
+
328
+ def disk(radius, alias_blur=0.1, dtype=np.float32):
329
+ if cv2 is None:
330
+ return np.zeros((1, 1), dtype=dtype)
331
+ if radius <= 8:
332
+ coords = np.arange(-8, 8 + 1)
333
+ ksize = (3, 3)
334
+ else:
335
+ coords = np.arange(-radius, radius + 1)
336
+ ksize = (5, 5)
337
+ x, y = np.meshgrid(coords, coords)
338
+ aliased_disk = np.asarray((x**2 + y**2) <= radius**2, dtype=dtype)
339
+ aliased_disk /= np.sum(aliased_disk)
340
+ return cv2.GaussianBlur(aliased_disk, ksize=ksize, sigmaX=alias_blur)
341
+
342
+
343
+ def plasma_fractal(mapsize=256, wibbledecay=3, rng=None):
344
+ assert (mapsize & (mapsize - 1) == 0)
345
+ maparray = np.empty((mapsize, mapsize), dtype=np.float64)
346
+ maparray[0, 0] = 0
347
+ stepsize = mapsize
348
+ wibble = 100
349
+ if rng is None:
350
+ rng = np.random.default_rng()
351
+
352
+ def wibbledmean(array):
353
+ return array / 4 + wibble * rng.uniform(-wibble, wibble, array.shape)
354
+
355
+ def fillsquares():
356
+ cornerref = maparray[0:mapsize:stepsize, 0:mapsize:stepsize]
357
+ squareaccum = cornerref + np.roll(cornerref, shift=-1, axis=0)
358
+ squareaccum += np.roll(squareaccum, shift=-1, axis=1)
359
+ maparray[stepsize // 2:mapsize:stepsize,
360
+ stepsize // 2:mapsize:stepsize] = wibbledmean(squareaccum)
361
+
362
+ def filldiamonds():
363
+ drgrid = maparray[stepsize // 2:mapsize:stepsize,
364
+ stepsize // 2:mapsize:stepsize]
365
+ ulgrid = maparray[0:mapsize:stepsize, 0:mapsize:stepsize]
366
+ ldrsum = drgrid + np.roll(drgrid, 1, axis=0)
367
+ lulsum = ulgrid + np.roll(ulgrid, -1, axis=1)
368
+ ltsum = ldrsum + lulsum
369
+ maparray[0:mapsize:stepsize,
370
+ stepsize // 2:mapsize:stepsize] = wibbledmean(ltsum)
371
+ tdrsum = drgrid + np.roll(drgrid, 1, axis=1)
372
+ tulsum = ulgrid + np.roll(ulgrid, -1, axis=0)
373
+ ttsum = tdrsum + tulsum
374
+ maparray[stepsize // 2:mapsize:stepsize,
375
+ 0:mapsize:stepsize] = wibbledmean(ttsum)
376
+
377
+ while stepsize >= 2:
378
+ fillsquares()
379
+ filldiamonds()
380
+ stepsize //= 2
381
+ wibble /= wibbledecay
382
+ maparray -= maparray.min()
383
+ return maparray / maparray.max()
384
+
385
+
386
+ def motion_blur(img: np.ndarray, radius: int, sigma: float,
387
+ angle: float) -> np.ndarray:
388
+ if cv2 is None:
389
+ return img
390
+ kernel_size = max(1, int(radius) * 2 + 1)
391
+ psf = np.zeros((kernel_size, kernel_size), dtype=np.float32)
392
+ psf[kernel_size // 2] = 1.0
393
+ M = cv2.getRotationMatrix2D((kernel_size / 2, kernel_size / 2), angle, 1)
394
+ psf = cv2.warpAffine(psf, M, (kernel_size, kernel_size))
395
+ if sigma > 0:
396
+ psf = cv2.GaussianBlur(psf, (kernel_size, kernel_size), sigma)
397
+ psf /= psf.sum() if psf.sum() != 0 else 1
398
+ return cv2.filter2D(img, -1, psf, borderType=cv2.BORDER_REPLICATE)
399
+
400
+
401
+ class CMERImageProcessor(BaseImageProcessor):
402
+ model_input_names = [
403
+ 'pixel_values', 'orig_spatial_shape', 'expanded_from_indices',
404
+ 'is_original_flags'
405
+ ]
406
+
407
+ def __init__(
408
+ self,
409
+ down_sample_ratio: int = 32,
410
+ do_convert_rgb: bool = True,
411
+ do_rescale: bool = True,
412
+ rescale_factor: float = 1.0 / 255.0,
413
+ do_normalize: bool = True,
414
+ image_mean: Optional[Union[float, list[float]]] = None,
415
+ image_std: Optional[Union[float, list[float]]] = None,
416
+ resample: 'PILImageResampling' = PILImageResampling.BILINEAR,
417
+ output_channel_format: ChannelDimension = ChannelDimension.FIRST,
418
+ pad_value_strategy: str = 'mean',
419
+ pad_value: Optional[Union[float, List[float]]] = None,
420
+ center_pad: bool = False,
421
+ do_augment: bool = True,
422
+ augment_prob: float = 1.0,
423
+ pre_pad_expand_ratio: float = 0.04,
424
+ pre_pad_min_px: int = 8,
425
+ aug_repeats: int = 0,
426
+ keep_original: bool = True,
427
+ num_workers: int = 8,
428
+ pad_num_workers: Optional[int] = None,
429
+ resize_backend: str = 'auto',
430
+ normalize_inplace: bool = True,
431
+ **kwargs,
432
+ ):
433
+ super().__init__(**kwargs)
434
+ self.down_sample_ratio = int(down_sample_ratio)
435
+ self.do_convert_rgb = bool(do_convert_rgb)
436
+ self.do_rescale = bool(do_rescale)
437
+ self.rescale_factor = float(rescale_factor)
438
+ self.do_normalize = bool(do_normalize)
439
+ self.image_mean = image_mean if image_mean is not None else [
440
+ 0.5, 0.5, 0.5
441
+ ]
442
+ self.image_std = image_std if image_std is not None else [
443
+ 0.5, 0.5, 0.5
444
+ ]
445
+ self.resample = resample
446
+ self.output_channel_format = output_channel_format
447
+ self.pad_value_strategy = str(pad_value_strategy).lower()
448
+ self.pad_value = pad_value
449
+ self.center_pad = bool(center_pad)
450
+ self.default_do_augment = bool(do_augment)
451
+ self.augment_prob = float(augment_prob)
452
+ self.pre_pad_expand_ratio = float(pre_pad_expand_ratio)
453
+ self.pre_pad_min_px = int(pre_pad_min_px)
454
+ self.aug_repeats = max(int(aug_repeats), 0)
455
+ self.keep_original = bool(keep_original)
456
+ self.num_workers = max(int(num_workers), 0)
457
+ self.pad_num_workers = pad_num_workers if pad_num_workers is not None else self.num_workers
458
+ self.resize_backend = resize_backend
459
+ self.normalize_inplace = bool(normalize_inplace)
460
+ self._augmentations = self._build_augmentations()
461
+
462
+ def _build_augmentations(self):
463
+ if A is None:
464
+ logger.warning_once(
465
+ f"[CMERImageProcessor] Albumentations 未安装,跳过图像增强。{_A_IMPORT_ERR if '_A_IMPORT_ERR' in globals() else ''}"
466
+ )
467
+ return None
468
+ tlist = []
469
+ if Bitmap is not None:
470
+ tlist.append(Bitmap(p=0.2))
471
+ weather_ops = []
472
+ for op in (Fog, Frost, Snow, Rain, Shadow):
473
+ if op is not None:
474
+ try:
475
+ weather_ops.append(op())
476
+ except Exception:
477
+ pass
478
+ if weather_ops:
479
+ tlist.append(A.OneOf(weather_ops, p=0.5))
480
+ morph_ops = []
481
+ if Erosion is not None:
482
+ try:
483
+ morph_ops.append(Erosion((2, 3)))
484
+ except Exception:
485
+ pass
486
+ if Dilation is not None:
487
+ try:
488
+ morph_ops.append(Dilation((2, 3)))
489
+ except Exception:
490
+ pass
491
+ if morph_ops:
492
+ tlist.append(A.OneOf(morph_ops, p=0.2))
493
+ tlist.extend([
494
+ A.ShiftScaleRotate(shift_limit=0,
495
+ scale_limit=(-.15, 0),
496
+ rotate_limit=1,
497
+ border_mode=0,
498
+ interpolation=3,
499
+ value=[255, 255, 255],
500
+ p=1),
501
+ A.GridDistortion(distort_limit=0.1,
502
+ border_mode=0,
503
+ interpolation=3,
504
+ value=[255, 255, 255],
505
+ p=0.5),
506
+ A.RGBShift(r_shift_limit=15,
507
+ g_shift_limit=15,
508
+ b_shift_limit=15,
509
+ p=0.3),
510
+ A.GaussNoise(var_limit=(10.0, 20.0), p=0.2),
511
+ A.RandomBrightnessContrast(0.05, (-0.2, 0), True, p=0.2),
512
+ ])
513
+ return A.Compose(tlist, p=self.augment_prob)
514
+
515
+ @staticmethod
516
+ def _constant_border(img: np.ndarray,
517
+ pad_px: int,
518
+ value: int = 255) -> np.ndarray:
519
+ if pad_px <= 0:
520
+ return img
521
+ if cv2 is not None:
522
+ return cv2.copyMakeBorder(img,
523
+ pad_px,
524
+ pad_px,
525
+ pad_px,
526
+ pad_px,
527
+ cv2.BORDER_CONSTANT,
528
+ value=[value, value, value])
529
+ return np.pad(img, ((pad_px, pad_px), (pad_px, pad_px), (0, 0)),
530
+ constant_values=value)
531
+
532
+ def _maybe_augment_uint8(self, img_uint8: np.ndarray, seed: Optional[int],
533
+ pre_pad_px: int) -> np.ndarray:
534
+ if self._augmentations is None:
535
+ return img_uint8
536
+ if pre_pad_px > 0:
537
+ img_uint8 = self._constant_border(img_uint8, pre_pad_px, value=255)
538
+ if seed is not None:
539
+ rng_state = np.random.get_state()
540
+ np.random.seed(seed)
541
+ try:
542
+ out = self._augmentations(image=img_uint8)['image']
543
+ finally:
544
+ np.random.set_state(rng_state)
545
+ return out
546
+ else:
547
+ return self._augmentations(image=img_uint8)['image']
548
+
549
+ def _prep_uint8(self, img, input_data_format) -> np.ndarray:
550
+ if self.do_convert_rgb:
551
+ img = convert_to_rgb(img)
552
+ np_img = to_numpy_array(img)
553
+ if input_data_format is None:
554
+ _fmt = infer_channel_dimension_format(np_img)
555
+ else:
556
+ _fmt = input_data_format
557
+ if _fmt == ChannelDimension.FIRST:
558
+ np_img = np.transpose(np_img, (1, 2, 0))
559
+ elif _fmt == ChannelDimension.LAST:
560
+ pass
561
+ else:
562
+ np_img = to_channel_dimension_format(np_img,
563
+ ChannelDimension.LAST,
564
+ input_channel_dim=_fmt)
565
+ if np_img.dtype != np.uint8:
566
+ if np_img.dtype.kind == 'f':
567
+ np_img = np.clip(np_img, 0.0, 1.0)
568
+ np_img = (np_img * 255.0 + 0.5).astype(np.uint8)
569
+ else:
570
+ np_img = np_img.astype(np.uint8)
571
+ return np_img
572
+
573
+ def _resize_uint8(self, img_uint8: np.ndarray, th: int, tw: int,
574
+ backend: str) -> np.ndarray:
575
+ if backend == 'cv2' and cv2 is not None:
576
+ return cv2.resize(img_uint8, (tw, th), interpolation=1)
577
+ return resize(img_uint8,
578
+ size=(th, tw),
579
+ resample=self.resample,
580
+ input_data_format=ChannelDimension.LAST)
581
+
582
+ def preprocess_auto(self,
583
+ images: ImageInput,
584
+ return_tensors: Optional[Union[str,
585
+ TensorType]] = None,
586
+ trainer=None,
587
+ **kwargs) -> BatchFeature:
588
+ if trainer is not None and getattr(trainer, 'is_in_train', False):
589
+ kwargs.setdefault('do_augment', True)
590
+ return self.preprocess(images, return_tensors=return_tensors, **kwargs)
591
+
592
+ @filter_out_non_signature_kwargs()
593
+ def preprocess(
594
+ self,
595
+ images: ImageInput,
596
+ return_tensors: Optional[Union[str, TensorType]] = None,
597
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
598
+ do_convert_rgb: Optional[bool] = None,
599
+ do_rescale: Optional[bool] = None,
600
+ rescale_factor: Optional[float] = None,
601
+ do_normalize: Optional[bool] = None,
602
+ image_mean: Optional[Union[float, list[float]]] = None,
603
+ image_std: Optional[Union[float, list[float]]] = None,
604
+ down_sample_ratio: Optional[int] = None,
605
+ resample: Optional['PILImageResampling'] = None,
606
+ output_channel_format: Optional[ChannelDimension] = None,
607
+ pad_value_strategy: Optional[str] = None,
608
+ pad_value: Optional[Union[float, List[float]]] = None,
609
+ center_pad: Optional[bool] = None,
610
+ do_augment: Optional[bool] = True,
611
+ augment_seed: Optional[int] = None,
612
+ pre_pad_expand_ratio: Optional[float] = None,
613
+ pre_pad_min_px: Optional[int] = None,
614
+ aug_repeats: Optional[int] = None,
615
+ keep_original: Optional[bool] = None,
616
+ num_workers: Optional[int] = None,
617
+ pad_num_workers: Optional[int] = None,
618
+ resize_backend: Optional[str] = None,
619
+ normalize_inplace: Optional[bool] = None,
620
+ ) -> BatchFeature:
621
+ do_convert_rgb = self.do_convert_rgb if do_convert_rgb is None else do_convert_rgb
622
+ do_rescale = self.do_rescale if do_rescale is None else do_rescale
623
+ rescale_factor = self.rescale_factor if rescale_factor is None else rescale_factor
624
+ do_normalize = self.do_normalize if do_normalize is None else do_normalize
625
+ image_mean = self.image_mean if image_mean is None else image_mean
626
+ image_std = self.image_std if image_std is None else image_std
627
+ down_sample_ratio = self.down_sample_ratio if down_sample_ratio is None else int(
628
+ down_sample_ratio)
629
+ resample = self.resample if resample is None else resample
630
+ output_channel_format = self.output_channel_format if output_channel_format is None else output_channel_format
631
+ pad_value_strategy = self.pad_value_strategy if pad_value_strategy is None else pad_value_strategy.lower(
632
+ )
633
+ pad_value = self.pad_value if pad_value is None else pad_value
634
+ center_pad = self.center_pad if center_pad is None else bool(
635
+ center_pad)
636
+ do_augment = self.default_do_augment if do_augment is None else bool(
637
+ do_augment)
638
+ pre_pad_expand_ratio = self.pre_pad_expand_ratio if pre_pad_expand_ratio is None else float(
639
+ pre_pad_expand_ratio)
640
+ pre_pad_min_px = self.pre_pad_min_px if pre_pad_min_px is None else int(
641
+ pre_pad_min_px)
642
+ aug_repeats = self.aug_repeats if aug_repeats is None else max(
643
+ int(aug_repeats), 0)
644
+ keep_original = self.keep_original if keep_original is None else bool(
645
+ keep_original)
646
+ num_workers = self.num_workers if num_workers is None else max(
647
+ int(num_workers), 0)
648
+ pad_num_workers = self.pad_num_workers if pad_num_workers is None else max(
649
+ int(pad_num_workers), 0)
650
+ resize_backend = (self.resize_backend if resize_backend is None else
651
+ resize_backend).lower()
652
+ normalize_inplace = self.normalize_inplace if normalize_inplace is None else bool(
653
+ normalize_inplace)
654
+ if type(images) is dict:
655
+ images = images.get('image', None)
656
+ images = self.fetch_images(images)
657
+ else:
658
+ images = self.fetch_images(images)
659
+ images = make_flat_list_of_images(images)
660
+ if not valid_images(images):
661
+ raise ValueError(
662
+ 'Invalid image type. Must be PIL.Image.Image, numpy.ndarray, or torch.Tensor'
663
+ )
664
+ validate_preprocess_arguments(
665
+ do_rescale=do_rescale,
666
+ rescale_factor=rescale_factor,
667
+ do_normalize=do_normalize,
668
+ image_mean=image_mean,
669
+ image_std=image_std,
670
+ )
671
+ from concurrent.futures import ThreadPoolExecutor, as_completed
672
+
673
+ def _process_one(idx_img: int):
674
+ base = self._prep_uint8(images[idx_img], input_data_format)
675
+ h0, w0 = base.shape[:2]
676
+ results_imgs: List[np.ndarray] = []
677
+ results_sizes: List[Tuple[int, int]] = []
678
+ results_from: List[int] = []
679
+ results_flag: List[bool] = []
680
+ cand: List[Tuple[np.ndarray, bool]] = []
681
+ if keep_original:
682
+ cand.append((base, True))
683
+ if do_augment and self._augmentations is not None and aug_repeats > 0:
684
+ est_pad = max(
685
+ int(max(h0, w0) * pre_pad_expand_ratio),
686
+ pre_pad_min_px if pre_pad_expand_ratio > 0 else 0)
687
+ for k in range(aug_repeats):
688
+ seed_k = None if augment_seed is None else (
689
+ int(augment_seed) + idx_img * (aug_repeats + 1) + k)
690
+ aug_img = self._maybe_augment_uint8(base,
691
+ seed=seed_k,
692
+ pre_pad_px=est_pad)
693
+ cand.append((aug_img, False))
694
+
695
+ is_cv2_avail = (resize_backend == 'auto' and cv2 is not None)
696
+ be = 'cv2' if is_cv2_avail else resize_backend
697
+ max_long_edge = 1024
698
+
699
+ for uint8_img, is_orig in cand:
700
+ hh, ww = uint8_img.shape[:2]
701
+ if max(hh, ww) > max_long_edge:
702
+ scale = float(max_long_edge) / float(max(hh, ww))
703
+ targ_h = max(1, int(math.floor(hh * scale)))
704
+ targ_w = max(1, int(math.floor(ww * scale)))
705
+ uint8_img = self._resize_uint8(uint8_img, targ_h, targ_w,
706
+ be)
707
+ hh, ww = uint8_img.shape[:2]
708
+ MIN_HW = 224
709
+ ds = down_sample_ratio
710
+ ceil_h = max(MIN_HW, math.ceil(hh / ds) * ds)
711
+ ceil_w = max(MIN_HW, math.ceil(ww / ds) * ds)
712
+ if max(ceil_h, ceil_w) <= max_long_edge:
713
+ th, tw = ceil_h, ceil_w
714
+ else:
715
+ floor_h = max(MIN_HW, (hh // ds) * ds)
716
+ floor_w = max(MIN_HW, (ww // ds) * ds)
717
+ if floor_h <= 0 or floor_w <= 0:
718
+ floor_h = max(MIN_HW,
719
+ min(hh, max_long_edge) // ds * ds)
720
+ floor_w = max(MIN_HW,
721
+ min(ww, max_long_edge) // ds * ds)
722
+ th, tw = floor_h, floor_w
723
+
724
+ rs_img = self._resize_uint8(uint8_img, th, tw, be)
725
+ if do_rescale:
726
+ rs_img = rs_img.astype(np.float32)
727
+ np.multiply(rs_img,
728
+ float(rescale_factor),
729
+ out=rs_img,
730
+ casting='unsafe')
731
+ else:
732
+ rs_img = rs_img.astype(np.float32)
733
+ results_imgs.append(rs_img)
734
+ results_sizes.append((th, tw))
735
+ results_from.append(idx_img)
736
+ results_flag.append(is_orig)
737
+ return results_imgs, results_sizes, results_from, results_flag
738
+
739
+ proc_list: List[np.ndarray] = []
740
+ rec_sizes: List[Tuple[int, int]] = []
741
+ from_indices: List[int] = []
742
+ is_orig_flags: List[bool] = []
743
+ if num_workers and num_workers > 1 and len(images) > 1:
744
+ with ThreadPoolExecutor(max_workers=num_workers) as ex:
745
+ futs = [ex.submit(_process_one, i) for i in range(len(images))]
746
+ for fu in as_completed(futs):
747
+ imgs_i, sizes_i, from_i, flag_i = fu.result()
748
+ proc_list.extend(imgs_i)
749
+ rec_sizes.extend(sizes_i)
750
+ from_indices.extend(from_i)
751
+ is_orig_flags.extend(flag_i)
752
+ else:
753
+ for i in range(len(images)):
754
+ imgs_i, sizes_i, from_i, flag_i = _process_one(i)
755
+ proc_list.extend(imgs_i)
756
+ rec_sizes.extend(sizes_i)
757
+ from_indices.extend(from_i)
758
+ is_orig_flags.extend(flag_i)
759
+ if len(proc_list) == 0:
760
+ return BatchFeature(data={
761
+ 'image': [],
762
+ 'orig_spatial_shape': [],
763
+ 'expanded_from_indices': [],
764
+ 'is_original_flags': []
765
+ },
766
+ tensor_type=return_tensors)
767
+ max_h = max(h for h, _ in rec_sizes)
768
+ max_w = max(w for _, w in rec_sizes)
769
+ mean = np.array(image_mean, dtype=np.float32)
770
+ std = np.array(image_std, dtype=np.float32)
771
+ inv_std = 1.0 / np.where(std == 0, 1.0, std)
772
+
773
+ def _maybe_scale_stats_to_image_domain(
774
+ _arr: np.ndarray, exemplar: np.ndarray) -> np.ndarray:
775
+ if not do_rescale and exemplar.max() > 1.5 and _arr.max() <= 1.5:
776
+ return _arr * 255.0
777
+ return _arr
778
+
779
+ def _make_pad_color(c: int, exemplar: np.ndarray) -> np.ndarray:
780
+ _mean = _maybe_scale_stats_to_image_domain(mean, exemplar)
781
+ if pad_value_strategy == 'mean':
782
+ col = _mean
783
+ elif pad_value_strategy == 'white':
784
+ col = np.ones(
785
+ (c, ), dtype=np.float32) * (1.0 if do_rescale else 255.0)
786
+ elif pad_value_strategy == 'zero':
787
+ col = np.zeros((c, ), dtype=np.float32)
788
+ elif pad_value_strategy == 'custom':
789
+ if pad_value is None:
790
+ col = _mean
791
+ else:
792
+ col = np.array(pad_value, dtype=np.float32)
793
+ if col.ndim == 0:
794
+ col = np.full((c, ), float(col), dtype=np.float32)
795
+ if col.shape[0] != c:
796
+ raise ValueError(
797
+ f'pad_value length must match channels={c}')
798
+ else:
799
+ col = _mean
800
+ return col
801
+
802
+ def _to_ch_first(arr: np.ndarray) -> np.ndarray:
803
+ return np.transpose(arr, (2, 0, 1))
804
+
805
+ batched: List[np.ndarray] = [None] * len(proc_list)
806
+
807
+ def _pad_one(i: int):
808
+ np_img = proc_list[i]
809
+ h, w = rec_sizes[i]
810
+ C = np_img.shape[2]
811
+ pad_color = _make_pad_color(C, np_img)
812
+ if center_pad:
813
+ y0 = (max_h - h) // 2
814
+ x0 = (max_w - w) // 2
815
+ else:
816
+ y0 = 0
817
+ x0 = 0
818
+ pad_img = np.empty((max_h, max_w, C), dtype=np.float32)
819
+ pad_img[...] = pad_color
820
+ pad_img[y0:y0 + h, x0:x0 + w, :] = np_img
821
+ if do_normalize:
822
+ _mean = _maybe_scale_stats_to_image_domain(mean, pad_img)
823
+ _invstd = _maybe_scale_stats_to_image_domain(inv_std, pad_img)
824
+ if normalize_inplace:
825
+ np.subtract(pad_img, _mean, out=pad_img)
826
+ np.multiply(pad_img, _invstd, out=pad_img)
827
+ else:
828
+ pad_img = (pad_img - _mean) * _invstd
829
+ batched[i] = _to_ch_first(
830
+ pad_img
831
+ ) if output_channel_format == ChannelDimension.FIRST else pad_img
832
+
833
+ if pad_num_workers and pad_num_workers > 1 and len(proc_list) > 1:
834
+ with ThreadPoolExecutor(max_workers=pad_num_workers) as ex:
835
+ list(ex.map(_pad_one, range(len(proc_list))))
836
+ else:
837
+ for i in range(len(proc_list)):
838
+ _pad_one(i)
839
+ return BatchFeature(
840
+ data={
841
+ 'image': batched,
842
+ 'orig_spatial_shape': rec_sizes,
843
+ 'expanded_from_indices': from_indices,
844
+ 'is_original_flags': is_orig_flags,
845
+ },
846
+ tensor_type=return_tensors,
847
+ )
848
+
849
+
850
+ AutoImageProcessor.register('CMER',
851
+ slow_image_processor_class=CMERImageProcessor)
852
+ class CMERProcessor(ProcessorMixin):
853
+ attributes = ['image_processor', 'tokenizer']
854
+ image_processor_class = 'CMERImageProcessor'
855
+ tokenizer_class = 'PreTrainedTokenizerFast'
856
+
857
+ def __init__(
858
+ self,
859
+ image_processor=None,
860
+ tokenizer=None,
861
+ tokenizer_file: str = './configs/rec/cmer/cmer_tokenizer/tokenizer.json',
862
+ **kwargs
863
+ ):
864
+ if image_processor is None:
865
+ # 确保这里能正确导入你的 CMERImageProcessor
866
+ image_processor = CMERImageProcessor(**kwargs)
867
+
868
+ if tokenizer is None:
869
+ try:
870
+ tokenizer = PreTrainedTokenizerFast(
871
+ tokenizer_file=tokenizer_file,
872
+ padding_side="right",
873
+ truncation_side="right",
874
+ pad_token="<|pad|>",
875
+ bos_token="<|bos|>",
876
+ eos_token="<|eos|>",
877
+ unk_token="<|unk|>",
878
+ )
879
+ except Exception as e:
880
+ # logger 需要外部定义或引入,这里简单用 print 代替
881
+ print(f"Failed to initialize default tokenizer from {tokenizer_file}. Error: {e}")
882
+ tokenizer = None
883
+
884
+ super().__init__(image_processor=image_processor, tokenizer=tokenizer)
885
+
886
+
887
+ def __call__(
888
+ self,
889
+ images: ImageInput,
890
+ text: Union[str, List[str]]=None,
891
+ ids=None,
892
+ categorys=None,
893
+ return_tensors: Optional[Union[str, TensorType]] = "pt",
894
+ padding: Union[bool, str] = True,
895
+ truncation: bool = True,
896
+ max_length: Optional[int] = None,
897
+ **img_kwargs,
898
+ ):
899
+ if isinstance(images, dict) and "image" in images:
900
+ images = images["image"]
901
+ # 情况 2: 列表样本,例如 [{'image': <PIL...>}, {'image': <PIL...>}]
902
+ elif isinstance(images, (list, tuple)) and len(images) > 0 and isinstance(images[0], dict) and "image" in images[0]:
903
+ images = [img["image"] for img in images]
904
+ # 计算输入图片的数量,用于后续生成默认 text
905
+ if isinstance(images, (list, tuple)):
906
+ input_batch_size = len(images)
907
+ else:
908
+ input_batch_size = 1
909
+ image_outputs: BatchFeature = self.image_processor.preprocess(
910
+ images=images,
911
+ return_tensors=return_tensors,
912
+ **img_kwargs,
913
+ )
914
+ expanded_from = image_outputs.get("expanded_from_indices")
915
+ # =================================================================
916
+ # 2. [修复核心报错] 处理 text/ids/categorys 为 None 的情况
917
+ # =================================================================
918
+ # 如果 text 为 None (推理模式),生成空字符串列表
919
+ if text is None:
920
+ text_list = [""] * input_batch_size
921
+ elif isinstance(text, str):
922
+ text_list = [text]
923
+ else:
924
+ text_list = list(text)
925
+
926
+ # 如果 ids 为 None,生成默认占位符
927
+ if ids is None:
928
+ ids_list = [None] * len(text_list)
929
+ else:
930
+ ids_list = list(ids)
931
+
932
+ # 如果 categorys 为 None,生成默认占位符
933
+ if categorys is None:
934
+ cats_list = [None] * len(text_list)
935
+ else:
936
+ cats_list = list(categorys)
937
+ # =================================================================
938
+
939
+ if expanded_from is None:
940
+ num_in = len(text_list)
941
+ expanded_from = list(range(num_in))
942
+ else:
943
+ num_in = max(expanded_from) + 1
944
+
945
+ # 检查长度一致性
946
+ if not (len(text_list) == num_in == len(ids_list) == len(cats_list)):
947
+ raise ValueError(
948
+ f"[CMERProcessor] Mismatch between base counts: "
949
+ f"text={len(text_list)}, ids={len(ids_list)}, "
950
+ f"cats={len(cats_list)}, num_in(from expanded_from)={num_in}"
951
+ )
952
+
953
+ bos_token = self.tokenizer.bos_token
954
+ eos_token = self.tokenizer.eos_token
955
+ if bos_token is None or eos_token is None:
956
+ raise ValueError("Tokenizer must have a `bos_token` and an `eos_token`.")
957
+
958
+ base_texts = text_list
959
+ base_ids = ids_list
960
+ base_cats = cats_list
961
+
962
+ try:
963
+ expanded_texts = [
964
+ f"{bos_token}{base_texts[src]}{eos_token}" for src in expanded_from
965
+ ]
966
+ expanded_ids = [base_ids[src] for src in expanded_from]
967
+ expanded_cats = [base_cats[src] for src in expanded_from]
968
+ except IndexError:
969
+ raise ValueError(
970
+ f"[CMERProcessor] expanded_from_indices contains index out of range: "
971
+ f"max={max(expanded_from)}, but num_in={num_in}"
972
+ )
973
+
974
+ text_outputs = self.tokenizer(
975
+ expanded_texts,
976
+ return_tensors=return_tensors,
977
+ add_special_tokens=True,
978
+ padding=padding,
979
+ truncation=truncation,
980
+ max_length=max_length,
981
+ )
982
+
983
+ text_outputs["decoder_input_ids"] = text_outputs.pop("input_ids")
984
+ data = {**image_outputs, **text_outputs}
985
+
986
+ labels = (
987
+ data["decoder_input_ids"].clone()
988
+ if return_tensors is not None
989
+ else list(data["decoder_input_ids"])
990
+ )
991
+
992
+ pad_id = self.tokenizer.pad_token_id
993
+ if pad_id is None:
994
+ data["labels"] = labels
995
+ else:
996
+ if hasattr(labels, "masked_fill"):
997
+ labels = labels.masked_fill(labels == pad_id, -100)
998
+ else:
999
+ labels = [[(-100 if tok == pad_id else tok) for tok in seq] for seq in labels]
1000
+ data["labels"] = labels
1001
+
1002
+ # bf = BatchFeature(data=data, tensor_type=return_tensors)
1003
+ # bf["ids"] = expanded_ids
1004
+ # bf["categorys"] = expanded_cats
1005
+ input_ids = data["decoder_input_ids"]
1006
+ # return bf
1007
+
1008
+ if "attention_mask" in text_outputs:
1009
+ # attention_mask shape: [batch, seq_len]
1010
+ # sum(dim=1) 得到每个样本的有效长度
1011
+ length = text_outputs["attention_mask"].sum(dim=1)
1012
+ # 确保是 int32 或 int64
1013
+ length = length.to(dtype=torch.int32)
1014
+ else:
1015
+ # 如果没有 attention_mask,假设没有 padding,直接取 shape
1016
+ seq_len = input_ids.shape[1]
1017
+ batch_size = input_ids.shape[0]
1018
+ length = torch.full((batch_size,), seq_len, dtype=torch.int32)
1019
+
1020
+ # 6. 返回 Tuple (pixel_values, labels, length)
1021
+ pixel_values = data['image']
1022
+ return pixel_values, labels, length
1023
+
1024
+ def batch_decode(self, *args, **kwargs):
1025
+ return self.tokenizer.batch_decode(*args, **kwargs)