mineru 2.6.8__py3-none-any.whl → 2.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mineru/backend/hybrid/__init__.py +1 -0
- mineru/backend/hybrid/hybrid_analyze.py +526 -0
- mineru/backend/hybrid/hybrid_magic_model.py +617 -0
- mineru/backend/hybrid/hybrid_model_output_to_middle_json.py +212 -0
- mineru/backend/pipeline/batch_analyze.py +9 -1
- mineru/backend/pipeline/model_init.py +96 -1
- mineru/backend/pipeline/pipeline_analyze.py +6 -4
- mineru/backend/pipeline/pipeline_middle_json_mkcontent.py +32 -41
- mineru/backend/vlm/utils.py +3 -1
- mineru/backend/vlm/vlm_analyze.py +12 -12
- mineru/backend/vlm/vlm_magic_model.py +24 -89
- mineru/backend/vlm/vlm_middle_json_mkcontent.py +112 -12
- mineru/cli/client.py +17 -17
- mineru/cli/common.py +170 -20
- mineru/cli/fast_api.py +39 -13
- mineru/cli/gradio_app.py +232 -206
- mineru/model/mfd/yolo_v8.py +12 -6
- mineru/model/mfr/unimernet/Unimernet.py +71 -3
- mineru/resources/header.html +5 -1
- mineru/utils/boxbase.py +23 -0
- mineru/utils/char_utils.py +55 -0
- mineru/utils/engine_utils.py +74 -0
- mineru/utils/enum_class.py +18 -1
- mineru/utils/magic_model_utils.py +85 -2
- mineru/utils/span_pre_proc.py +5 -3
- mineru/utils/table_merge.py +5 -21
- mineru/version.py +1 -1
- mineru-2.7.0.dist-info/METADATA +433 -0
- {mineru-2.6.8.dist-info → mineru-2.7.0.dist-info}/RECORD +33 -27
- mineru-2.6.8.dist-info/METADATA +0 -954
- {mineru-2.6.8.dist-info → mineru-2.7.0.dist-info}/WHEEL +0 -0
- {mineru-2.6.8.dist-info → mineru-2.7.0.dist-info}/entry_points.txt +0 -0
- {mineru-2.6.8.dist-info → mineru-2.7.0.dist-info}/licenses/LICENSE.md +0 -0
- {mineru-2.6.8.dist-info → mineru-2.7.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1 @@
|
|
|
1
|
+
# Copyright (c) Opendatalab. All rights reserved.
|
|
@@ -0,0 +1,526 @@
|
|
|
1
|
+
# Copyright (c) Opendatalab. All rights reserved.
|
|
2
|
+
import os
|
|
3
|
+
import time
|
|
4
|
+
from collections import defaultdict
|
|
5
|
+
|
|
6
|
+
import cv2
|
|
7
|
+
import numpy as np
|
|
8
|
+
from loguru import logger
|
|
9
|
+
from mineru_vl_utils import MinerUClient
|
|
10
|
+
from mineru_vl_utils.structs import BlockType
|
|
11
|
+
from tqdm import tqdm
|
|
12
|
+
|
|
13
|
+
from mineru.backend.hybrid.hybrid_model_output_to_middle_json import result_to_middle_json
|
|
14
|
+
from mineru.backend.pipeline.model_init import HybridModelSingleton
|
|
15
|
+
from mineru.backend.vlm.vlm_analyze import ModelSingleton
|
|
16
|
+
from mineru.data.data_reader_writer import DataWriter
|
|
17
|
+
from mineru.utils.config_reader import get_device
|
|
18
|
+
from mineru.utils.enum_class import ImageType, NotExtractType
|
|
19
|
+
from mineru.utils.model_utils import crop_img, get_vram, clean_memory
|
|
20
|
+
from mineru.utils.ocr_utils import get_adjusted_mfdetrec_res, get_ocr_result_list, sorted_boxes, merge_det_boxes, \
|
|
21
|
+
update_det_boxes, OcrConfidence
|
|
22
|
+
from mineru.utils.pdf_classify import classify
|
|
23
|
+
from mineru.utils.pdf_image_tools import load_images_from_pdf
|
|
24
|
+
|
|
25
|
+
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' # 让mps可以fallback
|
|
26
|
+
os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
|
|
27
|
+
|
|
28
|
+
MFR_BASE_BATCH_SIZE = 16
|
|
29
|
+
OCR_DET_BASE_BATCH_SIZE = 16
|
|
30
|
+
|
|
31
|
+
not_extract_list = [item.value for item in NotExtractType]
|
|
32
|
+
|
|
33
|
+
def ocr_classify(pdf_bytes, parse_method: str = 'auto',) -> bool:
|
|
34
|
+
# 确定OCR设置
|
|
35
|
+
_ocr_enable = False
|
|
36
|
+
if parse_method == 'auto':
|
|
37
|
+
if classify(pdf_bytes) == 'ocr':
|
|
38
|
+
_ocr_enable = True
|
|
39
|
+
elif parse_method == 'ocr':
|
|
40
|
+
_ocr_enable = True
|
|
41
|
+
return _ocr_enable
|
|
42
|
+
|
|
43
|
+
def ocr_det(
|
|
44
|
+
hybrid_pipeline_model,
|
|
45
|
+
np_images,
|
|
46
|
+
results,
|
|
47
|
+
mfd_res,
|
|
48
|
+
_ocr_enable,
|
|
49
|
+
batch_radio: int = 1,
|
|
50
|
+
):
|
|
51
|
+
ocr_res_list = []
|
|
52
|
+
if not hybrid_pipeline_model.enable_ocr_det_batch:
|
|
53
|
+
# 非批处理模式 - 逐页处理
|
|
54
|
+
for np_image, page_mfd_res, page_results in tqdm(
|
|
55
|
+
zip(np_images, mfd_res, results),
|
|
56
|
+
total=len(np_images),
|
|
57
|
+
desc="OCR-det"
|
|
58
|
+
):
|
|
59
|
+
ocr_res_list.append([])
|
|
60
|
+
img_height, img_width = np_image.shape[:2]
|
|
61
|
+
for res in page_results:
|
|
62
|
+
if res['type'] not in not_extract_list:
|
|
63
|
+
continue
|
|
64
|
+
x0 = max(0, int(res['bbox'][0] * img_width))
|
|
65
|
+
y0 = max(0, int(res['bbox'][1] * img_height))
|
|
66
|
+
x1 = min(img_width, int(res['bbox'][2] * img_width))
|
|
67
|
+
y1 = min(img_height, int(res['bbox'][3] * img_height))
|
|
68
|
+
if x1 <= x0 or y1 <= y0:
|
|
69
|
+
continue
|
|
70
|
+
res['poly'] = [x0, y0, x1, y0, x1, y1, x0, y1]
|
|
71
|
+
new_image, useful_list = crop_img(
|
|
72
|
+
res, np_image, crop_paste_x=50, crop_paste_y=50
|
|
73
|
+
)
|
|
74
|
+
adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(
|
|
75
|
+
page_mfd_res, useful_list
|
|
76
|
+
)
|
|
77
|
+
bgr_image = cv2.cvtColor(new_image, cv2.COLOR_RGB2BGR)
|
|
78
|
+
ocr_res = hybrid_pipeline_model.ocr_model.ocr(
|
|
79
|
+
bgr_image, mfd_res=adjusted_mfdetrec_res, rec=False
|
|
80
|
+
)[0]
|
|
81
|
+
if ocr_res:
|
|
82
|
+
ocr_result_list = get_ocr_result_list(
|
|
83
|
+
ocr_res, useful_list, _ocr_enable, bgr_image, hybrid_pipeline_model.lang
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
ocr_res_list[-1].extend(ocr_result_list)
|
|
87
|
+
else:
|
|
88
|
+
# 批处理模式 - 按语言和分辨率分组
|
|
89
|
+
# 收集所有需要OCR检测的裁剪图像
|
|
90
|
+
all_cropped_images_info = []
|
|
91
|
+
|
|
92
|
+
for np_image, page_mfd_res, page_results in zip(
|
|
93
|
+
np_images, mfd_res, results
|
|
94
|
+
):
|
|
95
|
+
ocr_res_list.append([])
|
|
96
|
+
img_height, img_width = np_image.shape[:2]
|
|
97
|
+
for res in page_results:
|
|
98
|
+
if res['type'] not in not_extract_list:
|
|
99
|
+
continue
|
|
100
|
+
x0 = max(0, int(res['bbox'][0] * img_width))
|
|
101
|
+
y0 = max(0, int(res['bbox'][1] * img_height))
|
|
102
|
+
x1 = min(img_width, int(res['bbox'][2] * img_width))
|
|
103
|
+
y1 = min(img_height, int(res['bbox'][3] * img_height))
|
|
104
|
+
if x1 <= x0 or y1 <= y0:
|
|
105
|
+
continue
|
|
106
|
+
res['poly'] = [x0, y0, x1, y0, x1, y1, x0, y1]
|
|
107
|
+
new_image, useful_list = crop_img(
|
|
108
|
+
res, np_image, crop_paste_x=50, crop_paste_y=50
|
|
109
|
+
)
|
|
110
|
+
adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(
|
|
111
|
+
page_mfd_res, useful_list
|
|
112
|
+
)
|
|
113
|
+
bgr_image = cv2.cvtColor(new_image, cv2.COLOR_RGB2BGR)
|
|
114
|
+
all_cropped_images_info.append((
|
|
115
|
+
bgr_image, useful_list, adjusted_mfdetrec_res, ocr_res_list[-1]
|
|
116
|
+
))
|
|
117
|
+
|
|
118
|
+
# 按分辨率分组并同时完成padding
|
|
119
|
+
RESOLUTION_GROUP_STRIDE = 64 # 32
|
|
120
|
+
|
|
121
|
+
resolution_groups = defaultdict(list)
|
|
122
|
+
for crop_info in all_cropped_images_info:
|
|
123
|
+
cropped_img = crop_info[0]
|
|
124
|
+
h, w = cropped_img.shape[:2]
|
|
125
|
+
# 直接计算目标尺寸并用作分组键
|
|
126
|
+
target_h = ((h + RESOLUTION_GROUP_STRIDE - 1) // RESOLUTION_GROUP_STRIDE) * RESOLUTION_GROUP_STRIDE
|
|
127
|
+
target_w = ((w + RESOLUTION_GROUP_STRIDE - 1) // RESOLUTION_GROUP_STRIDE) * RESOLUTION_GROUP_STRIDE
|
|
128
|
+
group_key = (target_h, target_w)
|
|
129
|
+
resolution_groups[group_key].append(crop_info)
|
|
130
|
+
|
|
131
|
+
# 对每个分辨率组进行批处理
|
|
132
|
+
for (target_h, target_w), group_crops in tqdm(resolution_groups.items(), desc=f"OCR-det"):
|
|
133
|
+
# 对所有图像进行padding到统一尺寸
|
|
134
|
+
batch_images = []
|
|
135
|
+
for crop_info in group_crops:
|
|
136
|
+
img = crop_info[0]
|
|
137
|
+
h, w = img.shape[:2]
|
|
138
|
+
# 创建目标尺寸的白色背景
|
|
139
|
+
padded_img = np.ones((target_h, target_w, 3), dtype=np.uint8) * 255
|
|
140
|
+
padded_img[:h, :w] = img
|
|
141
|
+
batch_images.append(padded_img)
|
|
142
|
+
|
|
143
|
+
# 批处理检测
|
|
144
|
+
det_batch_size = min(len(batch_images), batch_radio*OCR_DET_BASE_BATCH_SIZE)
|
|
145
|
+
batch_results = hybrid_pipeline_model.ocr_model.text_detector.batch_predict(batch_images, det_batch_size)
|
|
146
|
+
|
|
147
|
+
# 处理批处理结果
|
|
148
|
+
for crop_info, (dt_boxes, _) in zip(group_crops, batch_results):
|
|
149
|
+
bgr_image, useful_list, adjusted_mfdetrec_res, ocr_page_res_list = crop_info
|
|
150
|
+
|
|
151
|
+
if dt_boxes is not None and len(dt_boxes) > 0:
|
|
152
|
+
# 处理检测框
|
|
153
|
+
dt_boxes_sorted = sorted_boxes(dt_boxes)
|
|
154
|
+
dt_boxes_merged = merge_det_boxes(dt_boxes_sorted) if dt_boxes_sorted else []
|
|
155
|
+
|
|
156
|
+
# 根据公式位置更新检测框
|
|
157
|
+
dt_boxes_final = (update_det_boxes(dt_boxes_merged, adjusted_mfdetrec_res)
|
|
158
|
+
if dt_boxes_merged and adjusted_mfdetrec_res
|
|
159
|
+
else dt_boxes_merged)
|
|
160
|
+
|
|
161
|
+
if dt_boxes_final:
|
|
162
|
+
ocr_res = [box.tolist() if hasattr(box, 'tolist') else box for box in dt_boxes_final]
|
|
163
|
+
ocr_result_list = get_ocr_result_list(
|
|
164
|
+
ocr_res, useful_list, _ocr_enable, bgr_image, hybrid_pipeline_model.lang
|
|
165
|
+
)
|
|
166
|
+
ocr_page_res_list.extend(ocr_result_list)
|
|
167
|
+
return ocr_res_list
|
|
168
|
+
|
|
169
|
+
def mask_image_regions(np_images, results):
|
|
170
|
+
# 根据vlm返回的结果,在每一页中将image、table、equation块mask成白色背景图像
|
|
171
|
+
for np_image, vlm_page_results in zip(np_images, results):
|
|
172
|
+
img_height, img_width = np_image.shape[:2]
|
|
173
|
+
# 收集需要mask的区域
|
|
174
|
+
mask_regions = []
|
|
175
|
+
for block in vlm_page_results:
|
|
176
|
+
if block['type'] in [BlockType.IMAGE, BlockType.TABLE, BlockType.EQUATION]:
|
|
177
|
+
bbox = block['bbox']
|
|
178
|
+
# 批量转换归一化坐标到像素坐标,并进行边界检查
|
|
179
|
+
x0 = max(0, int(bbox[0] * img_width))
|
|
180
|
+
y0 = max(0, int(bbox[1] * img_height))
|
|
181
|
+
x1 = min(img_width, int(bbox[2] * img_width))
|
|
182
|
+
y1 = min(img_height, int(bbox[3] * img_height))
|
|
183
|
+
# 只添加有效区域
|
|
184
|
+
if x1 > x0 and y1 > y0:
|
|
185
|
+
mask_regions.append((y0, y1, x0, x1))
|
|
186
|
+
# 批量应用mask
|
|
187
|
+
for y0, y1, x0, x1 in mask_regions:
|
|
188
|
+
np_image[y0:y1, x0:x1, :] = 255
|
|
189
|
+
return np_images
|
|
190
|
+
|
|
191
|
+
def normalize_poly_to_bbox(item, page_width, page_height):
|
|
192
|
+
"""将poly坐标归一化为bbox"""
|
|
193
|
+
poly = item['poly']
|
|
194
|
+
x0 = min(max(poly[0] / page_width, 0), 1)
|
|
195
|
+
y0 = min(max(poly[1] / page_height, 0), 1)
|
|
196
|
+
x1 = min(max(poly[4] / page_width, 0), 1)
|
|
197
|
+
y1 = min(max(poly[5] / page_height, 0), 1)
|
|
198
|
+
item['bbox'] = [round(x0, 3), round(y0, 3), round(x1, 3), round(y1, 3)]
|
|
199
|
+
item.pop('poly', None)
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
def _process_ocr_and_formulas(
|
|
203
|
+
images_pil_list,
|
|
204
|
+
results,
|
|
205
|
+
language,
|
|
206
|
+
inline_formula_enable,
|
|
207
|
+
_ocr_enable,
|
|
208
|
+
batch_radio: int = 1,
|
|
209
|
+
):
|
|
210
|
+
"""处理OCR和公式识别"""
|
|
211
|
+
|
|
212
|
+
# 遍历results,对文本块截图交由OCR识别
|
|
213
|
+
# 根据_ocr_enable决定ocr只开det还是det+rec
|
|
214
|
+
# 根据inline_formula_enable决定是使用mfd和ocr结合的方式,还是纯ocr方式
|
|
215
|
+
|
|
216
|
+
# 将PIL图片转换为numpy数组
|
|
217
|
+
np_images = [np.asarray(pil_image).copy() for pil_image in images_pil_list]
|
|
218
|
+
|
|
219
|
+
# 获取混合模型实例
|
|
220
|
+
hybrid_model_singleton = HybridModelSingleton()
|
|
221
|
+
hybrid_pipeline_model = hybrid_model_singleton.get_model(
|
|
222
|
+
lang=language,
|
|
223
|
+
formula_enable=inline_formula_enable,
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
if inline_formula_enable:
|
|
227
|
+
# 在进行`行内`公式检测和识别前,先将图像中的图片、表格、`行间`公式区域mask掉
|
|
228
|
+
np_images = mask_image_regions(np_images, results)
|
|
229
|
+
# 公式检测
|
|
230
|
+
images_mfd_res = hybrid_pipeline_model.mfd_model.batch_predict(np_images, batch_size=1, conf=0.5)
|
|
231
|
+
# 公式识别
|
|
232
|
+
inline_formula_list = hybrid_pipeline_model.mfr_model.batch_predict(
|
|
233
|
+
images_mfd_res,
|
|
234
|
+
np_images,
|
|
235
|
+
batch_size=batch_radio*MFR_BASE_BATCH_SIZE,
|
|
236
|
+
interline_enable=True,
|
|
237
|
+
)
|
|
238
|
+
else:
|
|
239
|
+
inline_formula_list = [[] for _ in range(len(images_pil_list))]
|
|
240
|
+
|
|
241
|
+
mfd_res = []
|
|
242
|
+
for page_inline_formula_list in inline_formula_list:
|
|
243
|
+
page_mfd_res = []
|
|
244
|
+
for formula in page_inline_formula_list:
|
|
245
|
+
formula['category_id'] = 13
|
|
246
|
+
page_mfd_res.append({
|
|
247
|
+
"bbox": [int(formula['poly'][0]), int(formula['poly'][1]),
|
|
248
|
+
int(formula['poly'][4]), int(formula['poly'][5])],
|
|
249
|
+
})
|
|
250
|
+
mfd_res.append(page_mfd_res)
|
|
251
|
+
|
|
252
|
+
# vlm没有执行ocr,需要ocr_det
|
|
253
|
+
ocr_res_list = ocr_det(
|
|
254
|
+
hybrid_pipeline_model,
|
|
255
|
+
np_images,
|
|
256
|
+
results,
|
|
257
|
+
mfd_res,
|
|
258
|
+
_ocr_enable,
|
|
259
|
+
batch_radio=batch_radio,
|
|
260
|
+
)
|
|
261
|
+
|
|
262
|
+
# 如果需要ocr则做ocr_rec
|
|
263
|
+
if _ocr_enable:
|
|
264
|
+
need_ocr_list = []
|
|
265
|
+
img_crop_list = []
|
|
266
|
+
for page_ocr_res_list in ocr_res_list:
|
|
267
|
+
for ocr_res in page_ocr_res_list:
|
|
268
|
+
if 'np_img' in ocr_res:
|
|
269
|
+
need_ocr_list.append(ocr_res)
|
|
270
|
+
img_crop_list.append(ocr_res.pop('np_img'))
|
|
271
|
+
if len(img_crop_list) > 0:
|
|
272
|
+
# Process OCR
|
|
273
|
+
ocr_result_list = hybrid_pipeline_model.ocr_model.ocr(img_crop_list, det=False, tqdm_enable=True)[0]
|
|
274
|
+
|
|
275
|
+
# Verify we have matching counts
|
|
276
|
+
assert len(ocr_result_list) == len(need_ocr_list), f'ocr_result_list: {len(ocr_result_list)}, need_ocr_list: {len(need_ocr_list)}'
|
|
277
|
+
|
|
278
|
+
# Process OCR results for this language
|
|
279
|
+
for index, need_ocr_res in enumerate(need_ocr_list):
|
|
280
|
+
ocr_text, ocr_score = ocr_result_list[index]
|
|
281
|
+
need_ocr_res['text'] = ocr_text
|
|
282
|
+
need_ocr_res['score'] = float(f"{ocr_score:.3f}")
|
|
283
|
+
if ocr_score < OcrConfidence.min_confidence:
|
|
284
|
+
need_ocr_res['category_id'] = 16
|
|
285
|
+
else:
|
|
286
|
+
layout_res_bbox = [need_ocr_res['poly'][0], need_ocr_res['poly'][1],
|
|
287
|
+
need_ocr_res['poly'][4], need_ocr_res['poly'][5]]
|
|
288
|
+
layout_res_width = layout_res_bbox[2] - layout_res_bbox[0]
|
|
289
|
+
layout_res_height = layout_res_bbox[3] - layout_res_bbox[1]
|
|
290
|
+
if (
|
|
291
|
+
ocr_text in [
|
|
292
|
+
'(204号', '(20', '(2', '(2号', '(20号', '号','(204',
|
|
293
|
+
'(cid:)', '(ci:)', '(cd:1)', 'cd:)', 'c)', '(cd:)', 'c', 'id:)',
|
|
294
|
+
':)', '√:)', '√i:)', '−i:)', '−:' , 'i:)',
|
|
295
|
+
]
|
|
296
|
+
and ocr_score < 0.8
|
|
297
|
+
and layout_res_width < layout_res_height
|
|
298
|
+
):
|
|
299
|
+
need_ocr_res['category_id'] = 16
|
|
300
|
+
|
|
301
|
+
return inline_formula_list, ocr_res_list, hybrid_pipeline_model
|
|
302
|
+
|
|
303
|
+
|
|
304
|
+
def _normalize_bbox(
|
|
305
|
+
inline_formula_list,
|
|
306
|
+
ocr_res_list,
|
|
307
|
+
images_pil_list,
|
|
308
|
+
):
|
|
309
|
+
"""归一化坐标并生成最终结果"""
|
|
310
|
+
for page_inline_formula_list, page_ocr_res_list, page_pil_image in zip(
|
|
311
|
+
inline_formula_list, ocr_res_list, images_pil_list
|
|
312
|
+
):
|
|
313
|
+
if page_inline_formula_list or page_ocr_res_list:
|
|
314
|
+
page_width, page_height = page_pil_image.size
|
|
315
|
+
# 处理公式列表
|
|
316
|
+
for formula in page_inline_formula_list:
|
|
317
|
+
normalize_poly_to_bbox(formula, page_width, page_height)
|
|
318
|
+
# 处理OCR结果列表
|
|
319
|
+
for ocr_res in page_ocr_res_list:
|
|
320
|
+
normalize_poly_to_bbox(ocr_res, page_width, page_height)
|
|
321
|
+
|
|
322
|
+
|
|
323
|
+
def get_batch_ratio(device):
|
|
324
|
+
"""
|
|
325
|
+
根据显存大小或环境变量获取 batch ratio
|
|
326
|
+
"""
|
|
327
|
+
# 1. 优先尝试从环境变量获取
|
|
328
|
+
"""
|
|
329
|
+
c/s架构分离部署时,建议通过设置环境变量 MINERU_HYBRID_BATCH_RATIO 来指定 batch ratio
|
|
330
|
+
建议的设置值如如下,以下配置值已考虑一定的冗余,单卡多终端部署时为了保证稳定性,可以额外保留一个client端的显存作为整体冗余
|
|
331
|
+
单个client端显存大小 | MINERU_HYBRID_BATCH_RATIO
|
|
332
|
+
------------------|------------------------
|
|
333
|
+
<= 6 GB | 8
|
|
334
|
+
<= 4.5 GB | 4
|
|
335
|
+
<= 3 GB | 2
|
|
336
|
+
<= 2.5 GB | 1
|
|
337
|
+
例如:
|
|
338
|
+
export MINERU_HYBRID_BATCH_RATIO=4
|
|
339
|
+
"""
|
|
340
|
+
env_val = os.getenv("MINERU_HYBRID_BATCH_RATIO")
|
|
341
|
+
if env_val:
|
|
342
|
+
try:
|
|
343
|
+
batch_ratio = int(env_val)
|
|
344
|
+
logger.info(f"hybrid batch ratio (from env): {batch_ratio}")
|
|
345
|
+
return batch_ratio
|
|
346
|
+
except ValueError as e:
|
|
347
|
+
logger.warning(f"Invalid MINERU_HYBRID_BATCH_RATIO value: {env_val}, switching to auto mode. Error: {e}")
|
|
348
|
+
|
|
349
|
+
# 2. 根据显存自动推断
|
|
350
|
+
"""
|
|
351
|
+
根据总显存大小粗略估计 batch ratio,需要排除掉vllm等推理框架占用的显存开销
|
|
352
|
+
"""
|
|
353
|
+
gpu_memory = get_vram(device)
|
|
354
|
+
if gpu_memory >= 32:
|
|
355
|
+
batch_ratio = 16
|
|
356
|
+
elif gpu_memory >= 16:
|
|
357
|
+
batch_ratio = 8
|
|
358
|
+
elif gpu_memory >= 12:
|
|
359
|
+
batch_ratio = 4
|
|
360
|
+
elif gpu_memory >= 8:
|
|
361
|
+
batch_ratio = 2
|
|
362
|
+
else:
|
|
363
|
+
batch_ratio = 1
|
|
364
|
+
|
|
365
|
+
logger.info(f"hybrid batch ratio (auto, vram={gpu_memory}GB): {batch_ratio}")
|
|
366
|
+
return batch_ratio
|
|
367
|
+
|
|
368
|
+
|
|
369
|
+
def _should_enable_vlm_ocr(ocr_enable: bool, language: str, inline_formula_enable: bool) -> bool:
|
|
370
|
+
"""判断是否启用VLM OCR"""
|
|
371
|
+
force_enable = os.getenv("MINERU_FORCE_VLM_OCR_ENABLE", "0").lower() in ("1", "true", "yes")
|
|
372
|
+
if force_enable:
|
|
373
|
+
return True
|
|
374
|
+
|
|
375
|
+
force_pipeline = os.getenv("MINERU_HYBRID_FORCE_PIPELINE_ENABLE", "0").lower() in ("1", "true", "yes")
|
|
376
|
+
return (
|
|
377
|
+
ocr_enable
|
|
378
|
+
and language in ["ch", "en"]
|
|
379
|
+
and inline_formula_enable
|
|
380
|
+
and not force_pipeline
|
|
381
|
+
)
|
|
382
|
+
|
|
383
|
+
|
|
384
|
+
def doc_analyze(
|
|
385
|
+
pdf_bytes,
|
|
386
|
+
image_writer: DataWriter | None,
|
|
387
|
+
predictor: MinerUClient | None = None,
|
|
388
|
+
backend="transformers",
|
|
389
|
+
parse_method: str = 'auto',
|
|
390
|
+
language: str = 'ch',
|
|
391
|
+
inline_formula_enable: bool = True,
|
|
392
|
+
model_path: str | None = None,
|
|
393
|
+
server_url: str | None = None,
|
|
394
|
+
**kwargs,
|
|
395
|
+
):
|
|
396
|
+
# 初始化预测器
|
|
397
|
+
if predictor is None:
|
|
398
|
+
predictor = ModelSingleton().get_model(backend, model_path, server_url, **kwargs)
|
|
399
|
+
|
|
400
|
+
# 加载图像
|
|
401
|
+
load_images_start = time.time()
|
|
402
|
+
images_list, pdf_doc = load_images_from_pdf(pdf_bytes, image_type=ImageType.PIL)
|
|
403
|
+
images_pil_list = [image_dict["img_pil"] for image_dict in images_list]
|
|
404
|
+
load_images_time = round(time.time() - load_images_start, 2)
|
|
405
|
+
logger.debug(f"load images cost: {load_images_time}, speed: {round(len(images_pil_list)/load_images_time, 3)} images/s")
|
|
406
|
+
|
|
407
|
+
# 获取设备信息
|
|
408
|
+
device = get_device()
|
|
409
|
+
|
|
410
|
+
# 确定OCR配置
|
|
411
|
+
_ocr_enable = ocr_classify(pdf_bytes, parse_method=parse_method)
|
|
412
|
+
_vlm_ocr_enable = _should_enable_vlm_ocr(_ocr_enable, language, inline_formula_enable)
|
|
413
|
+
|
|
414
|
+
infer_start = time.time()
|
|
415
|
+
# VLM提取
|
|
416
|
+
if _vlm_ocr_enable:
|
|
417
|
+
results = predictor.batch_two_step_extract(images=images_pil_list)
|
|
418
|
+
hybrid_pipeline_model = None
|
|
419
|
+
inline_formula_list = [[] for _ in images_pil_list]
|
|
420
|
+
ocr_res_list = [[] for _ in images_pil_list]
|
|
421
|
+
else:
|
|
422
|
+
batch_ratio = get_batch_ratio(device)
|
|
423
|
+
results = predictor.batch_two_step_extract(
|
|
424
|
+
images=images_pil_list,
|
|
425
|
+
not_extract_list=not_extract_list
|
|
426
|
+
)
|
|
427
|
+
inline_formula_list, ocr_res_list, hybrid_pipeline_model = _process_ocr_and_formulas(
|
|
428
|
+
images_pil_list,
|
|
429
|
+
results,
|
|
430
|
+
language,
|
|
431
|
+
inline_formula_enable,
|
|
432
|
+
_ocr_enable,
|
|
433
|
+
batch_radio=batch_ratio,
|
|
434
|
+
)
|
|
435
|
+
_normalize_bbox(inline_formula_list, ocr_res_list, images_pil_list)
|
|
436
|
+
infer_time = round(time.time() - infer_start, 2)
|
|
437
|
+
logger.debug(f"infer finished, cost: {infer_time}, speed: {round(len(results)/infer_time, 3)} page/s")
|
|
438
|
+
|
|
439
|
+
# 生成中间JSON
|
|
440
|
+
middle_json = result_to_middle_json(
|
|
441
|
+
results,
|
|
442
|
+
inline_formula_list,
|
|
443
|
+
ocr_res_list,
|
|
444
|
+
images_list,
|
|
445
|
+
pdf_doc,
|
|
446
|
+
image_writer,
|
|
447
|
+
_ocr_enable,
|
|
448
|
+
_vlm_ocr_enable,
|
|
449
|
+
hybrid_pipeline_model,
|
|
450
|
+
)
|
|
451
|
+
|
|
452
|
+
clean_memory(device)
|
|
453
|
+
return middle_json, results, _vlm_ocr_enable
|
|
454
|
+
|
|
455
|
+
|
|
456
|
+
async def aio_doc_analyze(
|
|
457
|
+
pdf_bytes,
|
|
458
|
+
image_writer: DataWriter | None,
|
|
459
|
+
predictor: MinerUClient | None = None,
|
|
460
|
+
backend="transformers",
|
|
461
|
+
parse_method: str = 'auto',
|
|
462
|
+
language: str = 'ch',
|
|
463
|
+
inline_formula_enable: bool = True,
|
|
464
|
+
model_path: str | None = None,
|
|
465
|
+
server_url: str | None = None,
|
|
466
|
+
**kwargs,
|
|
467
|
+
):
|
|
468
|
+
# 初始化预测器
|
|
469
|
+
if predictor is None:
|
|
470
|
+
predictor = ModelSingleton().get_model(backend, model_path, server_url, **kwargs)
|
|
471
|
+
|
|
472
|
+
# 加载图像
|
|
473
|
+
load_images_start = time.time()
|
|
474
|
+
images_list, pdf_doc = load_images_from_pdf(pdf_bytes, image_type=ImageType.PIL)
|
|
475
|
+
images_pil_list = [image_dict["img_pil"] for image_dict in images_list]
|
|
476
|
+
load_images_time = round(time.time() - load_images_start, 2)
|
|
477
|
+
logger.debug(f"load images cost: {load_images_time}, speed: {round(len(images_pil_list)/load_images_time, 3)} images/s")
|
|
478
|
+
|
|
479
|
+
# 获取设备信息
|
|
480
|
+
device = get_device()
|
|
481
|
+
|
|
482
|
+
# 确定OCR配置
|
|
483
|
+
_ocr_enable = ocr_classify(pdf_bytes, parse_method=parse_method)
|
|
484
|
+
_vlm_ocr_enable = _should_enable_vlm_ocr(_ocr_enable, language, inline_formula_enable)
|
|
485
|
+
|
|
486
|
+
infer_start = time.time()
|
|
487
|
+
# VLM提取
|
|
488
|
+
if _vlm_ocr_enable:
|
|
489
|
+
results = await predictor.aio_batch_two_step_extract(images=images_pil_list)
|
|
490
|
+
hybrid_pipeline_model = None
|
|
491
|
+
inline_formula_list = [[] for _ in images_pil_list]
|
|
492
|
+
ocr_res_list = [[] for _ in images_pil_list]
|
|
493
|
+
else:
|
|
494
|
+
batch_ratio = get_batch_ratio(device)
|
|
495
|
+
results = await predictor.aio_batch_two_step_extract(
|
|
496
|
+
images=images_pil_list,
|
|
497
|
+
not_extract_list=not_extract_list
|
|
498
|
+
)
|
|
499
|
+
inline_formula_list, ocr_res_list, hybrid_pipeline_model = _process_ocr_and_formulas(
|
|
500
|
+
images_pil_list,
|
|
501
|
+
results,
|
|
502
|
+
language,
|
|
503
|
+
inline_formula_enable,
|
|
504
|
+
_ocr_enable,
|
|
505
|
+
batch_radio=batch_ratio,
|
|
506
|
+
)
|
|
507
|
+
_normalize_bbox(inline_formula_list, ocr_res_list, images_pil_list)
|
|
508
|
+
infer_time = round(time.time() - infer_start, 2)
|
|
509
|
+
logger.debug(f"infer finished, cost: {infer_time}, speed: {round(len(results)/infer_time, 3)} page/s")
|
|
510
|
+
|
|
511
|
+
# 生成中间JSON
|
|
512
|
+
middle_json = result_to_middle_json(
|
|
513
|
+
results,
|
|
514
|
+
inline_formula_list,
|
|
515
|
+
ocr_res_list,
|
|
516
|
+
images_list,
|
|
517
|
+
pdf_doc,
|
|
518
|
+
image_writer,
|
|
519
|
+
_ocr_enable,
|
|
520
|
+
_vlm_ocr_enable,
|
|
521
|
+
hybrid_pipeline_model,
|
|
522
|
+
)
|
|
523
|
+
|
|
524
|
+
clean_memory(device)
|
|
525
|
+
return middle_json, results, _vlm_ocr_enable
|
|
526
|
+
|