openocr-python 0.0.9__py3-none-any.whl → 0.1.0.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- openocr/__init__.py +35 -1
- openocr/configs/dataset/rec/evaluation.yaml +41 -0
- openocr/configs/dataset/rec/ltb.yaml +9 -0
- openocr/configs/dataset/rec/mjsynth.yaml +11 -0
- openocr/configs/dataset/rec/openvino.yaml +25 -0
- openocr/configs/dataset/rec/ost.yaml +17 -0
- openocr/configs/dataset/rec/synthtext.yaml +7 -0
- openocr/configs/dataset/rec/test.yaml +77 -0
- openocr/configs/dataset/rec/textocr.yaml +13 -0
- openocr/configs/dataset/rec/textocr_horizontal.yaml +13 -0
- openocr/configs/dataset/rec/union14m_b.yaml +47 -0
- openocr/configs/dataset/rec/union14m_l_filtered.yaml +35 -0
- openocr/configs/rec/cmer/cmer.yml +127 -0
- openocr/configs/rec/mdiff4str/svtrv2_mdiffdecoder_base.yml +152 -0
- openocr/configs/rec/mdiff4str/svtrv2_mdiffdecoder_small.yml +152 -0
- openocr/configs/rec/unirec/focalsvtr_ardecoder_unirec.yml +114 -0
- openocr/configs/rec/unirec/opendoc_pipeline.yml +105 -0
- openocr/demo_gradio.py +28 -8
- openocr/demo_opendoc.py +572 -0
- openocr/demo_unirec.py +392 -0
- openocr/opendet/losses/__init__.py +5 -7
- openocr/opendet/preprocess/crop_resize.py +2 -1
- openocr/openocr.py +685 -0
- openocr/openrec/losses/__init__.py +8 -3
- openocr/openrec/losses/cmer_loss.py +12 -0
- openocr/openrec/losses/mdiff_loss.py +11 -0
- openocr/openrec/losses/unirec_loss.py +12 -0
- openocr/openrec/metrics/__init__.py +4 -1
- openocr/openrec/metrics/rec_metric_cmer.py +328 -0
- openocr/openrec/modeling/cmer_modeling/modeling_cmer.py +643 -0
- openocr/openrec/modeling/decoders/__init__.py +1 -0
- openocr/openrec/modeling/decoders/ctc_decoder.py +1 -1
- openocr/openrec/modeling/decoders/dan_decoder.py +4 -4
- openocr/openrec/modeling/decoders/dptr_parseq_clip_b_decoder.py +1563 -1398
- openocr/openrec/modeling/decoders/mdiff_decoder.py +587 -0
- openocr/openrec/modeling/decoders/smtr_decoder.py +99 -48
- openocr/openrec/modeling/unirec_modeling/configuration_unirec.py +166 -0
- openocr/openrec/modeling/unirec_modeling/modeling_unirec.py +433 -0
- openocr/openrec/optimizer/__init__.py +4 -3
- openocr/openrec/optimizer/lr.py +49 -0
- openocr/openrec/postprocess/__init__.py +2 -0
- openocr/openrec/postprocess/abinet_postprocess.py +1 -1
- openocr/openrec/postprocess/ar_postprocess.py +1 -1
- openocr/openrec/postprocess/cmer_postprocess.py +86 -0
- openocr/openrec/postprocess/cppd_postprocess.py +1 -1
- openocr/openrec/postprocess/igtr_postprocess.py +1 -1
- openocr/openrec/postprocess/lister_postprocess.py +1 -1
- openocr/openrec/postprocess/mgp_postprocess.py +1 -1
- openocr/openrec/postprocess/nrtr_postprocess.py +2 -2
- openocr/openrec/postprocess/smtr_postprocess.py +1 -1
- openocr/openrec/postprocess/srn_postprocess.py +1 -1
- openocr/openrec/postprocess/unirec_postprocess.py +58 -0
- openocr/openrec/postprocess/visionlan_postprocess.py +1 -1
- openocr/openrec/preprocess/__init__.py +5 -0
- openocr/openrec/preprocess/ce_label_encode.py +1 -1
- openocr/openrec/preprocess/cmer_label_encode.py +1025 -0
- openocr/openrec/preprocess/ctc_label_encode.py +1 -1
- openocr/openrec/preprocess/dptr_label_encode.py +177 -157
- openocr/openrec/preprocess/igtr_label_encode.py +4 -2
- openocr/openrec/preprocess/mdiff_label_encode.py +312 -0
- openocr/openrec/preprocess/rec_aug.py +128 -2
- openocr/openrec/preprocess/resize.py +57 -0
- openocr/openrec/preprocess/unirec_label_encode.py +62 -0
- openocr/tools/data/__init__.py +78 -55
- openocr/tools/data/cmer_web_dataset.py +310 -0
- openocr/tools/data/native_size_dataset.py +753 -0
- openocr/tools/data/native_size_sampler.py +158 -0
- openocr/tools/data/ratio_dataset_tvresize.py +2 -0
- openocr/tools/data/ratio_sampler.py +2 -1
- openocr/tools/download/download_dataset.py +38 -0
- openocr/tools/download/utils.py +28 -0
- openocr/tools/download_example_images.py +236 -0
- openocr/tools/engine/trainer.py +155 -39
- openocr/tools/eval_rec_all_ch.py +2 -2
- openocr/tools/infer_det.py +20 -2
- openocr/tools/infer_doc.py +898 -0
- openocr/tools/infer_doc_onnx.py +1172 -0
- openocr/tools/infer_e2e.py +27 -10
- openocr/tools/infer_rec.py +64 -15
- openocr/tools/infer_unirec_onnx.py +730 -0
- openocr/tools/to_markdown.py +468 -0
- openocr/tools/utils/ckpt.py +17 -5
- openocr/tools/utils/opendoc_onnx_utils/utils.py +1052 -0
- openocr_python-0.1.0.dev0.dist-info/METADATA +324 -0
- {openocr_python-0.0.9.dist-info → openocr_python-0.1.0.dev0.dist-info}/RECORD +89 -45
- {openocr_python-0.0.9.dist-info → openocr_python-0.1.0.dev0.dist-info}/WHEEL +1 -1
- openocr_python-0.1.0.dev0.dist-info/entry_points.txt +2 -0
- openocr_python-0.0.9.dist-info/METADATA +0 -149
- /openocr_python-0.0.9.dist-info/LICENCE → /openocr_python-0.1.0.dev0.dist-info/licenses/LICENSE +0 -0
- {openocr_python-0.0.9.dist-info → openocr_python-0.1.0.dev0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,468 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import os
|
|
3
|
+
import re
|
|
4
|
+
from typing import List, Dict, Any
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def truncate_repeated_tail(s, threshold=20, keep=1):
|
|
8
|
+
"""
|
|
9
|
+
如果字符串尾部重复出现某个元素超过threshold次,则只保留keep个该元素
|
|
10
|
+
|
|
11
|
+
参数:
|
|
12
|
+
s: 输入字符串
|
|
13
|
+
threshold: 重复次数阈值,默认20
|
|
14
|
+
keep: 保留的重复次数,默认5
|
|
15
|
+
|
|
16
|
+
返回:
|
|
17
|
+
处理后的字符串
|
|
18
|
+
"""
|
|
19
|
+
if not s:
|
|
20
|
+
return s
|
|
21
|
+
|
|
22
|
+
# 尝试不同长度的重复模式(从1到合理的最大长度)
|
|
23
|
+
max_pattern_len = min(100, len(s) // threshold)
|
|
24
|
+
|
|
25
|
+
for pattern_len in range(1, max_pattern_len + 1):
|
|
26
|
+
if len(s) < pattern_len:
|
|
27
|
+
break
|
|
28
|
+
|
|
29
|
+
# 提取可能的重复模式
|
|
30
|
+
pattern = s[-pattern_len:]
|
|
31
|
+
|
|
32
|
+
# 从字符串末尾向前计数该模式的重复次数
|
|
33
|
+
count = 0
|
|
34
|
+
pos = len(s)
|
|
35
|
+
|
|
36
|
+
while pos >= pattern_len:
|
|
37
|
+
if s[pos - pattern_len:pos] == pattern:
|
|
38
|
+
count += 1
|
|
39
|
+
pos -= pattern_len
|
|
40
|
+
else:
|
|
41
|
+
break
|
|
42
|
+
|
|
43
|
+
# 如果重复次数超过阈值,进行截断
|
|
44
|
+
if count > threshold:
|
|
45
|
+
# 保留前面的非重复部分 + keep个重复模式
|
|
46
|
+
non_repeat_part = s[:pos]
|
|
47
|
+
kept_repeats = pattern * keep
|
|
48
|
+
# print("截断前 ori:", s)
|
|
49
|
+
# print("截断后 after:", non_repeat_part + kept_repeats)
|
|
50
|
+
return non_repeat_part + kept_repeats
|
|
51
|
+
|
|
52
|
+
# 没有找到超过阈值的重复模式,返回原字符串
|
|
53
|
+
return s
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def extract_table_from_html(html_string):
|
|
57
|
+
"""Extract and clean table tags from HTML string"""
|
|
58
|
+
try:
|
|
59
|
+
table_pattern = re.compile(r'<table.*?>.*?</table>', re.DOTALL)
|
|
60
|
+
tables = table_pattern.findall(html_string)
|
|
61
|
+
tables = [
|
|
62
|
+
re.sub(r'<table[^>]*>', '<table>', table) for table in tables
|
|
63
|
+
]
|
|
64
|
+
# tables = [re.sub(r'>\n', '>', table) for table in tables]
|
|
65
|
+
return '\n'.join(tables)
|
|
66
|
+
except Exception as e:
|
|
67
|
+
print(f'extract_table_from_html error: {str(e)}')
|
|
68
|
+
return f'<table><tr><td>Error extracting table: {str(e)}</td></tr></table>'
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
rules = [
|
|
72
|
+
(r'-<\|sn\|>', ''),
|
|
73
|
+
(r'<\|sn\|>', ''),
|
|
74
|
+
(r'<\|unk\|>', ''),
|
|
75
|
+
(r'\uffff', ''),
|
|
76
|
+
(r'_{4,}', '___'),
|
|
77
|
+
(r'\.{4,}', '...'),
|
|
78
|
+
# (r'(\d) +', r'\1'),
|
|
79
|
+
# (r' +(\d)', r'\1'),
|
|
80
|
+
]
|
|
81
|
+
|
|
82
|
+
# pattern = r"\\(big|Big|bigg|Bigg)\{([^{}])\}"
|
|
83
|
+
pattern = r'\\(big|Big|bigg|Bigg|bigl|bigr|Bigl|Bigr|biggr|biggl|Biggl|Biggr)\{(\\?[{}\[\]\(\)\|])\}'
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def fix_latex_brackets(text: str) -> str:
|
|
87
|
+
return re.sub(pattern, r'\\\1\2', text)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
class MarkdownConverter:
|
|
91
|
+
"""Convert structured recognition results to Markdown format"""
|
|
92
|
+
|
|
93
|
+
def __init__(self):
|
|
94
|
+
# Define heading levels for different section types
|
|
95
|
+
self.heading_levels = {
|
|
96
|
+
'sec_0': '#',
|
|
97
|
+
'sec_1': '##',
|
|
98
|
+
'sec_2': '###',
|
|
99
|
+
'sec_3': '###',
|
|
100
|
+
'sec_4': '###',
|
|
101
|
+
'sec_5': '###',
|
|
102
|
+
}
|
|
103
|
+
|
|
104
|
+
# Define which labels need special handling
|
|
105
|
+
self.special_labels = {
|
|
106
|
+
'sec_0', 'sec_1', 'sec_2', 'sec_3', 'sec_4', 'sec_5', 'list',
|
|
107
|
+
'equ', 'tab', 'fig'
|
|
108
|
+
}
|
|
109
|
+
|
|
110
|
+
# Define replacements for special formulas
|
|
111
|
+
self.replace_dict = {
|
|
112
|
+
'\\bm': '\mathbf ',
|
|
113
|
+
'\eqno': '\quad ',
|
|
114
|
+
'\quad': '\quad ',
|
|
115
|
+
'\leq': '\leq ',
|
|
116
|
+
'\pm': '\pm ',
|
|
117
|
+
'\\varmathbb': '\mathbb ',
|
|
118
|
+
'\in fty': '\infty',
|
|
119
|
+
'\mu': '\mu ',
|
|
120
|
+
'\cdot': '\cdot ',
|
|
121
|
+
'\langle': '\langle ',
|
|
122
|
+
'\pm': '\pm '
|
|
123
|
+
}
|
|
124
|
+
# self.bigpattern = pattern = r"\\(big|Big|bigg|Bigg)\{(\\?[()\[\]{}]|\\langle|\\rangle)|\|\}"
|
|
125
|
+
|
|
126
|
+
def try_remove_newline(self, text: str) -> str:
|
|
127
|
+
try:
|
|
128
|
+
# Preprocess text to handle line breaks
|
|
129
|
+
text = text.strip()
|
|
130
|
+
text = text.replace('-\n', '')
|
|
131
|
+
|
|
132
|
+
# Handle Chinese text line breaks
|
|
133
|
+
def is_chinese(char):
|
|
134
|
+
return '\u4e00' <= char <= '\u9fff'
|
|
135
|
+
|
|
136
|
+
lines = text.split('\n')
|
|
137
|
+
processed_lines = []
|
|
138
|
+
|
|
139
|
+
# Process all lines except the last one
|
|
140
|
+
for i in range(len(lines) - 1):
|
|
141
|
+
current_line = lines[i].strip()
|
|
142
|
+
next_line = lines[i + 1].strip()
|
|
143
|
+
|
|
144
|
+
# Always add the current line, but determine if we need a newline
|
|
145
|
+
if current_line: # If current line is not empty
|
|
146
|
+
if next_line: # If next line is not empty
|
|
147
|
+
# For Chinese text handling
|
|
148
|
+
if is_chinese(current_line[-1]) and is_chinese(
|
|
149
|
+
next_line[0]):
|
|
150
|
+
processed_lines.append(current_line)
|
|
151
|
+
else:
|
|
152
|
+
processed_lines.append(current_line + ' ')
|
|
153
|
+
else:
|
|
154
|
+
# Next line is empty, add current line with newline
|
|
155
|
+
processed_lines.append(current_line + '\n')
|
|
156
|
+
else:
|
|
157
|
+
# Current line is empty, add an empty line
|
|
158
|
+
processed_lines.append('\n')
|
|
159
|
+
|
|
160
|
+
# Add the last line
|
|
161
|
+
if lines and lines[-1].strip():
|
|
162
|
+
processed_lines.append(lines[-1].strip())
|
|
163
|
+
|
|
164
|
+
text = ''.join(processed_lines)
|
|
165
|
+
return text
|
|
166
|
+
|
|
167
|
+
except Exception as e:
|
|
168
|
+
print(f'try_remove_newline error: {str(e)}')
|
|
169
|
+
return text # Return original text on error
|
|
170
|
+
|
|
171
|
+
def _handle_text(self, text: str) -> str:
|
|
172
|
+
"""
|
|
173
|
+
Process regular text content, preserving paragraph structure
|
|
174
|
+
"""
|
|
175
|
+
try:
|
|
176
|
+
if not text:
|
|
177
|
+
return ''
|
|
178
|
+
if text in ['图中没有可识别的文本。', '图中无文本。', '图中没有文本。']:
|
|
179
|
+
return ''
|
|
180
|
+
for rule in rules:
|
|
181
|
+
text = re.sub(rule[0], rule[1], text)
|
|
182
|
+
# Process formulas in text before handling other text processing
|
|
183
|
+
text = self._process_formulas_in_text(text)
|
|
184
|
+
text = text.replace('$\bullet$', '•')
|
|
185
|
+
# rm html table tag
|
|
186
|
+
if '<table>' in text:
|
|
187
|
+
print(text)
|
|
188
|
+
text = re.sub(r'</?(table|tr|th|td|thead|tbody|tfoot)[^>]*>',
|
|
189
|
+
'',
|
|
190
|
+
text,
|
|
191
|
+
flags=re.IGNORECASE)
|
|
192
|
+
text = re.sub(r'\n\s*\n+', '\n', text)
|
|
193
|
+
print(text)
|
|
194
|
+
# text = self.try_remove_newline(text)
|
|
195
|
+
return text
|
|
196
|
+
except Exception as e:
|
|
197
|
+
print(f'_handle_text error: {str(e)}')
|
|
198
|
+
return text # Return original text on error
|
|
199
|
+
|
|
200
|
+
def _process_formulas_in_text(self, text: str) -> str:
|
|
201
|
+
"""
|
|
202
|
+
Process mathematical formulas in text by iteratively finding and replacing formulas.
|
|
203
|
+
- Identify inline and block formulas
|
|
204
|
+
- Replace newlines within formulas with \\
|
|
205
|
+
"""
|
|
206
|
+
try:
|
|
207
|
+
text = text.replace(r'\upmu',
|
|
208
|
+
r'\mu').replace('\(', '$').replace('\)', '$')
|
|
209
|
+
for key, value in self.replace_dict.items():
|
|
210
|
+
text = text.replace(key, value)
|
|
211
|
+
return text
|
|
212
|
+
|
|
213
|
+
except Exception as e:
|
|
214
|
+
print(f'_process_formulas_in_text error: {str(e)}')
|
|
215
|
+
return text # Return original text on error
|
|
216
|
+
|
|
217
|
+
def _remove_newline_in_heading(self, text: str) -> str:
|
|
218
|
+
"""
|
|
219
|
+
Remove newline in heading
|
|
220
|
+
"""
|
|
221
|
+
try:
|
|
222
|
+
# Handle Chinese text line breaks
|
|
223
|
+
def is_chinese(char):
|
|
224
|
+
return '\u4e00' <= char <= '\u9fff'
|
|
225
|
+
|
|
226
|
+
# Check if the text contains Chinese characters
|
|
227
|
+
if any(is_chinese(char) for char in text):
|
|
228
|
+
return text.replace('\n', '')
|
|
229
|
+
else:
|
|
230
|
+
return text.replace('\n', ' ')
|
|
231
|
+
|
|
232
|
+
except Exception as e:
|
|
233
|
+
print(f'_remove_newline_in_heading error: {str(e)}')
|
|
234
|
+
return text
|
|
235
|
+
|
|
236
|
+
def _handle_heading(self, text: str, label: str) -> str:
|
|
237
|
+
"""
|
|
238
|
+
Convert section headings to appropriate markdown format
|
|
239
|
+
"""
|
|
240
|
+
try:
|
|
241
|
+
level = self.heading_levels.get(label, '#')
|
|
242
|
+
text = text.strip()
|
|
243
|
+
text = self._remove_newline_in_heading(text)
|
|
244
|
+
text = self._handle_text(text)
|
|
245
|
+
return f'{level} {text}\n\n'
|
|
246
|
+
|
|
247
|
+
except Exception as e:
|
|
248
|
+
print(f'_handle_heading error: {str(e)}')
|
|
249
|
+
return f'# Error processing heading: {text}\n\n'
|
|
250
|
+
|
|
251
|
+
def _handle_list_item(self, text: str) -> str:
|
|
252
|
+
"""
|
|
253
|
+
Convert list items to markdown list format
|
|
254
|
+
"""
|
|
255
|
+
try:
|
|
256
|
+
return f'- {text.strip()}\n'
|
|
257
|
+
except Exception as e:
|
|
258
|
+
print(f'_handle_list_item error: {str(e)}')
|
|
259
|
+
return f'- Error processing list item: {text}\n'
|
|
260
|
+
|
|
261
|
+
def _handle_figure(self, text: str, section_count: int) -> str:
|
|
262
|
+
"""
|
|
263
|
+
Handle figure content
|
|
264
|
+
"""
|
|
265
|
+
try:
|
|
266
|
+
# Check if it's a file path starting with "figures/"
|
|
267
|
+
if text.startswith('figures/'):
|
|
268
|
+
# Convert to relative path from markdown directory to figures directory
|
|
269
|
+
relative_path = f'../{text}'
|
|
270
|
+
return f'\n\n'
|
|
271
|
+
|
|
272
|
+
# Check if it's already a markdown format image link
|
|
273
|
+
if text.startswith('\n\n'
|
|
280
|
+
elif ';' in text and ',' in text:
|
|
281
|
+
return f'\n\n'
|
|
282
|
+
else:
|
|
283
|
+
# Assume it's raw base64, convert to data URI
|
|
284
|
+
img_format = 'png'
|
|
285
|
+
data_uri = f'data:image/{img_format};base64,{text}'
|
|
286
|
+
return f'\n\n'
|
|
287
|
+
|
|
288
|
+
except Exception as e:
|
|
289
|
+
print(f'_handle_figure error: {str(e)}')
|
|
290
|
+
return f'*[Error processing figure: {str(e)}]*\n\n'
|
|
291
|
+
|
|
292
|
+
def _handle_table(self, text: str) -> str:
|
|
293
|
+
"""
|
|
294
|
+
Convert table content to markdown format
|
|
295
|
+
"""
|
|
296
|
+
try:
|
|
297
|
+
markdown_content = []
|
|
298
|
+
markdown_table = extract_table_from_html(text)
|
|
299
|
+
table_content = markdown_table.replace('<tdcolspan=',
|
|
300
|
+
'<td colspan=')
|
|
301
|
+
table_content = table_content.replace('<tdrowspan=',
|
|
302
|
+
'<td rowspan=')
|
|
303
|
+
table_content = table_content.replace('"colspan=', '" colspan=')
|
|
304
|
+
table_content = re.sub(r'<\|sn\|>', '', table_content)
|
|
305
|
+
table_content = re.sub(r'<\|unk\|>', '', table_content)
|
|
306
|
+
table_content = re.sub(r'\uffff', '', table_content)
|
|
307
|
+
table_content = re.sub(r'_{4,}', '___', table_content)
|
|
308
|
+
table_content = re.sub(r'\.{4,}', '...', table_content)
|
|
309
|
+
|
|
310
|
+
table_content = re.sub(r'</td\s+colspan="[^"]*"\s*>',
|
|
311
|
+
'</td>',
|
|
312
|
+
table_content,
|
|
313
|
+
flags=re.IGNORECASE)
|
|
314
|
+
table_content = re.sub(r'</td\s+rowspan="[^"]*"\s*>',
|
|
315
|
+
'</td>',
|
|
316
|
+
table_content,
|
|
317
|
+
flags=re.IGNORECASE)
|
|
318
|
+
table_content = re.sub(r'</th\s+rowspan="[^"]*"\s*>',
|
|
319
|
+
'</th>',
|
|
320
|
+
table_content,
|
|
321
|
+
flags=re.IGNORECASE)
|
|
322
|
+
table_content = re.sub(r'</th\s+colspan="[^"]*"\s*>',
|
|
323
|
+
'</th>',
|
|
324
|
+
table_content,
|
|
325
|
+
flags=re.IGNORECASE)
|
|
326
|
+
table_content = table_content.replace('\(', '$').replace('\)', '$')
|
|
327
|
+
table_content = table_content.replace('\[',
|
|
328
|
+
'$$').replace('\]', '$$')
|
|
329
|
+
# markdown_table = re.sub(r'>\s*\n+\s*',
|
|
330
|
+
# '>',
|
|
331
|
+
# table_content,
|
|
332
|
+
# flags=re.DOTALL)
|
|
333
|
+
markdown_content.append(table_content + '\n')
|
|
334
|
+
return '\n'.join(markdown_content) + '\n\n'
|
|
335
|
+
|
|
336
|
+
except Exception as e:
|
|
337
|
+
print(f'_handle_table error: {str(e)}')
|
|
338
|
+
return f'*[Error processing table: {str(e)}]*\n\n'
|
|
339
|
+
|
|
340
|
+
def _handle_formula(self, text: str) -> str:
|
|
341
|
+
"""
|
|
342
|
+
Handle formula-specific content
|
|
343
|
+
"""
|
|
344
|
+
try:
|
|
345
|
+
text = text.replace(r'\upmu', r'\mu')
|
|
346
|
+
result = re.sub(r'\\] \(\d+\)\n\n', r'\\]', text)
|
|
347
|
+
result = re.sub(r'<\|sn\|>', '', result)
|
|
348
|
+
result = re.sub(r'<\|unk\|>', '', result)
|
|
349
|
+
result = re.sub(r'\uffff', '', result)
|
|
350
|
+
result = re.sub(r'_{4,}', '___', result)
|
|
351
|
+
result = result.replace('\]\n*\[', '\\\\')
|
|
352
|
+
result = result.replace('\n\n\[', '')
|
|
353
|
+
result = result.replace('\]\n\n', '')
|
|
354
|
+
result = result.replace('\[\n', '')
|
|
355
|
+
result = result.replace('\n\]', '')
|
|
356
|
+
result = result.replace('\]', '')
|
|
357
|
+
result = result.replace('\[', '')
|
|
358
|
+
result = result.replace('\( ', '')
|
|
359
|
+
result = result.replace(' \)', '')
|
|
360
|
+
result = result.replace('\(', '')
|
|
361
|
+
text = result.replace('\)', '')
|
|
362
|
+
text = text.strip('$').rstrip('\ ').replace(r'\upmu', r'\mu')
|
|
363
|
+
for key, value in self.replace_dict.items():
|
|
364
|
+
text = text.replace(key, value)
|
|
365
|
+
processed_text = '$$' + text + '$$'
|
|
366
|
+
processed_text = processed_text.replace('\n', '\\\\\n')
|
|
367
|
+
processed_text = fix_latex_brackets(processed_text)
|
|
368
|
+
|
|
369
|
+
# 替换为 \big( 或 \Big[ 等
|
|
370
|
+
# processed_text = re.sub(self.bigpattern, r"\\\1\2", processed_text)
|
|
371
|
+
return f'{processed_text}\n\n'
|
|
372
|
+
|
|
373
|
+
except Exception as e:
|
|
374
|
+
print(f'_handle_formula error: {str(e)}')
|
|
375
|
+
return f'*[Error processing formula: {str(e)}]*\n\n'
|
|
376
|
+
|
|
377
|
+
def convert(self, recognition_results: List[Dict[str, Any]]) -> str:
|
|
378
|
+
"""
|
|
379
|
+
Convert recognition results to markdown format
|
|
380
|
+
"""
|
|
381
|
+
try:
|
|
382
|
+
markdown_content = []
|
|
383
|
+
|
|
384
|
+
# {'text', 'header', 'number', 'figure_title', 'content', 'chart', 'footer',
|
|
385
|
+
# 'vision_footnote', 'aside_text', 'inline_formula', 'algorithm', 'table',
|
|
386
|
+
# 'image', 'display_formula', 'footer_image', 'vertical_text',
|
|
387
|
+
# 'header_image', 'paragraph_title', 'reference_content', 'doc_title', 'footnote', 'seal', 'abstract'}
|
|
388
|
+
# block_label_set_ignore = set(['chart', 'image', 'footer_image', 'header_image', 'seal'])
|
|
389
|
+
|
|
390
|
+
for section_count, result in enumerate(recognition_results):
|
|
391
|
+
try:
|
|
392
|
+
label = result.get('label', '')
|
|
393
|
+
text = result.get('text_unirec', '').strip()
|
|
394
|
+
|
|
395
|
+
# Skip empty text
|
|
396
|
+
if not text:
|
|
397
|
+
continue
|
|
398
|
+
if label in [
|
|
399
|
+
'header', 'header_image', 'footer_image', 'footer',
|
|
400
|
+
'aside_text', 'inline_formula', 'number'
|
|
401
|
+
]:
|
|
402
|
+
continue
|
|
403
|
+
if label == 'number' and (section_count == 0
|
|
404
|
+
or section_count
|
|
405
|
+
== len(recognition_results) - 1):
|
|
406
|
+
continue
|
|
407
|
+
|
|
408
|
+
text = truncate_repeated_tail(text)
|
|
409
|
+
if label == 'doc_title':
|
|
410
|
+
label = 'sec_0'
|
|
411
|
+
elif label == 'paragraph_title':
|
|
412
|
+
label = 'sec_1'
|
|
413
|
+
# Handle different content types
|
|
414
|
+
if label in {
|
|
415
|
+
'sec_0', 'sec_1', 'sec_2', 'sec_3', 'sec_4',
|
|
416
|
+
'sec_5'
|
|
417
|
+
}:
|
|
418
|
+
markdown_content.append(
|
|
419
|
+
self._handle_heading(text, label))
|
|
420
|
+
elif label in ['image', 'chart', 'seal']:
|
|
421
|
+
markdown_content.append(
|
|
422
|
+
self._handle_figure(text, section_count))
|
|
423
|
+
elif label == 'table':
|
|
424
|
+
markdown_content.append(self._handle_table(text))
|
|
425
|
+
elif label in ['display_formula']:
|
|
426
|
+
markdown_content.append(self._handle_formula(text))
|
|
427
|
+
elif label == 'list':
|
|
428
|
+
markdown_content.append(self._handle_list_item(text))
|
|
429
|
+
elif label == 'code':
|
|
430
|
+
markdown_content.append(f'```bash\n{text}\n```\n\n')
|
|
431
|
+
else:
|
|
432
|
+
# Handle regular text (paragraphs, etc.)
|
|
433
|
+
processed_text = self._handle_text(text)
|
|
434
|
+
markdown_content.append(f'{processed_text}\n\n')
|
|
435
|
+
|
|
436
|
+
except Exception as e:
|
|
437
|
+
print(f'Error processing item {section_count}: {str(e)}')
|
|
438
|
+
# Add a placeholder for the failed item
|
|
439
|
+
markdown_content.append(
|
|
440
|
+
f'*[Error processing content]*\n\n')
|
|
441
|
+
|
|
442
|
+
# Join all content and apply post-processing
|
|
443
|
+
result = ''.join(markdown_content)
|
|
444
|
+
return result
|
|
445
|
+
|
|
446
|
+
except Exception as e:
|
|
447
|
+
print(f'convert error: {str(e)}')
|
|
448
|
+
return f'Error generating markdown content: {str(e)}'
|
|
449
|
+
|
|
450
|
+
|
|
451
|
+
if __name__ == '__main__':
|
|
452
|
+
|
|
453
|
+
markdown_converter = MarkdownConverter()
|
|
454
|
+
img_path = f'./OmniDocBench/images'
|
|
455
|
+
save_res_path = './rec_results'
|
|
456
|
+
img_path_list = os.listdir(img_path)
|
|
457
|
+
md_save_path = f'{save_res_path}/markdown_results'
|
|
458
|
+
os.makedirs(md_save_path, exist_ok=True)
|
|
459
|
+
for img_name in img_path_list:
|
|
460
|
+
file_path = os.path.join(save_res_path, img_name[:-4] + '_res.json')
|
|
461
|
+
if not os.path.exists(file_path):
|
|
462
|
+
continue
|
|
463
|
+
with open(file_path, 'r') as f:
|
|
464
|
+
json_data = json.load(f)
|
|
465
|
+
block_list = json_data['boxes']
|
|
466
|
+
markdown_content = markdown_converter.convert(block_list)
|
|
467
|
+
with open(os.path.join(md_save_path, img_name[:-4] + '.md'), 'w') as f:
|
|
468
|
+
f.write(markdown_content)
|
openocr/tools/utils/ckpt.py
CHANGED
|
@@ -4,7 +4,6 @@ import torch
|
|
|
4
4
|
|
|
5
5
|
from tools.utils.logging import get_logger
|
|
6
6
|
|
|
7
|
-
|
|
8
7
|
def save_ckpt(
|
|
9
8
|
model,
|
|
10
9
|
cfg,
|
|
@@ -79,9 +78,22 @@ def load_ckpt(model, cfg, optimizer=None, lr_scheduler=None, logger=None):
|
|
|
79
78
|
|
|
80
79
|
|
|
81
80
|
def load_pretrained_params(model, pretrained_model, logger):
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
81
|
+
if pretrained_model.endswith(".safetensors"):
|
|
82
|
+
from safetensors.torch import load_file
|
|
83
|
+
logger.info(f"Loading weights from safetensors: {pretrained_model}")
|
|
84
|
+
checkpoint = load_file(pretrained_model)
|
|
85
|
+
else:
|
|
86
|
+
logger.info(f"Loading weights using torch.load: {pretrained_model}")
|
|
87
|
+
checkpoint = torch.load(pretrained_model, map_location=torch.device("cpu"))
|
|
88
|
+
|
|
89
|
+
if "state_dict" in checkpoint:
|
|
90
|
+
state_dict = checkpoint["state_dict"]
|
|
91
|
+
else:
|
|
92
|
+
state_dict = checkpoint
|
|
93
|
+
|
|
94
|
+
model.load_state_dict(state_dict, strict=False)
|
|
95
|
+
model_keys = model.state_dict().keys()
|
|
96
|
+
for name in model_keys:
|
|
97
|
+
if name not in state_dict:
|
|
86
98
|
logger.info(f"{name} is not in pretrained model")
|
|
87
99
|
|