yomitoku 0.4.1__py3-none-any.whl → 0.7.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. yomitoku/base.py +1 -1
  2. yomitoku/cli/main.py +219 -27
  3. yomitoku/configs/__init__.py +2 -0
  4. yomitoku/configs/cfg_text_detector_dbnet.py +1 -1
  5. yomitoku/configs/cfg_text_recognizer_parseq_small.py +51 -0
  6. yomitoku/data/functions.py +48 -23
  7. yomitoku/document_analyzer.py +243 -41
  8. yomitoku/export/__init__.py +18 -5
  9. yomitoku/export/export_csv.py +71 -2
  10. yomitoku/export/export_html.py +46 -12
  11. yomitoku/export/export_json.py +66 -3
  12. yomitoku/export/export_markdown.py +42 -6
  13. yomitoku/layout_analyzer.py +2 -9
  14. yomitoku/layout_parser.py +58 -4
  15. yomitoku/models/dbnet_plus.py +13 -39
  16. yomitoku/models/layers/activate.py +13 -0
  17. yomitoku/models/layers/rtdetr_backbone.py +18 -17
  18. yomitoku/models/layers/rtdetr_hybrid_encoder.py +19 -20
  19. yomitoku/models/layers/rtdetrv2_decoder.py +14 -1
  20. yomitoku/models/parseq.py +15 -22
  21. yomitoku/ocr.py +24 -27
  22. yomitoku/onnx/.gitkeep +0 -0
  23. yomitoku/postprocessor/dbnet_postporcessor.py +15 -14
  24. yomitoku/postprocessor/parseq_tokenizer.py +1 -3
  25. yomitoku/postprocessor/rtdetr_postprocessor.py +14 -1
  26. yomitoku/table_structure_recognizer.py +82 -9
  27. yomitoku/text_detector.py +57 -7
  28. yomitoku/text_recognizer.py +84 -16
  29. yomitoku/utils/misc.py +21 -14
  30. yomitoku/utils/visualizer.py +15 -8
  31. {yomitoku-0.4.1.dist-info → yomitoku-0.7.4.dist-info}/METADATA +34 -41
  32. yomitoku-0.7.4.dist-info/RECORD +54 -0
  33. {yomitoku-0.4.1.dist-info → yomitoku-0.7.4.dist-info}/WHEEL +1 -1
  34. yomitoku-0.4.1.dist-info/RECORD +0 -52
  35. {yomitoku-0.4.1.dist-info → yomitoku-0.7.4.dist-info}/entry_points.txt +0 -0
@@ -2,17 +2,20 @@ import asyncio
2
2
  from concurrent.futures import ThreadPoolExecutor
3
3
  from typing import List, Union
4
4
 
5
+ import numpy as np
5
6
  from pydantic import conlist
6
7
 
8
+ from yomitoku.text_detector import TextDetector
9
+ from yomitoku.text_recognizer import TextRecognizer
10
+
7
11
  from .base import BaseSchema
8
12
  from .export import export_csv, export_html, export_markdown
9
13
  from .layout_analyzer import LayoutAnalyzer
10
- from .ocr import OCR, WordPrediction
11
- from .table_structure_recognizer import TableStructureRecognizerSchema
12
- from .utils.misc import is_contained, quad_to_xyxy
14
+ from .ocr import OCRSchema, WordPrediction, ocr_aggregate
13
15
  from .reading_order import prediction_reading_order
14
-
15
- from .utils.visualizer import reading_order_visualizer
16
+ from .table_structure_recognizer import TableStructureRecognizerSchema
17
+ from .utils.misc import calc_overlap_ratio, is_contained, quad_to_xyxy
18
+ from .utils.visualizer import det_visualizer, reading_order_visualizer
16
19
 
17
20
 
18
21
  class ParagraphSchema(BaseSchema):
@@ -38,13 +41,13 @@ class DocumentAnalyzerSchema(BaseSchema):
38
41
  figures: List[FigureSchema]
39
42
 
40
43
  def to_html(self, out_path: str, **kwargs):
41
- export_html(self, out_path, **kwargs)
44
+ return export_html(self, out_path, **kwargs)
42
45
 
43
46
  def to_markdown(self, out_path: str, **kwargs):
44
- export_markdown(self, out_path, **kwargs)
47
+ return export_markdown(self, out_path, **kwargs)
45
48
 
46
49
  def to_csv(self, out_path: str, **kwargs):
47
- export_csv(self, out_path, **kwargs)
50
+ return export_csv(self, out_path, **kwargs)
48
51
 
49
52
 
50
53
  def combine_flags(flag1, flag2):
@@ -98,47 +101,56 @@ def extract_words_within_element(pred_words, element):
98
101
  word_sum_width = 0
99
102
  word_sum_height = 0
100
103
  check_list = [False] * len(pred_words)
104
+
101
105
  for i, word in enumerate(pred_words):
102
106
  word_box = quad_to_xyxy(word.points)
103
107
  if is_contained(element.box, word_box, threshold=0.5):
104
- contained_words.append(word)
105
108
  word_sum_width += word_box[2] - word_box[0]
106
109
  word_sum_height += word_box[3] - word_box[1]
107
110
  check_list[i] = True
108
111
 
112
+ word_element = ParagraphSchema(
113
+ box=word_box,
114
+ contents=word.content,
115
+ direction=word.direction,
116
+ order=0,
117
+ role=None,
118
+ )
119
+ contained_words.append(word_element)
120
+
109
121
  if len(contained_words) == 0:
110
122
  return None, None, check_list
111
123
 
112
- mean_width = word_sum_width / len(contained_words)
113
- mean_height = word_sum_height / len(contained_words)
114
-
115
124
  word_direction = [word.direction for word in contained_words]
116
125
  cnt_horizontal = word_direction.count("horizontal")
117
126
  cnt_vertical = word_direction.count("vertical")
118
127
 
119
128
  element_direction = "horizontal" if cnt_horizontal > cnt_vertical else "vertical"
120
- if element_direction == "horizontal":
121
- contained_words = sorted(
122
- contained_words,
123
- key=lambda x: (
124
- x.points[0][1] // int(mean_height),
125
- x.points[0][0],
126
- ),
127
- )
128
- else:
129
- contained_words = sorted(
130
- contained_words,
131
- key=lambda x: (
132
- x.points[1][0] // int(mean_width),
133
- x.points[1][1],
134
- ),
135
- reverse=True,
136
- )
137
129
 
138
- contained_words = "\n".join([content.content for content in contained_words])
130
+ prediction_reading_order(contained_words, element_direction)
131
+ contained_words = sorted(contained_words, key=lambda x: x.order)
132
+
133
+ contained_words = "\n".join([content.contents for content in contained_words])
134
+
139
135
  return (contained_words, element_direction, check_list)
140
136
 
141
137
 
138
+ def is_vertical(quad, thresh_aspect=2):
139
+ quad = np.array(quad)
140
+ width = np.linalg.norm(quad[0] - quad[1])
141
+ height = np.linalg.norm(quad[1] - quad[2])
142
+
143
+ return height > width * thresh_aspect
144
+
145
+
146
+ def is_noise(quad, thresh=15):
147
+ quad = np.array(quad)
148
+ width = np.linalg.norm(quad[0] - quad[1])
149
+ height = np.linalg.norm(quad[1] - quad[2])
150
+
151
+ return width < thresh or height < thresh
152
+
153
+
142
154
  def recursive_update(original, new_data):
143
155
  for key, value in new_data.items():
144
156
  # `value`が辞書の場合、再帰的に更新
@@ -154,8 +166,169 @@ def recursive_update(original, new_data):
154
166
  return original
155
167
 
156
168
 
169
+ def _extract_words_within_table(words, table, check_list):
170
+ horizontal_words = []
171
+ vertical_words = []
172
+
173
+ for i, (points, score) in enumerate(zip(words.points, words.scores)):
174
+ word_box = quad_to_xyxy(points)
175
+ if is_contained(table.box, word_box, threshold=0.5):
176
+ if is_vertical(points):
177
+ vertical_words.append({"points": points, "score": score})
178
+ else:
179
+ horizontal_words.append({"points": points, "score": score})
180
+
181
+ check_list[i] = True
182
+
183
+ return (horizontal_words, vertical_words, check_list)
184
+
185
+
186
+ def _calc_overlap_words_on_lines(lines, words):
187
+ overlap_ratios = [[0 for _ in lines] for _ in words]
188
+
189
+ for i, word in enumerate(words):
190
+ word_box = quad_to_xyxy(word["points"])
191
+ for j, row in enumerate(lines):
192
+ overlap_ratio, _ = calc_overlap_ratio(
193
+ row.box,
194
+ word_box,
195
+ )
196
+ overlap_ratios[i][j] = overlap_ratio
197
+
198
+ return overlap_ratios
199
+
200
+
201
+ def _correct_vertical_word_boxes(overlap_ratios_vertical, table, table_words_vertical):
202
+ allocated_cols = [cols.index(max(cols)) for cols in overlap_ratios_vertical]
203
+
204
+ new_points = []
205
+ new_scores = []
206
+ for i, col_index in enumerate(allocated_cols):
207
+ col_cells = []
208
+ for cell in table.cells:
209
+ if cell.col <= (col_index + 1) < (cell.col + cell.col_span):
210
+ col_cells.append(cell)
211
+
212
+ word_point = table_words_vertical[i]["points"]
213
+ word_score = table_words_vertical[i]["score"]
214
+
215
+ for cell in col_cells:
216
+ word_box = quad_to_xyxy(word_point)
217
+
218
+ _, intersection = calc_overlap_ratio(
219
+ cell.box,
220
+ word_box,
221
+ )
222
+
223
+ if intersection is not None:
224
+ _, y1, _, y2 = intersection
225
+
226
+ new_point = [
227
+ [word_point[0][0], max(word_point[0][1], y1)],
228
+ [word_point[1][0], max(word_point[1][1], y1)],
229
+ [word_point[2][0], min(word_point[2][1], y2)],
230
+ [word_point[3][0], min(word_point[3][1], y2)],
231
+ ]
232
+
233
+ if not is_noise(new_point):
234
+ new_points.append(new_point)
235
+ new_scores.append(word_score)
236
+
237
+ return new_points, new_scores
238
+
239
+
240
+ def _correct_horizontal_word_boxes(
241
+ overlap_ratios_horizontal, table, table_words_horizontal
242
+ ):
243
+ allocated_rows = [rows.index(max(rows)) for rows in overlap_ratios_horizontal]
244
+
245
+ new_points = []
246
+ new_scores = []
247
+ for i, row_index in enumerate(allocated_rows):
248
+ row_cells = []
249
+ for cell in table.cells:
250
+ if cell.row <= (row_index + 1) < (cell.row + cell.row_span):
251
+ row_cells.append(cell)
252
+
253
+ word_point = table_words_horizontal[i]["points"]
254
+ word_score = table_words_horizontal[i]["score"]
255
+
256
+ for cell in row_cells:
257
+ word_box = quad_to_xyxy(word_point)
258
+
259
+ _, intersection = calc_overlap_ratio(
260
+ cell.box,
261
+ word_box,
262
+ )
263
+
264
+ if intersection is not None:
265
+ x1, _, x2, _ = intersection
266
+
267
+ new_point = [
268
+ [max(word_point[0][0], x1), word_point[0][1]],
269
+ [min(word_point[1][0], x2), word_point[1][1]],
270
+ [min(word_point[2][0], x2), word_point[2][1]],
271
+ [max(word_point[3][0], x1), word_point[3][1]],
272
+ ]
273
+
274
+ if not is_noise(new_point):
275
+ new_points.append(new_point)
276
+ new_scores.append(word_score)
277
+
278
+ return new_points, new_scores
279
+
280
+
281
+ def _split_text_across_cells(results_det, results_layout):
282
+ check_list = [False] * len(results_det.points)
283
+ new_points = []
284
+ new_scores = []
285
+ for table in results_layout.tables:
286
+ table_words_horizontal, table_words_vertical, check_list = (
287
+ _extract_words_within_table(results_det, table, check_list)
288
+ )
289
+
290
+ overlap_ratios_horizontal = _calc_overlap_words_on_lines(
291
+ table.rows,
292
+ table_words_horizontal,
293
+ )
294
+
295
+ overlap_ratios_vertical = _calc_overlap_words_on_lines(
296
+ table.cols,
297
+ table_words_vertical,
298
+ )
299
+
300
+ new_points_horizontal, new_scores_horizontal = _correct_horizontal_word_boxes(
301
+ overlap_ratios_horizontal, table, table_words_horizontal
302
+ )
303
+
304
+ new_points_vertical, new_scores_vertical = _correct_vertical_word_boxes(
305
+ overlap_ratios_vertical, table, table_words_vertical
306
+ )
307
+
308
+ new_points.extend(new_points_horizontal)
309
+ new_scores.extend(new_scores_horizontal)
310
+ new_points.extend(new_points_vertical)
311
+ new_scores.extend(new_scores_vertical)
312
+
313
+ for i, flag in enumerate(check_list):
314
+ if not flag:
315
+ new_points.append(results_det.points[i])
316
+ new_scores.append(results_det.scores[i])
317
+
318
+ results_det.points = new_points
319
+ results_det.scores = new_scores
320
+
321
+ return results_det
322
+
323
+
157
324
  class DocumentAnalyzer:
158
- def __init__(self, configs=None, device="cuda", visualize=False):
325
+ def __init__(
326
+ self,
327
+ configs={},
328
+ device="cuda",
329
+ visualize=False,
330
+ ignore_meta=False,
331
+ ):
159
332
  default_configs = {
160
333
  "ocr": {
161
334
  "text_detector": {
@@ -186,10 +359,20 @@ class DocumentAnalyzer:
186
359
  "configs must be a dict. See the https://kotaro-kinoshita.github.io/yomitoku-dev/usage/"
187
360
  )
188
361
 
189
- self.ocr = OCR(configs=default_configs["ocr"])
190
- self.layout = LayoutAnalyzer(configs=default_configs["layout_analyzer"])
362
+ self.text_detector = TextDetector(
363
+ **default_configs["ocr"]["text_detector"],
364
+ )
365
+ self.text_recognizer = TextRecognizer(
366
+ **default_configs["ocr"]["text_recognizer"]
367
+ )
368
+
369
+ self.layout = LayoutAnalyzer(
370
+ configs=default_configs["layout_analyzer"],
371
+ )
191
372
  self.visualize = visualize
192
373
 
374
+ self.ignore_meta = ignore_meta
375
+
193
376
  def aggregate(self, ocr_res, layout_res):
194
377
  paragraphs = []
195
378
  check_list = [False] * len(ocr_res.words)
@@ -250,11 +433,15 @@ class DocumentAnalyzer:
250
433
  page_direction = judge_page_direction(paragraphs)
251
434
 
252
435
  headers = [
253
- paragraph for paragraph in paragraphs if paragraph.role == "page_header"
436
+ paragraph
437
+ for paragraph in paragraphs
438
+ if paragraph.role == "page_header" and not self.ignore_meta
254
439
  ]
255
440
 
256
441
  footers = [
257
- paragraph for paragraph in paragraphs if paragraph.role == "page_footer"
442
+ paragraph
443
+ for paragraph in paragraphs
444
+ if paragraph.role == "page_footer" and not self.ignore_meta
258
445
  ]
259
446
 
260
447
  page_contents = [
@@ -292,24 +479,39 @@ class DocumentAnalyzer:
292
479
  with ThreadPoolExecutor(max_workers=2) as executor:
293
480
  loop = asyncio.get_running_loop()
294
481
  tasks = [
295
- loop.run_in_executor(executor, self.ocr, img),
482
+ # loop.run_in_executor(executor, self.ocr, img),
483
+ loop.run_in_executor(executor, self.text_detector, img),
296
484
  loop.run_in_executor(executor, self.layout, img),
297
485
  ]
298
486
 
299
487
  results = await asyncio.gather(*tasks)
300
488
 
301
- results_ocr, ocr = results[0]
489
+ results_det, _ = results[0]
302
490
  results_layout, layout = results[1]
303
491
 
304
- outputs = self.aggregate(results_ocr, results_layout)
492
+ results_det = _split_text_across_cells(results_det, results_layout)
493
+
494
+ vis_det = None
495
+ if self.visualize:
496
+ vis_det = det_visualizer(
497
+ img,
498
+ results_det.points,
499
+ )
500
+
501
+ results_rec, ocr = self.text_recognizer(img, results_det.points, vis_det)
502
+
503
+ outputs = {"words": ocr_aggregate(results_det, results_rec)}
504
+ results_ocr = OCRSchema(**outputs)
505
+ outputs = self.aggregate(results_ocr, results_layout)
506
+
305
507
  results = DocumentAnalyzerSchema(**outputs)
306
508
  return results, ocr, layout
307
509
 
308
510
  def __call__(self, img):
309
511
  self.img = img
310
- resutls, ocr, layout = asyncio.run(self.run(img))
512
+ results, ocr, layout = asyncio.run(self.run(img))
311
513
 
312
514
  if self.visualize:
313
- layout = reading_order_visualizer(layout, resutls)
515
+ layout = reading_order_visualizer(layout, results)
314
516
 
315
- return resutls, ocr, layout
517
+ return results, ocr, layout
@@ -1,6 +1,19 @@
1
- from .export_csv import export_csv
2
- from .export_html import export_html
3
- from .export_json import export_json
4
- from .export_markdown import export_markdown
1
+ from .export_csv import export_csv, save_csv, convert_csv
2
+ from .export_html import export_html, save_html, convert_html
3
+ from .export_json import export_json, save_json, convert_json
4
+ from .export_markdown import export_markdown, save_markdown, convert_markdown
5
5
 
6
- __all__ = ["export_html", "export_markdown", "export_csv", "export_json"]
6
+ __all__ = [
7
+ "export_html",
8
+ "export_markdown",
9
+ "export_csv",
10
+ "export_json",
11
+ "save_html",
12
+ "save_markdown",
13
+ "save_csv",
14
+ "save_json",
15
+ "convert_html",
16
+ "convert_markdown",
17
+ "convert_csv",
18
+ "convert_json",
19
+ ]
@@ -1,4 +1,7 @@
1
1
  import csv
2
+ import os
3
+
4
+ import cv2
2
5
 
3
6
 
4
7
  def table_to_csv(table, ignore_line_break):
@@ -33,7 +36,35 @@ def paragraph_to_csv(paragraph, ignore_line_break):
33
36
  return contents
34
37
 
35
38
 
36
- def export_csv(inputs, out_path: str, ignore_line_break: bool = False):
39
+ def save_figure(
40
+ figures,
41
+ img,
42
+ out_path,
43
+ figure_dir="figures",
44
+ ):
45
+ assert img is not None, "img is required for saving figures"
46
+
47
+ for i, figure in enumerate(figures):
48
+ x1, y1, x2, y2 = map(int, figure.box)
49
+ figure_img = img[y1:y2, x1:x2, :]
50
+ save_dir = os.path.dirname(out_path)
51
+ save_dir = os.path.join(save_dir, figure_dir)
52
+ os.makedirs(save_dir, exist_ok=True)
53
+
54
+ filename = os.path.splitext(os.path.basename(out_path))[0]
55
+ figure_name = f"{filename}_figure_{i}.png"
56
+ figure_path = os.path.join(save_dir, figure_name)
57
+ cv2.imwrite(figure_path, figure_img)
58
+
59
+
60
+ def convert_csv(
61
+ inputs,
62
+ out_path,
63
+ ignore_line_break,
64
+ img=None,
65
+ export_figure: bool = True,
66
+ figure_dir="figures",
67
+ ):
37
68
  elements = []
38
69
  for table in inputs.tables:
39
70
  table_csv = table_to_csv(table, ignore_line_break)
@@ -60,7 +91,45 @@ def export_csv(inputs, out_path: str, ignore_line_break: bool = False):
60
91
 
61
92
  elements = sorted(elements, key=lambda x: x["order"])
62
93
 
63
- with open(out_path, "w", newline="", encoding="utf-8") as f:
94
+ if export_figure:
95
+ save_figure(
96
+ inputs.figures,
97
+ img,
98
+ out_path,
99
+ figure_dir=figure_dir,
100
+ )
101
+
102
+ return elements
103
+
104
+
105
+ def export_csv(
106
+ inputs,
107
+ out_path: str,
108
+ ignore_line_break: bool = False,
109
+ encoding: str = "utf-8",
110
+ img=None,
111
+ export_figure: bool = True,
112
+ figure_dir="figures",
113
+ ):
114
+ elements = convert_csv(
115
+ inputs,
116
+ out_path,
117
+ ignore_line_break,
118
+ img,
119
+ export_figure,
120
+ figure_dir,
121
+ )
122
+
123
+ save_csv(elements, out_path, encoding)
124
+ return elements
125
+
126
+
127
+ def save_csv(
128
+ elements,
129
+ out_path,
130
+ encoding,
131
+ ):
132
+ with open(out_path, "w", newline="", encoding=encoding, errors="ignore") as f:
64
133
  writer = csv.writer(f, quoting=csv.QUOTE_MINIMAL)
65
134
  for element in elements:
66
135
  if element["type"] == "table":
@@ -1,9 +1,8 @@
1
- import re
2
1
  import os
3
- import cv2
4
-
2
+ import re
5
3
  from html import escape
6
4
 
5
+ import cv2
7
6
  from lxml import etree, html
8
7
 
9
8
 
@@ -110,6 +109,8 @@ def figure_to_html(
110
109
  figure_dir="figures",
111
110
  width=200,
112
111
  ):
112
+ assert img is not None, "img is required for saving figures"
113
+
113
114
  elements = []
114
115
  for i, figure in enumerate(figures):
115
116
  x1, y1, x2, y2 = map(int, figure.box)
@@ -145,12 +146,12 @@ def figure_to_html(
145
146
  return elements
146
147
 
147
148
 
148
- def export_html(
149
+ def convert_html(
149
150
  inputs,
150
- out_path: str,
151
- ignore_line_break: bool = False,
152
- export_figure: bool = True,
153
- export_figure_letter: bool = False,
151
+ out_path,
152
+ ignore_line_break,
153
+ export_figure,
154
+ export_figure_letter,
154
155
  img=None,
155
156
  figure_width=200,
156
157
  figure_dir="figures",
@@ -179,10 +180,43 @@ def export_html(
179
180
  elements = sorted(elements, key=lambda x: x["order"])
180
181
 
181
182
  html_string = "".join([element["html"] for element in elements])
182
- html_string = add_html_tag(html_string)
183
-
184
183
  parsed_html = html.fromstring(html_string)
185
184
  formatted_html = etree.tostring(parsed_html, pretty_print=True, encoding="unicode")
186
185
 
187
- with open(out_path, "w", encoding="utf-8") as f:
188
- f.write(formatted_html)
186
+ return formatted_html, elements
187
+
188
+
189
+ def export_html(
190
+ inputs,
191
+ out_path: str,
192
+ ignore_line_break: bool = False,
193
+ export_figure: bool = True,
194
+ export_figure_letter: bool = False,
195
+ img=None,
196
+ figure_width=200,
197
+ figure_dir="figures",
198
+ encoding: str = "utf-8",
199
+ ):
200
+ formatted_html, elements = convert_html(
201
+ inputs,
202
+ out_path,
203
+ ignore_line_break,
204
+ export_figure,
205
+ export_figure_letter,
206
+ img,
207
+ figure_width,
208
+ figure_dir,
209
+ )
210
+
211
+ save_html(formatted_html, out_path, encoding)
212
+
213
+ return formatted_html
214
+
215
+
216
+ def save_html(
217
+ html,
218
+ out_path,
219
+ encoding,
220
+ ):
221
+ with open(out_path, "w", encoding=encoding, errors="ignore") as f:
222
+ f.write(html)
@@ -1,4 +1,7 @@
1
1
  import json
2
+ import os
3
+
4
+ import cv2
2
5
 
3
6
 
4
7
  def paragraph_to_json(paragraph, ignore_line_break):
@@ -12,7 +15,28 @@ def table_to_json(table, ignore_line_break):
12
15
  cell.contents = cell.contents.replace("\n", "")
13
16
 
14
17
 
15
- def export_json(inputs, out_path, ignore_line_break=False):
18
+ def save_figure(
19
+ figures,
20
+ img,
21
+ out_path,
22
+ figure_dir="figures",
23
+ ):
24
+ assert img is not None, "img is required for saving figures"
25
+
26
+ for i, figure in enumerate(figures):
27
+ x1, y1, x2, y2 = map(int, figure.box)
28
+ figure_img = img[y1:y2, x1:x2, :]
29
+ save_dir = os.path.dirname(out_path)
30
+ save_dir = os.path.join(save_dir, figure_dir)
31
+ os.makedirs(save_dir, exist_ok=True)
32
+
33
+ filename = os.path.splitext(os.path.basename(out_path))[0]
34
+ figure_name = f"{filename}_figure_{i}.png"
35
+ figure_path = os.path.join(save_dir, figure_name)
36
+ cv2.imwrite(figure_path, figure_img)
37
+
38
+
39
+ def convert_json(inputs, out_path, ignore_line_break, img, export_figure, figure_dir):
16
40
  from yomitoku.document_analyzer import DocumentAnalyzerSchema
17
41
 
18
42
  if isinstance(inputs, DocumentAnalyzerSchema):
@@ -23,9 +47,48 @@ def export_json(inputs, out_path, ignore_line_break=False):
23
47
  for paragraph in inputs.paragraphs:
24
48
  paragraph_to_json(paragraph, ignore_line_break)
25
49
 
26
- with open(out_path, "w", encoding="utf-8") as f:
50
+ if isinstance(inputs, DocumentAnalyzerSchema) and export_figure:
51
+ save_figure(
52
+ inputs.figures,
53
+ img,
54
+ out_path,
55
+ figure_dir=figure_dir,
56
+ )
57
+
58
+ return inputs
59
+
60
+
61
+ def export_json(
62
+ inputs,
63
+ out_path,
64
+ ignore_line_break=False,
65
+ encoding: str = "utf-8",
66
+ img=None,
67
+ export_figure=False,
68
+ figure_dir="figures",
69
+ ):
70
+ inputs = convert_json(
71
+ inputs,
72
+ out_path,
73
+ ignore_line_break,
74
+ img,
75
+ export_figure,
76
+ figure_dir,
77
+ )
78
+
79
+ save_json(
80
+ inputs.model_dump(),
81
+ out_path,
82
+ encoding,
83
+ )
84
+
85
+ return inputs
86
+
87
+
88
+ def save_json(data, out_path, encoding):
89
+ with open(out_path, "w", encoding=encoding, errors="ignore") as f:
27
90
  json.dump(
28
- inputs.model_dump(),
91
+ data,
29
92
  f,
30
93
  ensure_ascii=False,
31
94
  indent=4,