yomitoku 0.5.3__py3-none-any.whl → 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
yomitoku/cli/main.py CHANGED
@@ -13,6 +13,18 @@ from ..utils.logger import set_logger
13
13
  logger = set_logger(__name__, "INFO")
14
14
 
15
15
 
16
+ def validate_encoding(encoding):
17
+ if encoding not in [
18
+ "utf-8",
19
+ "utf-8-sig",
20
+ "shift-jis",
21
+ "euc-jp",
22
+ "cp932",
23
+ ]:
24
+ raise ValueError(f"Invalid encoding: {encoding}")
25
+ return True
26
+
27
+
16
28
  def process_single_file(args, analyzer, path, format):
17
29
  if path.suffix[1:].lower() in ["pdf"]:
18
30
  imgs = load_pdf(path)
@@ -21,7 +33,6 @@ def process_single_file(args, analyzer, path, format):
21
33
 
22
34
  for page, img in enumerate(imgs):
23
35
  results, ocr, layout = analyzer(img)
24
-
25
36
  dirname = path.parent.name
26
37
  filename = path.stem
27
38
 
@@ -47,11 +58,19 @@ def process_single_file(args, analyzer, path, format):
47
58
  results.to_json(
48
59
  out_path,
49
60
  ignore_line_break=args.ignore_line_break,
61
+ encoding=args.encoding,
62
+ img=img,
63
+ export_figure=args.figure,
64
+ figure_dir=args.figure_dir,
50
65
  )
51
66
  elif format == "csv":
52
67
  results.to_csv(
53
68
  out_path,
54
69
  ignore_line_break=args.ignore_line_break,
70
+ encoding=args.encoding,
71
+ img=img,
72
+ export_figure=args.figure,
73
+ figure_dir=args.figure_dir,
55
74
  )
56
75
  elif format == "html":
57
76
  results.to_html(
@@ -62,6 +81,7 @@ def process_single_file(args, analyzer, path, format):
62
81
  export_figure_letter=args.figure_letter,
63
82
  figure_width=args.figure_width,
64
83
  figure_dir=args.figure_dir,
84
+ encoding=args.encoding,
65
85
  )
66
86
  elif format == "md":
67
87
  results.to_markdown(
@@ -72,6 +92,7 @@ def process_single_file(args, analyzer, path, format):
72
92
  export_figure_letter=args.figure_letter,
73
93
  figure_width=args.figure_width,
74
94
  figure_dir=args.figure_dir,
95
+ encoding=args.encoding,
75
96
  )
76
97
 
77
98
  logger.info(f"Output file: {out_path}")
@@ -104,6 +125,12 @@ def main():
104
125
  default="results",
105
126
  help="output directory",
106
127
  )
128
+ parser.add_argument(
129
+ "-l",
130
+ "--lite",
131
+ action="store_true",
132
+ help="if set, use lite model",
133
+ )
107
134
  parser.add_argument(
108
135
  "-d",
109
136
  "--device",
@@ -162,6 +189,12 @@ def main():
162
189
  default="figures",
163
190
  help="directory to save figure images",
164
191
  )
192
+ parser.add_argument(
193
+ "--encoding",
194
+ type=str,
195
+ default="utf-8",
196
+ help="Specifies the character encoding for the output file to be exported. If unsupported characters are included, they will be ignored.",
197
+ )
165
198
 
166
199
  args = parser.parse_args()
167
200
 
@@ -175,6 +208,8 @@ def main():
175
208
  f"Invalid output format: {args.format}. Supported formats are {SUPPORT_OUTPUT_FORMAT}"
176
209
  )
177
210
 
211
+ validate_encoding(args.encoding)
212
+
178
213
  if format == "markdown":
179
214
  format = "md"
180
215
 
@@ -197,6 +232,17 @@ def main():
197
232
  },
198
233
  }
199
234
 
235
+ if args.lite:
236
+ configs["ocr"]["text_recognizer"]["model_name"] = "parseq-small"
237
+
238
+ if args.device == "cpu":
239
+ configs["ocr"]["text_detector"]["infer_onnx"] = True
240
+
241
+ # Note: Text Detector以外はONNX推論よりもPyTorch推論の方が速いため、ONNX推論は行わない
242
+ # configs["ocr"]["text_recognizer"]["infer_onnx"] = True
243
+ # configs["layout_analyzer"]["table_structure_recognizer"]["infer_onnx"] = True
244
+ # configs["layout_analyzer"]["layout_parser"]["infer_onnx"] = True
245
+
200
246
  analyzer = DocumentAnalyzer(
201
247
  configs=configs,
202
248
  visualize=args.vis,
@@ -4,10 +4,12 @@ from .cfg_table_structure_recognizer_rtdtrv2 import (
4
4
  )
5
5
  from .cfg_text_detector_dbnet import TextDetectorDBNetConfig
6
6
  from .cfg_text_recognizer_parseq import TextRecognizerPARSeqConfig
7
+ from .cfg_text_recognizer_parseq_small import TextRecognizerPARSeqSmallConfig
7
8
 
8
9
  __all__ = [
9
10
  "TextDetectorDBNetConfig",
10
11
  "TextRecognizerPARSeqConfig",
11
12
  "LayoutParserRTDETRv2Config",
12
13
  "TableStructureRecognizerRTDETRv2Config",
14
+ "TextRecognizerPARSeqSmallConfig",
13
15
  ]
@@ -0,0 +1,51 @@
1
+ from dataclasses import dataclass, field
2
+ from typing import List
3
+
4
+ from ..constants import ROOT_DIR
5
+
6
+
7
+ @dataclass
8
+ class Data:
9
+ num_workers: int = 4
10
+ batch_size: int = 128
11
+ img_size: List[int] = field(default_factory=lambda: [32, 800])
12
+
13
+
14
+ @dataclass
15
+ class Encoder:
16
+ patch_size: List[int] = field(default_factory=lambda: [16, 16])
17
+ num_heads: int = 8
18
+ embed_dim: int = 384
19
+ mlp_ratio: int = 4
20
+ depth: int = 9
21
+
22
+
23
+ @dataclass
24
+ class Decoder:
25
+ embed_dim: int = 384
26
+ num_heads: int = 8
27
+ mlp_ratio: int = 4
28
+ depth: int = 1
29
+
30
+
31
+ @dataclass
32
+ class Visualize:
33
+ font: str = str(ROOT_DIR + "/resource/MPLUS1p-Medium.ttf")
34
+ color: List[int] = field(default_factory=lambda: [0, 0, 255]) # RGB
35
+ font_size: int = 18
36
+
37
+
38
+ @dataclass
39
+ class TextRecognizerPARSeqSmallConfig:
40
+ hf_hub_repo: str = "KotaroKinoshita/yomitoku-text-recognizer-parseq-small-open-beta"
41
+ charset: str = str(ROOT_DIR + "/resource/charset.txt")
42
+ num_tokens: int = 7312
43
+ max_label_length: int = 100
44
+ decode_ar: int = 1
45
+ refine_iters: int = 1
46
+
47
+ data: Data = field(default_factory=Data)
48
+ encoder: Encoder = field(default_factory=Encoder)
49
+ decoder: Decoder = field(default_factory=Decoder)
50
+
51
+ visualize: Visualize = field(default_factory=Visualize)
@@ -2,17 +2,26 @@ import asyncio
2
2
  from concurrent.futures import ThreadPoolExecutor
3
3
  from typing import List, Union
4
4
 
5
+ import numpy as np
6
+
5
7
  from pydantic import conlist
6
8
 
7
9
  from .base import BaseSchema
8
10
  from .export import export_csv, export_html, export_markdown
9
11
  from .layout_analyzer import LayoutAnalyzer
10
- from .ocr import OCR, WordPrediction
11
- from .table_structure_recognizer import TableStructureRecognizerSchema
12
- from .utils.misc import is_contained, quad_to_xyxy
12
+ from .ocr import OCRSchema, WordPrediction, ocr_aggregate
13
13
  from .reading_order import prediction_reading_order
14
-
14
+ from .table_structure_recognizer import TableStructureRecognizerSchema
15
+ from .utils.misc import (
16
+ is_contained,
17
+ quad_to_xyxy,
18
+ calc_overlap_ratio,
19
+ )
15
20
  from .utils.visualizer import reading_order_visualizer
21
+ from yomitoku.text_detector import TextDetector
22
+ from yomitoku.text_recognizer import TextRecognizer
23
+
24
+ from .utils.visualizer import det_visualizer
16
25
 
17
26
 
18
27
  class ParagraphSchema(BaseSchema):
@@ -98,41 +107,57 @@ def extract_words_within_element(pred_words, element):
98
107
  word_sum_width = 0
99
108
  word_sum_height = 0
100
109
  check_list = [False] * len(pred_words)
110
+
101
111
  for i, word in enumerate(pred_words):
102
112
  word_box = quad_to_xyxy(word.points)
103
113
  if is_contained(element.box, word_box, threshold=0.5):
104
- contained_words.append(word)
105
114
  word_sum_width += word_box[2] - word_box[0]
106
115
  word_sum_height += word_box[3] - word_box[1]
107
116
  check_list[i] = True
108
117
 
118
+ word_element = ParagraphSchema(
119
+ box=word_box,
120
+ contents=word.content,
121
+ direction=word.direction,
122
+ order=0,
123
+ role=None,
124
+ )
125
+ contained_words.append(word_element)
126
+
109
127
  if len(contained_words) == 0:
110
128
  return None, None, check_list
111
129
 
112
- # mean_width = word_sum_width / len(contained_words)
113
- # mean_height = word_sum_height / len(contained_words)
114
-
130
+ element_direction = "horizontal"
115
131
  word_direction = [word.direction for word in contained_words]
116
132
  cnt_horizontal = word_direction.count("horizontal")
117
133
  cnt_vertical = word_direction.count("vertical")
118
134
 
119
135
  element_direction = "horizontal" if cnt_horizontal > cnt_vertical else "vertical"
120
- if element_direction == "horizontal":
121
- contained_words = sorted(
122
- contained_words,
123
- key=lambda x: (sum([p[1] for p in x.points]) / 4),
124
- )
125
- else:
126
- contained_words = sorted(
127
- contained_words,
128
- key=lambda x: (sum([p[0] for p in x.points]) / 4),
129
- reverse=True,
130
- )
131
136
 
132
- contained_words = "\n".join([content.content for content in contained_words])
137
+ prediction_reading_order(contained_words, element_direction)
138
+ contained_words = sorted(contained_words, key=lambda x: x.order)
139
+
140
+ contained_words = "\n".join([content.contents for content in contained_words])
141
+
133
142
  return (contained_words, element_direction, check_list)
134
143
 
135
144
 
145
+ def is_vertical(quad, thresh_aspect=2):
146
+ quad = np.array(quad)
147
+ width = np.linalg.norm(quad[0] - quad[1])
148
+ height = np.linalg.norm(quad[1] - quad[2])
149
+
150
+ return height > width * thresh_aspect
151
+
152
+
153
+ def is_noise(quad, thresh=15):
154
+ quad = np.array(quad)
155
+ width = np.linalg.norm(quad[0] - quad[1])
156
+ height = np.linalg.norm(quad[1] - quad[2])
157
+
158
+ return width < thresh or height < thresh
159
+
160
+
136
161
  def recursive_update(original, new_data):
137
162
  for key, value in new_data.items():
138
163
  # `value`が辞書の場合、再帰的に更新
@@ -148,8 +173,163 @@ def recursive_update(original, new_data):
148
173
  return original
149
174
 
150
175
 
176
+ def _extract_words_within_table(words, table, check_list):
177
+ horizontal_words = []
178
+ vertical_words = []
179
+
180
+ for i, (points, score) in enumerate(zip(words.points, words.scores)):
181
+ word_box = quad_to_xyxy(points)
182
+ if is_contained(table.box, word_box, threshold=0.5):
183
+ if is_vertical(points):
184
+ vertical_words.append({"points": points, "score": score})
185
+ else:
186
+ horizontal_words.append({"points": points, "score": score})
187
+
188
+ check_list[i] = True
189
+
190
+ return (horizontal_words, vertical_words, check_list)
191
+
192
+
193
+ def _calc_overlap_words_on_lines(lines, words):
194
+ overlap_ratios = [[0 for _ in lines] for _ in words]
195
+
196
+ for i, word in enumerate(words):
197
+ word_box = quad_to_xyxy(word["points"])
198
+ for j, row in enumerate(lines):
199
+ overlap_ratio, _ = calc_overlap_ratio(
200
+ row.box,
201
+ word_box,
202
+ )
203
+ overlap_ratios[i][j] = overlap_ratio
204
+
205
+ return overlap_ratios
206
+
207
+
208
+ def _correct_vertical_word_boxes(overlap_ratios_vertical, table, table_words_vertical):
209
+ allocated_cols = [cols.index(max(cols)) for cols in overlap_ratios_vertical]
210
+
211
+ new_points = []
212
+ new_scores = []
213
+ for i, col_index in enumerate(allocated_cols):
214
+ col_cells = []
215
+ for cell in table.cells:
216
+ if cell.col <= (col_index + 1) < (cell.col + cell.col_span):
217
+ col_cells.append(cell)
218
+
219
+ word_point = table_words_vertical[i]["points"]
220
+ word_score = table_words_vertical[i]["score"]
221
+
222
+ for cell in col_cells:
223
+ word_box = quad_to_xyxy(word_point)
224
+
225
+ _, intersection = calc_overlap_ratio(
226
+ cell.box,
227
+ word_box,
228
+ )
229
+
230
+ if intersection is not None:
231
+ _, y1, _, y2 = intersection
232
+
233
+ new_point = [
234
+ [word_point[0][0], max(word_point[0][1], y1)],
235
+ [word_point[1][0], max(word_point[1][1], y1)],
236
+ [word_point[2][0], min(word_point[2][1], y2)],
237
+ [word_point[3][0], min(word_point[3][1], y2)],
238
+ ]
239
+
240
+ if not is_noise(new_point):
241
+ new_points.append(new_point)
242
+ new_scores.append(word_score)
243
+
244
+ return new_points, new_scores
245
+
246
+
247
+ def _correct_horizontal_word_boxes(
248
+ overlap_ratios_horizontal, table, table_words_horizontal
249
+ ):
250
+ allocated_rows = [rows.index(max(rows)) for rows in overlap_ratios_horizontal]
251
+
252
+ new_points = []
253
+ new_scores = []
254
+ for i, row_index in enumerate(allocated_rows):
255
+ row_cells = []
256
+ for cell in table.cells:
257
+ if cell.row <= (row_index + 1) < (cell.row + cell.row_span):
258
+ row_cells.append(cell)
259
+
260
+ word_point = table_words_horizontal[i]["points"]
261
+ word_score = table_words_horizontal[i]["score"]
262
+
263
+ for cell in row_cells:
264
+ word_box = quad_to_xyxy(word_point)
265
+
266
+ _, intersection = calc_overlap_ratio(
267
+ cell.box,
268
+ word_box,
269
+ )
270
+
271
+ if intersection is not None:
272
+ x1, _, x2, _ = intersection
273
+
274
+ new_point = [
275
+ [max(word_point[0][0], x1), word_point[0][1]],
276
+ [min(word_point[1][0], x2), word_point[1][1]],
277
+ [min(word_point[2][0], x2), word_point[2][1]],
278
+ [max(word_point[3][0], x1), word_point[3][1]],
279
+ ]
280
+
281
+ if not is_noise(new_point):
282
+ new_points.append(new_point)
283
+ new_scores.append(word_score)
284
+
285
+ return new_points, new_scores
286
+
287
+
288
+ def _split_text_across_cells(results_det, results_layout):
289
+ check_list = [False] * len(results_det.points)
290
+ new_points = []
291
+ new_scores = []
292
+ for table in results_layout.tables:
293
+ table_words_horizontal, table_words_vertical, check_list = (
294
+ _extract_words_within_table(results_det, table, check_list)
295
+ )
296
+
297
+ overlap_ratios_horizontal = _calc_overlap_words_on_lines(
298
+ table.rows,
299
+ table_words_horizontal,
300
+ )
301
+
302
+ overlap_ratios_vertical = _calc_overlap_words_on_lines(
303
+ table.cols,
304
+ table_words_vertical,
305
+ )
306
+
307
+ new_points_horizontal, new_scores_horizontal = _correct_horizontal_word_boxes(
308
+ overlap_ratios_horizontal, table, table_words_horizontal
309
+ )
310
+
311
+ new_points_vertical, new_scores_vertical = _correct_vertical_word_boxes(
312
+ overlap_ratios_vertical, table, table_words_vertical
313
+ )
314
+
315
+ new_points.extend(new_points_horizontal)
316
+ new_scores.extend(new_scores_horizontal)
317
+ new_points.extend(new_points_vertical)
318
+ new_scores.extend(new_scores_vertical)
319
+
320
+ for i, flag in enumerate(check_list):
321
+ if not flag:
322
+ new_points.append(results_det.points[i])
323
+ new_scores.append(results_det.scores[i])
324
+
325
+ results_det.points = new_points
326
+ results_det.scores = new_scores
327
+
328
+ return results_det
329
+
330
+
151
331
  class DocumentAnalyzer:
152
- def __init__(self, configs=None, device="cuda", visualize=False):
332
+ def __init__(self, configs={}, device="cuda", visualize=False):
153
333
  default_configs = {
154
334
  "ocr": {
155
335
  "text_detector": {
@@ -180,8 +360,16 @@ class DocumentAnalyzer:
180
360
  "configs must be a dict. See the https://kotaro-kinoshita.github.io/yomitoku-dev/usage/"
181
361
  )
182
362
 
183
- self.ocr = OCR(configs=default_configs["ocr"])
184
- self.layout = LayoutAnalyzer(configs=default_configs["layout_analyzer"])
363
+ self.text_detector = TextDetector(
364
+ **default_configs["ocr"]["text_detector"],
365
+ )
366
+ self.text_recognizer = TextRecognizer(
367
+ **default_configs["ocr"]["text_recognizer"]
368
+ )
369
+
370
+ self.layout = LayoutAnalyzer(
371
+ configs=default_configs["layout_analyzer"],
372
+ )
185
373
  self.visualize = visualize
186
374
 
187
375
  def aggregate(self, ocr_res, layout_res):
@@ -286,16 +474,31 @@ class DocumentAnalyzer:
286
474
  with ThreadPoolExecutor(max_workers=2) as executor:
287
475
  loop = asyncio.get_running_loop()
288
476
  tasks = [
289
- loop.run_in_executor(executor, self.ocr, img),
477
+ # loop.run_in_executor(executor, self.ocr, img),
478
+ loop.run_in_executor(executor, self.text_detector, img),
290
479
  loop.run_in_executor(executor, self.layout, img),
291
480
  ]
292
481
 
293
482
  results = await asyncio.gather(*tasks)
294
483
 
295
- results_ocr, ocr = results[0]
484
+ results_det, _ = results[0]
296
485
  results_layout, layout = results[1]
297
486
 
298
- outputs = self.aggregate(results_ocr, results_layout)
487
+ results_det = _split_text_across_cells(results_det, results_layout)
488
+
489
+ vis_det = None
490
+ if self.visualize:
491
+ vis_det = det_visualizer(
492
+ img,
493
+ results_det.points,
494
+ )
495
+
496
+ results_rec, ocr = self.text_recognizer(img, results_det.points, vis_det)
497
+
498
+ outputs = {"words": ocr_aggregate(results_det, results_rec)}
499
+ results_ocr = OCRSchema(**outputs)
500
+ outputs = self.aggregate(results_ocr, results_layout)
501
+
299
502
  results = DocumentAnalyzerSchema(**outputs)
300
503
  return results, ocr, layout
301
504
 
@@ -1,4 +1,6 @@
1
1
  import csv
2
+ import cv2
3
+ import os
2
4
 
3
5
 
4
6
  def table_to_csv(table, ignore_line_break):
@@ -33,7 +35,34 @@ def paragraph_to_csv(paragraph, ignore_line_break):
33
35
  return contents
34
36
 
35
37
 
36
- def export_csv(inputs, out_path: str, ignore_line_break: bool = False):
38
+ def save_figure(
39
+ figures,
40
+ img,
41
+ out_path,
42
+ figure_dir="figures",
43
+ ):
44
+ for i, figure in enumerate(figures):
45
+ x1, y1, x2, y2 = map(int, figure.box)
46
+ figure_img = img[y1:y2, x1:x2, :]
47
+ save_dir = os.path.dirname(out_path)
48
+ save_dir = os.path.join(save_dir, figure_dir)
49
+ os.makedirs(save_dir, exist_ok=True)
50
+
51
+ filename = os.path.splitext(os.path.basename(out_path))[0]
52
+ figure_name = f"{filename}_figure_{i}.png"
53
+ figure_path = os.path.join(save_dir, figure_name)
54
+ cv2.imwrite(figure_path, figure_img)
55
+
56
+
57
+ def export_csv(
58
+ inputs,
59
+ out_path: str,
60
+ ignore_line_break: bool = False,
61
+ encoding: str = "utf-8",
62
+ img=None,
63
+ export_figure: bool = True,
64
+ figure_dir="figures",
65
+ ):
37
66
  elements = []
38
67
  for table in inputs.tables:
39
68
  table_csv = table_to_csv(table, ignore_line_break)
@@ -58,9 +87,17 @@ def export_csv(inputs, out_path: str, ignore_line_break: bool = False):
58
87
  }
59
88
  )
60
89
 
90
+ if export_figure:
91
+ save_figure(
92
+ inputs.figures,
93
+ img,
94
+ out_path,
95
+ figure_dir=figure_dir,
96
+ )
97
+
61
98
  elements = sorted(elements, key=lambda x: x["order"])
62
99
 
63
- with open(out_path, "w", newline="", encoding="utf-8") as f:
100
+ with open(out_path, "w", newline="", encoding=encoding, errors="ignore") as f:
64
101
  writer = csv.writer(f, quoting=csv.QUOTE_MINIMAL)
65
102
  for element in elements:
66
103
  if element["type"] == "table":
@@ -154,6 +154,7 @@ def export_html(
154
154
  img=None,
155
155
  figure_width=200,
156
156
  figure_dir="figures",
157
+ encoding: str = "utf-8",
157
158
  ):
158
159
  html_string = ""
159
160
  elements = []
@@ -184,5 +185,5 @@ def export_html(
184
185
  parsed_html = html.fromstring(html_string)
185
186
  formatted_html = etree.tostring(parsed_html, pretty_print=True, encoding="unicode")
186
187
 
187
- with open(out_path, "w", encoding="utf-8") as f:
188
+ with open(out_path, "w", encoding=encoding, errors="ignore") as f:
188
189
  f.write(formatted_html)
@@ -1,5 +1,8 @@
1
1
  import json
2
2
 
3
+ import cv2
4
+ import os
5
+
3
6
 
4
7
  def paragraph_to_json(paragraph, ignore_line_break):
5
8
  if ignore_line_break:
@@ -12,7 +15,34 @@ def table_to_json(table, ignore_line_break):
12
15
  cell.contents = cell.contents.replace("\n", "")
13
16
 
14
17
 
15
- def export_json(inputs, out_path, ignore_line_break=False):
18
+ def save_figure(
19
+ figures,
20
+ img,
21
+ out_path,
22
+ figure_dir="figures",
23
+ ):
24
+ for i, figure in enumerate(figures):
25
+ x1, y1, x2, y2 = map(int, figure.box)
26
+ figure_img = img[y1:y2, x1:x2, :]
27
+ save_dir = os.path.dirname(out_path)
28
+ save_dir = os.path.join(save_dir, figure_dir)
29
+ os.makedirs(save_dir, exist_ok=True)
30
+
31
+ filename = os.path.splitext(os.path.basename(out_path))[0]
32
+ figure_name = f"{filename}_figure_{i}.png"
33
+ figure_path = os.path.join(save_dir, figure_name)
34
+ cv2.imwrite(figure_path, figure_img)
35
+
36
+
37
+ def export_json(
38
+ inputs,
39
+ out_path,
40
+ ignore_line_break=False,
41
+ encoding: str = "utf-8",
42
+ img=None,
43
+ export_figure=False,
44
+ figure_dir="figures",
45
+ ):
16
46
  from yomitoku.document_analyzer import DocumentAnalyzerSchema
17
47
 
18
48
  if isinstance(inputs, DocumentAnalyzerSchema):
@@ -23,7 +53,15 @@ def export_json(inputs, out_path, ignore_line_break=False):
23
53
  for paragraph in inputs.paragraphs:
24
54
  paragraph_to_json(paragraph, ignore_line_break)
25
55
 
26
- with open(out_path, "w", encoding="utf-8") as f:
56
+ if export_figure:
57
+ save_figure(
58
+ inputs.figures,
59
+ img,
60
+ out_path,
61
+ figure_dir=figure_dir,
62
+ )
63
+
64
+ with open(out_path, "w", encoding=encoding, errors="ignore") as f:
27
65
  json.dump(
28
66
  inputs.model_dump(),
29
67
  f,
@@ -117,6 +117,7 @@ def export_markdown(
117
117
  export_figure=True,
118
118
  figure_width=200,
119
119
  figure_dir="figures",
120
+ encoding: str = "utf-8",
120
121
  ):
121
122
  elements = []
122
123
  for table in inputs.tables:
@@ -141,5 +142,5 @@ def export_markdown(
141
142
  elements = sorted(elements, key=lambda x: x["order"])
142
143
  markdown = "\n".join([element["md"] for element in elements])
143
144
 
144
- with open(out_path, "w", encoding="utf-8") as f:
145
+ with open(out_path, "w", encoding=encoding, errors="ignore") as f:
145
146
  f.write(markdown)
@@ -15,7 +15,7 @@ class LayoutAnalyzerSchema(BaseSchema):
15
15
 
16
16
 
17
17
  class LayoutAnalyzer:
18
- def __init__(self, configs=None, device="cuda", visualize=False):
18
+ def __init__(self, configs={}, device="cuda", visualize=False):
19
19
  layout_parser_kwargs = {
20
20
  "device": device,
21
21
  "visualize": visualize,
@@ -26,10 +26,6 @@ class LayoutAnalyzer:
26
26
  }
27
27
 
28
28
  if isinstance(configs, dict):
29
- assert (
30
- "layout_parser" in configs or "table_structure_recognizer" in configs
31
- ), "Invalid config key. Please check the config keys."
32
-
33
29
  if "layout_parser" in configs:
34
30
  layout_parser_kwargs.update(configs["layout_parser"])
35
31