openocr-python 0.0.9__py3-none-any.whl → 0.1.0.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (90) hide show
  1. openocr/__init__.py +35 -1
  2. openocr/configs/dataset/rec/evaluation.yaml +41 -0
  3. openocr/configs/dataset/rec/ltb.yaml +9 -0
  4. openocr/configs/dataset/rec/mjsynth.yaml +11 -0
  5. openocr/configs/dataset/rec/openvino.yaml +25 -0
  6. openocr/configs/dataset/rec/ost.yaml +17 -0
  7. openocr/configs/dataset/rec/synthtext.yaml +7 -0
  8. openocr/configs/dataset/rec/test.yaml +77 -0
  9. openocr/configs/dataset/rec/textocr.yaml +13 -0
  10. openocr/configs/dataset/rec/textocr_horizontal.yaml +13 -0
  11. openocr/configs/dataset/rec/union14m_b.yaml +47 -0
  12. openocr/configs/dataset/rec/union14m_l_filtered.yaml +35 -0
  13. openocr/configs/rec/cmer/cmer.yml +127 -0
  14. openocr/configs/rec/mdiff4str/svtrv2_mdiffdecoder_base.yml +152 -0
  15. openocr/configs/rec/mdiff4str/svtrv2_mdiffdecoder_small.yml +152 -0
  16. openocr/configs/rec/unirec/focalsvtr_ardecoder_unirec.yml +114 -0
  17. openocr/configs/rec/unirec/opendoc_pipeline.yml +105 -0
  18. openocr/demo_gradio.py +28 -8
  19. openocr/demo_opendoc.py +572 -0
  20. openocr/demo_unirec.py +392 -0
  21. openocr/opendet/losses/__init__.py +5 -7
  22. openocr/opendet/preprocess/crop_resize.py +2 -1
  23. openocr/openocr.py +685 -0
  24. openocr/openrec/losses/__init__.py +8 -3
  25. openocr/openrec/losses/cmer_loss.py +12 -0
  26. openocr/openrec/losses/mdiff_loss.py +11 -0
  27. openocr/openrec/losses/unirec_loss.py +12 -0
  28. openocr/openrec/metrics/__init__.py +4 -1
  29. openocr/openrec/metrics/rec_metric_cmer.py +328 -0
  30. openocr/openrec/modeling/cmer_modeling/modeling_cmer.py +643 -0
  31. openocr/openrec/modeling/decoders/__init__.py +1 -0
  32. openocr/openrec/modeling/decoders/ctc_decoder.py +1 -1
  33. openocr/openrec/modeling/decoders/dan_decoder.py +4 -4
  34. openocr/openrec/modeling/decoders/dptr_parseq_clip_b_decoder.py +1563 -1398
  35. openocr/openrec/modeling/decoders/mdiff_decoder.py +587 -0
  36. openocr/openrec/modeling/decoders/smtr_decoder.py +99 -48
  37. openocr/openrec/modeling/unirec_modeling/configuration_unirec.py +166 -0
  38. openocr/openrec/modeling/unirec_modeling/modeling_unirec.py +433 -0
  39. openocr/openrec/optimizer/__init__.py +4 -3
  40. openocr/openrec/optimizer/lr.py +49 -0
  41. openocr/openrec/postprocess/__init__.py +2 -0
  42. openocr/openrec/postprocess/abinet_postprocess.py +1 -1
  43. openocr/openrec/postprocess/ar_postprocess.py +1 -1
  44. openocr/openrec/postprocess/cmer_postprocess.py +86 -0
  45. openocr/openrec/postprocess/cppd_postprocess.py +1 -1
  46. openocr/openrec/postprocess/igtr_postprocess.py +1 -1
  47. openocr/openrec/postprocess/lister_postprocess.py +1 -1
  48. openocr/openrec/postprocess/mgp_postprocess.py +1 -1
  49. openocr/openrec/postprocess/nrtr_postprocess.py +2 -2
  50. openocr/openrec/postprocess/smtr_postprocess.py +1 -1
  51. openocr/openrec/postprocess/srn_postprocess.py +1 -1
  52. openocr/openrec/postprocess/unirec_postprocess.py +58 -0
  53. openocr/openrec/postprocess/visionlan_postprocess.py +1 -1
  54. openocr/openrec/preprocess/__init__.py +5 -0
  55. openocr/openrec/preprocess/ce_label_encode.py +1 -1
  56. openocr/openrec/preprocess/cmer_label_encode.py +1025 -0
  57. openocr/openrec/preprocess/ctc_label_encode.py +1 -1
  58. openocr/openrec/preprocess/dptr_label_encode.py +177 -157
  59. openocr/openrec/preprocess/igtr_label_encode.py +4 -2
  60. openocr/openrec/preprocess/mdiff_label_encode.py +312 -0
  61. openocr/openrec/preprocess/rec_aug.py +128 -2
  62. openocr/openrec/preprocess/resize.py +57 -0
  63. openocr/openrec/preprocess/unirec_label_encode.py +62 -0
  64. openocr/tools/data/__init__.py +78 -55
  65. openocr/tools/data/cmer_web_dataset.py +310 -0
  66. openocr/tools/data/native_size_dataset.py +753 -0
  67. openocr/tools/data/native_size_sampler.py +158 -0
  68. openocr/tools/data/ratio_dataset_tvresize.py +2 -0
  69. openocr/tools/data/ratio_sampler.py +2 -1
  70. openocr/tools/download/download_dataset.py +38 -0
  71. openocr/tools/download/utils.py +28 -0
  72. openocr/tools/download_example_images.py +236 -0
  73. openocr/tools/engine/trainer.py +155 -39
  74. openocr/tools/eval_rec_all_ch.py +2 -2
  75. openocr/tools/infer_det.py +20 -2
  76. openocr/tools/infer_doc.py +898 -0
  77. openocr/tools/infer_doc_onnx.py +1172 -0
  78. openocr/tools/infer_e2e.py +27 -10
  79. openocr/tools/infer_rec.py +64 -15
  80. openocr/tools/infer_unirec_onnx.py +730 -0
  81. openocr/tools/to_markdown.py +468 -0
  82. openocr/tools/utils/ckpt.py +17 -5
  83. openocr/tools/utils/opendoc_onnx_utils/utils.py +1052 -0
  84. openocr_python-0.1.0.dev0.dist-info/METADATA +324 -0
  85. {openocr_python-0.0.9.dist-info → openocr_python-0.1.0.dev0.dist-info}/RECORD +89 -45
  86. {openocr_python-0.0.9.dist-info → openocr_python-0.1.0.dev0.dist-info}/WHEEL +1 -1
  87. openocr_python-0.1.0.dev0.dist-info/entry_points.txt +2 -0
  88. openocr_python-0.0.9.dist-info/METADATA +0 -149
  89. /openocr_python-0.0.9.dist-info/LICENCE → /openocr_python-0.1.0.dev0.dist-info/licenses/LICENSE +0 -0
  90. {openocr_python-0.0.9.dist-info → openocr_python-0.1.0.dev0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,468 @@
1
+ import json
2
+ import os
3
+ import re
4
+ from typing import List, Dict, Any
5
+
6
+
7
+ def truncate_repeated_tail(s, threshold=20, keep=1):
8
+ """
9
+ 如果字符串尾部重复出现某个元素超过threshold次,则只保留keep个该元素
10
+
11
+ 参数:
12
+ s: 输入字符串
13
+ threshold: 重复次数阈值,默认20
14
+ keep: 保留的重复次数,默认5
15
+
16
+ 返回:
17
+ 处理后的字符串
18
+ """
19
+ if not s:
20
+ return s
21
+
22
+ # 尝试不同长度的重复模式(从1到合理的最大长度)
23
+ max_pattern_len = min(100, len(s) // threshold)
24
+
25
+ for pattern_len in range(1, max_pattern_len + 1):
26
+ if len(s) < pattern_len:
27
+ break
28
+
29
+ # 提取可能的重复模式
30
+ pattern = s[-pattern_len:]
31
+
32
+ # 从字符串末尾向前计数该模式的重复次数
33
+ count = 0
34
+ pos = len(s)
35
+
36
+ while pos >= pattern_len:
37
+ if s[pos - pattern_len:pos] == pattern:
38
+ count += 1
39
+ pos -= pattern_len
40
+ else:
41
+ break
42
+
43
+ # 如果重复次数超过阈值,进行截断
44
+ if count > threshold:
45
+ # 保留前面的非重复部分 + keep个重复模式
46
+ non_repeat_part = s[:pos]
47
+ kept_repeats = pattern * keep
48
+ # print("截断前 ori:", s)
49
+ # print("截断后 after:", non_repeat_part + kept_repeats)
50
+ return non_repeat_part + kept_repeats
51
+
52
+ # 没有找到超过阈值的重复模式,返回原字符串
53
+ return s
54
+
55
+
56
+ def extract_table_from_html(html_string):
57
+ """Extract and clean table tags from HTML string"""
58
+ try:
59
+ table_pattern = re.compile(r'<table.*?>.*?</table>', re.DOTALL)
60
+ tables = table_pattern.findall(html_string)
61
+ tables = [
62
+ re.sub(r'<table[^>]*>', '<table>', table) for table in tables
63
+ ]
64
+ # tables = [re.sub(r'>\n', '>', table) for table in tables]
65
+ return '\n'.join(tables)
66
+ except Exception as e:
67
+ print(f'extract_table_from_html error: {str(e)}')
68
+ return f'<table><tr><td>Error extracting table: {str(e)}</td></tr></table>'
69
+
70
+
71
+ rules = [
72
+ (r'-<\|sn\|>', ''),
73
+ (r'<\|sn\|>', ''),
74
+ (r'<\|unk\|>', ''),
75
+ (r'\uffff', ''),
76
+ (r'_{4,}', '___'),
77
+ (r'\.{4,}', '...'),
78
+ # (r'(\d) +', r'\1'),
79
+ # (r' +(\d)', r'\1'),
80
+ ]
81
+
82
+ # pattern = r"\\(big|Big|bigg|Bigg)\{([^{}])\}"
83
+ pattern = r'\\(big|Big|bigg|Bigg|bigl|bigr|Bigl|Bigr|biggr|biggl|Biggl|Biggr)\{(\\?[{}\[\]\(\)\|])\}'
84
+
85
+
86
+ def fix_latex_brackets(text: str) -> str:
87
+ return re.sub(pattern, r'\\\1\2', text)
88
+
89
+
90
+ class MarkdownConverter:
91
+ """Convert structured recognition results to Markdown format"""
92
+
93
+ def __init__(self):
94
+ # Define heading levels for different section types
95
+ self.heading_levels = {
96
+ 'sec_0': '#',
97
+ 'sec_1': '##',
98
+ 'sec_2': '###',
99
+ 'sec_3': '###',
100
+ 'sec_4': '###',
101
+ 'sec_5': '###',
102
+ }
103
+
104
+ # Define which labels need special handling
105
+ self.special_labels = {
106
+ 'sec_0', 'sec_1', 'sec_2', 'sec_3', 'sec_4', 'sec_5', 'list',
107
+ 'equ', 'tab', 'fig'
108
+ }
109
+
110
+ # Define replacements for special formulas
111
+ self.replace_dict = {
112
+ '\\bm': '\mathbf ',
113
+ '\eqno': '\quad ',
114
+ '\quad': '\quad ',
115
+ '\leq': '\leq ',
116
+ '\pm': '\pm ',
117
+ '\\varmathbb': '\mathbb ',
118
+ '\in fty': '\infty',
119
+ '\mu': '\mu ',
120
+ '\cdot': '\cdot ',
121
+ '\langle': '\langle ',
122
+ '\pm': '\pm '
123
+ }
124
+ # self.bigpattern = pattern = r"\\(big|Big|bigg|Bigg)\{(\\?[()\[\]{}]|\\langle|\\rangle)|\|\}"
125
+
126
+ def try_remove_newline(self, text: str) -> str:
127
+ try:
128
+ # Preprocess text to handle line breaks
129
+ text = text.strip()
130
+ text = text.replace('-\n', '')
131
+
132
+ # Handle Chinese text line breaks
133
+ def is_chinese(char):
134
+ return '\u4e00' <= char <= '\u9fff'
135
+
136
+ lines = text.split('\n')
137
+ processed_lines = []
138
+
139
+ # Process all lines except the last one
140
+ for i in range(len(lines) - 1):
141
+ current_line = lines[i].strip()
142
+ next_line = lines[i + 1].strip()
143
+
144
+ # Always add the current line, but determine if we need a newline
145
+ if current_line: # If current line is not empty
146
+ if next_line: # If next line is not empty
147
+ # For Chinese text handling
148
+ if is_chinese(current_line[-1]) and is_chinese(
149
+ next_line[0]):
150
+ processed_lines.append(current_line)
151
+ else:
152
+ processed_lines.append(current_line + ' ')
153
+ else:
154
+ # Next line is empty, add current line with newline
155
+ processed_lines.append(current_line + '\n')
156
+ else:
157
+ # Current line is empty, add an empty line
158
+ processed_lines.append('\n')
159
+
160
+ # Add the last line
161
+ if lines and lines[-1].strip():
162
+ processed_lines.append(lines[-1].strip())
163
+
164
+ text = ''.join(processed_lines)
165
+ return text
166
+
167
+ except Exception as e:
168
+ print(f'try_remove_newline error: {str(e)}')
169
+ return text # Return original text on error
170
+
171
+ def _handle_text(self, text: str) -> str:
172
+ """
173
+ Process regular text content, preserving paragraph structure
174
+ """
175
+ try:
176
+ if not text:
177
+ return ''
178
+ if text in ['图中没有可识别的文本。', '图中无文本。', '图中没有文本。']:
179
+ return ''
180
+ for rule in rules:
181
+ text = re.sub(rule[0], rule[1], text)
182
+ # Process formulas in text before handling other text processing
183
+ text = self._process_formulas_in_text(text)
184
+ text = text.replace('$\bullet$', '•')
185
+ # rm html table tag
186
+ if '<table>' in text:
187
+ print(text)
188
+ text = re.sub(r'</?(table|tr|th|td|thead|tbody|tfoot)[^>]*>',
189
+ '',
190
+ text,
191
+ flags=re.IGNORECASE)
192
+ text = re.sub(r'\n\s*\n+', '\n', text)
193
+ print(text)
194
+ # text = self.try_remove_newline(text)
195
+ return text
196
+ except Exception as e:
197
+ print(f'_handle_text error: {str(e)}')
198
+ return text # Return original text on error
199
+
200
+ def _process_formulas_in_text(self, text: str) -> str:
201
+ """
202
+ Process mathematical formulas in text by iteratively finding and replacing formulas.
203
+ - Identify inline and block formulas
204
+ - Replace newlines within formulas with \\
205
+ """
206
+ try:
207
+ text = text.replace(r'\upmu',
208
+ r'\mu').replace('\(', '$').replace('\)', '$')
209
+ for key, value in self.replace_dict.items():
210
+ text = text.replace(key, value)
211
+ return text
212
+
213
+ except Exception as e:
214
+ print(f'_process_formulas_in_text error: {str(e)}')
215
+ return text # Return original text on error
216
+
217
+ def _remove_newline_in_heading(self, text: str) -> str:
218
+ """
219
+ Remove newline in heading
220
+ """
221
+ try:
222
+ # Handle Chinese text line breaks
223
+ def is_chinese(char):
224
+ return '\u4e00' <= char <= '\u9fff'
225
+
226
+ # Check if the text contains Chinese characters
227
+ if any(is_chinese(char) for char in text):
228
+ return text.replace('\n', '')
229
+ else:
230
+ return text.replace('\n', ' ')
231
+
232
+ except Exception as e:
233
+ print(f'_remove_newline_in_heading error: {str(e)}')
234
+ return text
235
+
236
+ def _handle_heading(self, text: str, label: str) -> str:
237
+ """
238
+ Convert section headings to appropriate markdown format
239
+ """
240
+ try:
241
+ level = self.heading_levels.get(label, '#')
242
+ text = text.strip()
243
+ text = self._remove_newline_in_heading(text)
244
+ text = self._handle_text(text)
245
+ return f'{level} {text}\n\n'
246
+
247
+ except Exception as e:
248
+ print(f'_handle_heading error: {str(e)}')
249
+ return f'# Error processing heading: {text}\n\n'
250
+
251
+ def _handle_list_item(self, text: str) -> str:
252
+ """
253
+ Convert list items to markdown list format
254
+ """
255
+ try:
256
+ return f'- {text.strip()}\n'
257
+ except Exception as e:
258
+ print(f'_handle_list_item error: {str(e)}')
259
+ return f'- Error processing list item: {text}\n'
260
+
261
+ def _handle_figure(self, text: str, section_count: int) -> str:
262
+ """
263
+ Handle figure content
264
+ """
265
+ try:
266
+ # Check if it's a file path starting with "figures/"
267
+ if text.startswith('figures/'):
268
+ # Convert to relative path from markdown directory to figures directory
269
+ relative_path = f'../{text}'
270
+ return f'![Figure {section_count}]({relative_path})\n\n'
271
+
272
+ # Check if it's already a markdown format image link
273
+ if text.startswith('!['):
274
+ # Already in markdown format, return directly
275
+ return f'{text}\n\n'
276
+
277
+ # If it's still base64 format, maintain original logic
278
+ if text.startswith('data:image/'):
279
+ return f'![Figure {section_count}]({text})\n\n'
280
+ elif ';' in text and ',' in text:
281
+ return f'![Figure {section_count}]({text})\n\n'
282
+ else:
283
+ # Assume it's raw base64, convert to data URI
284
+ img_format = 'png'
285
+ data_uri = f'data:image/{img_format};base64,{text}'
286
+ return f'![Figure {section_count}]({data_uri})\n\n'
287
+
288
+ except Exception as e:
289
+ print(f'_handle_figure error: {str(e)}')
290
+ return f'*[Error processing figure: {str(e)}]*\n\n'
291
+
292
+ def _handle_table(self, text: str) -> str:
293
+ """
294
+ Convert table content to markdown format
295
+ """
296
+ try:
297
+ markdown_content = []
298
+ markdown_table = extract_table_from_html(text)
299
+ table_content = markdown_table.replace('<tdcolspan=',
300
+ '<td colspan=')
301
+ table_content = table_content.replace('<tdrowspan=',
302
+ '<td rowspan=')
303
+ table_content = table_content.replace('"colspan=', '" colspan=')
304
+ table_content = re.sub(r'<\|sn\|>', '', table_content)
305
+ table_content = re.sub(r'<\|unk\|>', '', table_content)
306
+ table_content = re.sub(r'\uffff', '', table_content)
307
+ table_content = re.sub(r'_{4,}', '___', table_content)
308
+ table_content = re.sub(r'\.{4,}', '...', table_content)
309
+
310
+ table_content = re.sub(r'</td\s+colspan="[^"]*"\s*>',
311
+ '</td>',
312
+ table_content,
313
+ flags=re.IGNORECASE)
314
+ table_content = re.sub(r'</td\s+rowspan="[^"]*"\s*>',
315
+ '</td>',
316
+ table_content,
317
+ flags=re.IGNORECASE)
318
+ table_content = re.sub(r'</th\s+rowspan="[^"]*"\s*>',
319
+ '</th>',
320
+ table_content,
321
+ flags=re.IGNORECASE)
322
+ table_content = re.sub(r'</th\s+colspan="[^"]*"\s*>',
323
+ '</th>',
324
+ table_content,
325
+ flags=re.IGNORECASE)
326
+ table_content = table_content.replace('\(', '$').replace('\)', '$')
327
+ table_content = table_content.replace('\[',
328
+ '$$').replace('\]', '$$')
329
+ # markdown_table = re.sub(r'>\s*\n+\s*',
330
+ # '>',
331
+ # table_content,
332
+ # flags=re.DOTALL)
333
+ markdown_content.append(table_content + '\n')
334
+ return '\n'.join(markdown_content) + '\n\n'
335
+
336
+ except Exception as e:
337
+ print(f'_handle_table error: {str(e)}')
338
+ return f'*[Error processing table: {str(e)}]*\n\n'
339
+
340
+ def _handle_formula(self, text: str) -> str:
341
+ """
342
+ Handle formula-specific content
343
+ """
344
+ try:
345
+ text = text.replace(r'\upmu', r'\mu')
346
+ result = re.sub(r'\\] \(\d+\)\n\n', r'\\]', text)
347
+ result = re.sub(r'<\|sn\|>', '', result)
348
+ result = re.sub(r'<\|unk\|>', '', result)
349
+ result = re.sub(r'\uffff', '', result)
350
+ result = re.sub(r'_{4,}', '___', result)
351
+ result = result.replace('\]\n*\[', '\\\\')
352
+ result = result.replace('\n\n\[', '')
353
+ result = result.replace('\]\n\n', '')
354
+ result = result.replace('\[\n', '')
355
+ result = result.replace('\n\]', '')
356
+ result = result.replace('\]', '')
357
+ result = result.replace('\[', '')
358
+ result = result.replace('\( ', '')
359
+ result = result.replace(' \)', '')
360
+ result = result.replace('\(', '')
361
+ text = result.replace('\)', '')
362
+ text = text.strip('$').rstrip('\ ').replace(r'\upmu', r'\mu')
363
+ for key, value in self.replace_dict.items():
364
+ text = text.replace(key, value)
365
+ processed_text = '$$' + text + '$$'
366
+ processed_text = processed_text.replace('\n', '\\\\\n')
367
+ processed_text = fix_latex_brackets(processed_text)
368
+
369
+ # 替换为 \big( 或 \Big[ 等
370
+ # processed_text = re.sub(self.bigpattern, r"\\\1\2", processed_text)
371
+ return f'{processed_text}\n\n'
372
+
373
+ except Exception as e:
374
+ print(f'_handle_formula error: {str(e)}')
375
+ return f'*[Error processing formula: {str(e)}]*\n\n'
376
+
377
+ def convert(self, recognition_results: List[Dict[str, Any]]) -> str:
378
+ """
379
+ Convert recognition results to markdown format
380
+ """
381
+ try:
382
+ markdown_content = []
383
+
384
+ # {'text', 'header', 'number', 'figure_title', 'content', 'chart', 'footer',
385
+ # 'vision_footnote', 'aside_text', 'inline_formula', 'algorithm', 'table',
386
+ # 'image', 'display_formula', 'footer_image', 'vertical_text',
387
+ # 'header_image', 'paragraph_title', 'reference_content', 'doc_title', 'footnote', 'seal', 'abstract'}
388
+ # block_label_set_ignore = set(['chart', 'image', 'footer_image', 'header_image', 'seal'])
389
+
390
+ for section_count, result in enumerate(recognition_results):
391
+ try:
392
+ label = result.get('label', '')
393
+ text = result.get('text_unirec', '').strip()
394
+
395
+ # Skip empty text
396
+ if not text:
397
+ continue
398
+ if label in [
399
+ 'header', 'header_image', 'footer_image', 'footer',
400
+ 'aside_text', 'inline_formula', 'number'
401
+ ]:
402
+ continue
403
+ if label == 'number' and (section_count == 0
404
+ or section_count
405
+ == len(recognition_results) - 1):
406
+ continue
407
+
408
+ text = truncate_repeated_tail(text)
409
+ if label == 'doc_title':
410
+ label = 'sec_0'
411
+ elif label == 'paragraph_title':
412
+ label = 'sec_1'
413
+ # Handle different content types
414
+ if label in {
415
+ 'sec_0', 'sec_1', 'sec_2', 'sec_3', 'sec_4',
416
+ 'sec_5'
417
+ }:
418
+ markdown_content.append(
419
+ self._handle_heading(text, label))
420
+ elif label in ['image', 'chart', 'seal']:
421
+ markdown_content.append(
422
+ self._handle_figure(text, section_count))
423
+ elif label == 'table':
424
+ markdown_content.append(self._handle_table(text))
425
+ elif label in ['display_formula']:
426
+ markdown_content.append(self._handle_formula(text))
427
+ elif label == 'list':
428
+ markdown_content.append(self._handle_list_item(text))
429
+ elif label == 'code':
430
+ markdown_content.append(f'```bash\n{text}\n```\n\n')
431
+ else:
432
+ # Handle regular text (paragraphs, etc.)
433
+ processed_text = self._handle_text(text)
434
+ markdown_content.append(f'{processed_text}\n\n')
435
+
436
+ except Exception as e:
437
+ print(f'Error processing item {section_count}: {str(e)}')
438
+ # Add a placeholder for the failed item
439
+ markdown_content.append(
440
+ f'*[Error processing content]*\n\n')
441
+
442
+ # Join all content and apply post-processing
443
+ result = ''.join(markdown_content)
444
+ return result
445
+
446
+ except Exception as e:
447
+ print(f'convert error: {str(e)}')
448
+ return f'Error generating markdown content: {str(e)}'
449
+
450
+
451
+ if __name__ == '__main__':
452
+
453
+ markdown_converter = MarkdownConverter()
454
+ img_path = f'./OmniDocBench/images'
455
+ save_res_path = './rec_results'
456
+ img_path_list = os.listdir(img_path)
457
+ md_save_path = f'{save_res_path}/markdown_results'
458
+ os.makedirs(md_save_path, exist_ok=True)
459
+ for img_name in img_path_list:
460
+ file_path = os.path.join(save_res_path, img_name[:-4] + '_res.json')
461
+ if not os.path.exists(file_path):
462
+ continue
463
+ with open(file_path, 'r') as f:
464
+ json_data = json.load(f)
465
+ block_list = json_data['boxes']
466
+ markdown_content = markdown_converter.convert(block_list)
467
+ with open(os.path.join(md_save_path, img_name[:-4] + '.md'), 'w') as f:
468
+ f.write(markdown_content)
@@ -4,7 +4,6 @@ import torch
4
4
 
5
5
  from tools.utils.logging import get_logger
6
6
 
7
-
8
7
  def save_ckpt(
9
8
  model,
10
9
  cfg,
@@ -79,9 +78,22 @@ def load_ckpt(model, cfg, optimizer=None, lr_scheduler=None, logger=None):
79
78
 
80
79
 
81
80
  def load_pretrained_params(model, pretrained_model, logger):
82
- checkpoint = torch.load(pretrained_model, map_location=torch.device("cpu"))
83
- model.load_state_dict(checkpoint["state_dict"], strict=False)
84
- for name in model.state_dict().keys():
85
- if name not in checkpoint["state_dict"]:
81
+ if pretrained_model.endswith(".safetensors"):
82
+ from safetensors.torch import load_file
83
+ logger.info(f"Loading weights from safetensors: {pretrained_model}")
84
+ checkpoint = load_file(pretrained_model)
85
+ else:
86
+ logger.info(f"Loading weights using torch.load: {pretrained_model}")
87
+ checkpoint = torch.load(pretrained_model, map_location=torch.device("cpu"))
88
+
89
+ if "state_dict" in checkpoint:
90
+ state_dict = checkpoint["state_dict"]
91
+ else:
92
+ state_dict = checkpoint
93
+
94
+ model.load_state_dict(state_dict, strict=False)
95
+ model_keys = model.state_dict().keys()
96
+ for name in model_keys:
97
+ if name not in state_dict:
86
98
  logger.info(f"{name} is not in pretrained model")
87
99