yomitoku 0.4.1__py3-none-any.whl → 0.7.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. yomitoku/base.py +1 -1
  2. yomitoku/cli/main.py +219 -27
  3. yomitoku/configs/__init__.py +2 -0
  4. yomitoku/configs/cfg_text_detector_dbnet.py +1 -1
  5. yomitoku/configs/cfg_text_recognizer_parseq_small.py +51 -0
  6. yomitoku/data/functions.py +48 -23
  7. yomitoku/document_analyzer.py +243 -41
  8. yomitoku/export/__init__.py +18 -5
  9. yomitoku/export/export_csv.py +71 -2
  10. yomitoku/export/export_html.py +46 -12
  11. yomitoku/export/export_json.py +66 -3
  12. yomitoku/export/export_markdown.py +42 -6
  13. yomitoku/layout_analyzer.py +2 -9
  14. yomitoku/layout_parser.py +58 -4
  15. yomitoku/models/dbnet_plus.py +13 -39
  16. yomitoku/models/layers/activate.py +13 -0
  17. yomitoku/models/layers/rtdetr_backbone.py +18 -17
  18. yomitoku/models/layers/rtdetr_hybrid_encoder.py +19 -20
  19. yomitoku/models/layers/rtdetrv2_decoder.py +14 -1
  20. yomitoku/models/parseq.py +15 -22
  21. yomitoku/ocr.py +24 -27
  22. yomitoku/onnx/.gitkeep +0 -0
  23. yomitoku/postprocessor/dbnet_postporcessor.py +15 -14
  24. yomitoku/postprocessor/parseq_tokenizer.py +1 -3
  25. yomitoku/postprocessor/rtdetr_postprocessor.py +14 -1
  26. yomitoku/table_structure_recognizer.py +82 -9
  27. yomitoku/text_detector.py +57 -7
  28. yomitoku/text_recognizer.py +84 -16
  29. yomitoku/utils/misc.py +21 -14
  30. yomitoku/utils/visualizer.py +15 -8
  31. {yomitoku-0.4.1.dist-info → yomitoku-0.7.4.dist-info}/METADATA +34 -41
  32. yomitoku-0.7.4.dist-info/RECORD +54 -0
  33. {yomitoku-0.4.1.dist-info → yomitoku-0.7.4.dist-info}/WHEEL +1 -1
  34. yomitoku-0.4.1.dist-info/RECORD +0 -52
  35. {yomitoku-0.4.1.dist-info → yomitoku-0.7.4.dist-info}/entry_points.txt +0 -0
yomitoku/base.py CHANGED
@@ -54,7 +54,7 @@ class BaseSchema(BaseModel):
54
54
  validate_assignment = True
55
55
 
56
56
  def to_json(self, out_path: str, **kwargs):
57
- export_json(self, out_path, **kwargs)
57
+ return export_json(self, out_path, **kwargs)
58
58
 
59
59
 
60
60
  class BaseModule:
yomitoku/cli/main.py CHANGED
@@ -1,30 +1,94 @@
1
1
  import argparse
2
2
  import os
3
+ import time
3
4
  from pathlib import Path
4
5
 
5
6
  import cv2
6
- import time
7
+ import torch
7
8
 
8
9
  from ..constants import SUPPORT_OUTPUT_FORMAT
9
10
  from ..data.functions import load_image, load_pdf
10
11
  from ..document_analyzer import DocumentAnalyzer
11
12
  from ..utils.logger import set_logger
12
13
 
14
+ from ..export import save_csv, save_html, save_json, save_markdown
15
+ from ..export import convert_json, convert_csv, convert_html, convert_markdown
16
+
13
17
  logger = set_logger(__name__, "INFO")
14
18
 
15
19
 
20
+ def merge_all_pages(results):
21
+ out = None
22
+ for result in results:
23
+ format = result["format"]
24
+ data = result["data"]
25
+
26
+ if format == "json":
27
+ if out is None:
28
+ out = [data]
29
+ else:
30
+ out.append(data)
31
+
32
+ elif format == "csv":
33
+ if out is None:
34
+ out = data
35
+ else:
36
+ out.extend(data)
37
+
38
+ elif format == "html":
39
+ if out is None:
40
+ out = data
41
+ else:
42
+ out += "\n" + data
43
+
44
+ elif format == "md":
45
+ if out is None:
46
+ out = data
47
+ else:
48
+ out += "\n" + data
49
+
50
+ return out
51
+
52
+
53
+ def save_merged_file(out_path, args, out):
54
+ if args.format == "json":
55
+ save_json(out, out_path, args.encoding)
56
+ elif args.format == "csv":
57
+ save_csv(out, out_path, args.encoding)
58
+ elif args.format == "html":
59
+ save_html(out, out_path, args.encoding)
60
+ elif args.format == "md":
61
+ save_markdown(out, out_path, args.encoding)
62
+
63
+
64
+ def validate_encoding(encoding):
65
+ if encoding not in [
66
+ "utf-8",
67
+ "utf-8-sig",
68
+ "shift-jis",
69
+ "euc-jp",
70
+ "cp932",
71
+ ]:
72
+ raise ValueError(f"Invalid encoding: {encoding}")
73
+ return True
74
+
75
+
16
76
  def process_single_file(args, analyzer, path, format):
17
77
  if path.suffix[1:].lower() in ["pdf"]:
18
78
  imgs = load_pdf(path)
19
79
  else:
20
- imgs = [load_image(path)]
80
+ imgs = load_image(path)
21
81
 
82
+ results = []
22
83
  for page, img in enumerate(imgs):
23
- results, ocr, layout = analyzer(img)
24
-
84
+ result, ocr, layout = analyzer(img)
25
85
  dirname = path.parent.name
26
86
  filename = path.stem
27
87
 
88
+ # cv2.imwrite(
89
+ # os.path.join(args.outdir, f"{dirname}_{filename}_p{page+1}.jpg"), img
90
+ # )
91
+
28
92
  if ocr is not None:
29
93
  out_path = os.path.join(
30
94
  args.outdir, f"{dirname}_{filename}_p{page+1}_ocr.jpg"
@@ -44,37 +108,129 @@ def process_single_file(args, analyzer, path, format):
44
108
  out_path = os.path.join(args.outdir, f"{dirname}_{filename}_p{page+1}.{format}")
45
109
 
46
110
  if format == "json":
47
- results.to_json(
48
- out_path,
49
- ignore_line_break=args.ignore_line_break,
111
+ if args.combine:
112
+ json = convert_json(
113
+ result,
114
+ out_path,
115
+ args.ignore_line_break,
116
+ img,
117
+ args.figure,
118
+ args.figure_dir,
119
+ )
120
+ else:
121
+ json = result.to_json(
122
+ out_path,
123
+ ignore_line_break=args.ignore_line_break,
124
+ encoding=args.encoding,
125
+ img=img,
126
+ export_figure=args.figure,
127
+ figure_dir=args.figure_dir,
128
+ )
129
+
130
+ results.append(
131
+ {
132
+ "format": format,
133
+ "data": json.model_dump(),
134
+ }
50
135
  )
136
+
51
137
  elif format == "csv":
52
- results.to_csv(
53
- out_path,
54
- ignore_line_break=args.ignore_line_break,
138
+ if args.combine:
139
+ csv = convert_csv(
140
+ result,
141
+ out_path,
142
+ args.ignore_line_break,
143
+ img,
144
+ args.figure,
145
+ args.figure_dir,
146
+ )
147
+ else:
148
+ csv = result.to_csv(
149
+ out_path,
150
+ ignore_line_break=args.ignore_line_break,
151
+ encoding=args.encoding,
152
+ img=img,
153
+ export_figure=args.figure,
154
+ figure_dir=args.figure_dir,
155
+ )
156
+
157
+ results.append(
158
+ {
159
+ "format": format,
160
+ "data": csv,
161
+ }
55
162
  )
163
+
56
164
  elif format == "html":
57
- results.to_html(
58
- out_path,
59
- ignore_line_break=args.ignore_line_break,
60
- img=img,
61
- export_figure=args.figure,
62
- export_figure_letter=args.figure_letter,
63
- figure_width=args.figure_width,
64
- figure_dir=args.figure_dir,
165
+ if args.combine:
166
+ html, _ = convert_html(
167
+ result,
168
+ out_path,
169
+ ignore_line_break=args.ignore_line_break,
170
+ img=img,
171
+ export_figure=args.figure,
172
+ export_figure_letter=args.figure_letter,
173
+ figure_width=args.figure_width,
174
+ figure_dir=args.figure_dir,
175
+ )
176
+ else:
177
+ html = result.to_html(
178
+ out_path,
179
+ ignore_line_break=args.ignore_line_break,
180
+ img=img,
181
+ export_figure=args.figure,
182
+ export_figure_letter=args.figure_letter,
183
+ figure_width=args.figure_width,
184
+ figure_dir=args.figure_dir,
185
+ encoding=args.encoding,
186
+ )
187
+
188
+ results.append(
189
+ {
190
+ "format": format,
191
+ "data": html,
192
+ }
65
193
  )
194
+
66
195
  elif format == "md":
67
- results.to_markdown(
68
- out_path,
69
- ignore_line_break=args.ignore_line_break,
70
- img=img,
71
- export_figure=args.figure,
72
- export_figure_letter=args.figure_letter,
73
- figure_width=args.figure_width,
74
- figure_dir=args.figure_dir,
196
+ if args.combine:
197
+ md, _ = convert_markdown(
198
+ result,
199
+ out_path,
200
+ ignore_line_break=args.ignore_line_break,
201
+ img=img,
202
+ export_figure=args.figure,
203
+ export_figure_letter=args.figure_letter,
204
+ figure_width=args.figure_width,
205
+ figure_dir=args.figure_dir,
206
+ )
207
+ else:
208
+ md = result.to_markdown(
209
+ out_path,
210
+ ignore_line_break=args.ignore_line_break,
211
+ img=img,
212
+ export_figure=args.figure,
213
+ export_figure_letter=args.figure_letter,
214
+ figure_width=args.figure_width,
215
+ figure_dir=args.figure_dir,
216
+ encoding=args.encoding,
217
+ )
218
+
219
+ results.append(
220
+ {
221
+ "format": format,
222
+ "data": md,
223
+ }
75
224
  )
76
225
 
77
- logger.info(f"Output file: {out_path}")
226
+ out = merge_all_pages(results)
227
+ if args.combine:
228
+ out_path = os.path.join(args.outdir, f"{dirname}_{filename}.{format}")
229
+ save_merged_file(
230
+ out_path,
231
+ args,
232
+ out,
233
+ )
78
234
 
79
235
 
80
236
  def main():
@@ -104,6 +260,12 @@ def main():
104
260
  default="results",
105
261
  help="output directory",
106
262
  )
263
+ parser.add_argument(
264
+ "-l",
265
+ "--lite",
266
+ action="store_true",
267
+ help="if set, use lite model",
268
+ )
107
269
  parser.add_argument(
108
270
  "-d",
109
271
  "--device",
@@ -162,6 +324,22 @@ def main():
162
324
  default="figures",
163
325
  help="directory to save figure images",
164
326
  )
327
+ parser.add_argument(
328
+ "--encoding",
329
+ type=str,
330
+ default="utf-8",
331
+ help="Specifies the character encoding for the output file to be exported. If unsupported characters are included, they will be ignored.",
332
+ )
333
+ parser.add_argument(
334
+ "--combine",
335
+ action="store_true",
336
+ help="if set, merge all pages in the output",
337
+ )
338
+ parser.add_argument(
339
+ "--ignore_meta",
340
+ action="store_true",
341
+ help="if set, ignore meta information(header, footer) in the output",
342
+ )
165
343
 
166
344
  args = parser.parse_args()
167
345
 
@@ -175,6 +353,8 @@ def main():
175
353
  f"Invalid output format: {args.format}. Supported formats are {SUPPORT_OUTPUT_FORMAT}"
176
354
  )
177
355
 
356
+ validate_encoding(args.encoding)
357
+
178
358
  if format == "markdown":
179
359
  format = "md"
180
360
 
@@ -197,10 +377,22 @@ def main():
197
377
  },
198
378
  }
199
379
 
380
+ if args.lite:
381
+ configs["ocr"]["text_recognizer"]["model_name"] = "parseq-small"
382
+
383
+ if args.device == "cpu" or not torch.cuda.is_available():
384
+ configs["ocr"]["text_detector"]["infer_onnx"] = True
385
+
386
+ # Note: Text Detector以外はONNX推論よりもPyTorch推論の方が速いため、ONNX推論は行わない
387
+ # configs["ocr"]["text_recognizer"]["infer_onnx"] = True
388
+ # configs["layout_analyzer"]["table_structure_recognizer"]["infer_onnx"] = True
389
+ # configs["layout_analyzer"]["layout_parser"]["infer_onnx"] = True
390
+
200
391
  analyzer = DocumentAnalyzer(
201
392
  configs=configs,
202
393
  visualize=args.vis,
203
394
  device=args.device,
395
+ ignore_meta=args.ignore_meta,
204
396
  )
205
397
 
206
398
  os.makedirs(args.outdir, exist_ok=True)
@@ -4,10 +4,12 @@ from .cfg_table_structure_recognizer_rtdtrv2 import (
4
4
  )
5
5
  from .cfg_text_detector_dbnet import TextDetectorDBNetConfig
6
6
  from .cfg_text_recognizer_parseq import TextRecognizerPARSeqConfig
7
+ from .cfg_text_recognizer_parseq_small import TextRecognizerPARSeqSmallConfig
7
8
 
8
9
  __all__ = [
9
10
  "TextDetectorDBNetConfig",
10
11
  "TextRecognizerPARSeqConfig",
11
12
  "LayoutParserRTDETRv2Config",
12
13
  "TableStructureRecognizerRTDETRv2Config",
14
+ "TextRecognizerPARSeqSmallConfig",
13
15
  ]
@@ -30,7 +30,7 @@ class PostProcess:
30
30
  thresh: float = 0.2
31
31
  box_thresh: float = 0.5
32
32
  max_candidates: int = 1500
33
- unclip_ratio: float = 2.0
33
+ unclip_ratio: float = 7.0
34
34
 
35
35
 
36
36
  @dataclass
@@ -0,0 +1,51 @@
1
+ from dataclasses import dataclass, field
2
+ from typing import List
3
+
4
+ from ..constants import ROOT_DIR
5
+
6
+
7
+ @dataclass
8
+ class Data:
9
+ num_workers: int = 4
10
+ batch_size: int = 128
11
+ img_size: List[int] = field(default_factory=lambda: [32, 800])
12
+
13
+
14
+ @dataclass
15
+ class Encoder:
16
+ patch_size: List[int] = field(default_factory=lambda: [16, 16])
17
+ num_heads: int = 8
18
+ embed_dim: int = 384
19
+ mlp_ratio: int = 4
20
+ depth: int = 9
21
+
22
+
23
+ @dataclass
24
+ class Decoder:
25
+ embed_dim: int = 384
26
+ num_heads: int = 8
27
+ mlp_ratio: int = 4
28
+ depth: int = 1
29
+
30
+
31
+ @dataclass
32
+ class Visualize:
33
+ font: str = str(ROOT_DIR + "/resource/MPLUS1p-Medium.ttf")
34
+ color: List[int] = field(default_factory=lambda: [0, 0, 255]) # RGB
35
+ font_size: int = 18
36
+
37
+
38
+ @dataclass
39
+ class TextRecognizerPARSeqSmallConfig:
40
+ hf_hub_repo: str = "KotaroKinoshita/yomitoku-text-recognizer-parseq-small-open-beta"
41
+ charset: str = str(ROOT_DIR + "/resource/charset.txt")
42
+ num_tokens: int = 7312
43
+ max_label_length: int = 100
44
+ decode_ar: int = 1
45
+ refine_iters: int = 1
46
+
47
+ data: Data = field(default_factory=Data)
48
+ encoder: Encoder = field(default_factory=Encoder)
49
+ decoder: Decoder = field(default_factory=Decoder)
50
+
51
+ visualize: Visualize = field(default_factory=Visualize)
@@ -1,9 +1,10 @@
1
1
  from pathlib import Path
2
2
 
3
3
  import cv2
4
+ from PIL import Image
4
5
  import numpy as np
5
6
  import torch
6
- from pdf2image import convert_from_path
7
+ import pypdfium2
7
8
 
8
9
  from ..constants import (
9
10
  MIN_IMAGE_SIZE,
@@ -15,6 +16,20 @@ from ..utils.logger import set_logger
15
16
  logger = set_logger(__name__)
16
17
 
17
18
 
19
+ def validate_image(img: np.ndarray):
20
+ h, w = img.shape[:2]
21
+ if h < MIN_IMAGE_SIZE or w < MIN_IMAGE_SIZE:
22
+ raise ValueError("Image size is too small.")
23
+
24
+ if min(h, w) < WARNING_IMAGE_SIZE:
25
+ logger.warning(
26
+ """
27
+ The image size is small, which may result in reduced OCR accuracy.
28
+ The process will continue, but it is recommended to input images with a minimum size of 720 pixels on the shorter side.
29
+ """
30
+ )
31
+
32
+
18
33
  def load_image(image_path: str) -> np.ndarray:
19
34
  """
20
35
  Open an image file.
@@ -40,24 +55,27 @@ def load_image(image_path: str) -> np.ndarray:
40
55
  "PDF file is not supported by load_image(). Use load_pdf() instead."
41
56
  )
42
57
 
43
- img = cv2.imread(image_path, cv2.IMREAD_COLOR)
44
-
45
- if img is None:
58
+ try:
59
+ img = Image.open(image_path)
60
+ except Exception:
46
61
  raise ValueError("Invalid image data.")
47
62
 
48
- h, w = img.shape[:2]
49
- if h < MIN_IMAGE_SIZE or w < MIN_IMAGE_SIZE:
50
- raise ValueError("Image size is too small.")
51
-
52
- if min(h, w) < WARNING_IMAGE_SIZE:
53
- logger.warning(
54
- """
55
- The image size is small, which may result in reduced OCR accuracy.
56
- The process will continue, but it is recommended to input images with a minimum size of 720 pixels on the shorter side.
57
- """
58
- )
63
+ pages = []
64
+ if ext in ["tif", "tiff"]:
65
+ try:
66
+ while True:
67
+ img_arr = np.array(img.copy().convert("RGB"))
68
+ validate_image(img_arr)
69
+ pages.append(img_arr[:, :, ::-1])
70
+ img.seek(img.tell() + 1)
71
+ except EOFError:
72
+ pass
73
+ else:
74
+ img_arr = np.array(img.convert("RGB"))
75
+ validate_image(img_arr)
76
+ pages.append(img_arr[:, :, ::-1])
59
77
 
60
- return img
78
+ return pages
61
79
 
62
80
 
63
81
  def load_pdf(pdf_path: str, dpi=200) -> list[np.ndarray]:
@@ -70,6 +88,7 @@ def load_pdf(pdf_path: str, dpi=200) -> list[np.ndarray]:
70
88
  Returns:
71
89
  list[np.ndarray]: list of image data(BGR)
72
90
  """
91
+
73
92
  pdf_path = Path(pdf_path)
74
93
  if not pdf_path.exists():
75
94
  raise FileNotFoundError(f"File not found: {pdf_path}")
@@ -86,11 +105,19 @@ def load_pdf(pdf_path: str, dpi=200) -> list[np.ndarray]:
86
105
  )
87
106
 
88
107
  try:
89
- images = convert_from_path(pdf_path, dpi=dpi)
108
+ doc = pypdfium2.PdfDocument(pdf_path)
109
+ renderer = doc.render(
110
+ pypdfium2.PdfBitmap.to_pil,
111
+ scale=dpi / 72,
112
+ )
113
+ images = list(renderer)
114
+ images = [np.array(image.convert("RGB"))[:, :, ::-1] for image in images]
115
+
116
+ doc.close()
90
117
  except Exception as e:
91
118
  raise ValueError(f"Failed to open the PDF file: {pdf_path}") from e
92
119
 
93
- return [np.array(img)[:, :, ::-1] for img in images]
120
+ return images
94
121
 
95
122
 
96
123
  def resize_shortest_edge(
@@ -123,7 +150,7 @@ def resize_shortest_edge(
123
150
  neww = max(int(new_w / 32) * 32, 32)
124
151
  newh = max(int(new_h / 32) * 32, 32)
125
152
 
126
- img = cv2.resize(img, (neww, newh))
153
+ img = cv2.resize(img, (neww, newh), interpolation=cv2.INTER_AREA)
127
154
  return img
128
155
 
129
156
 
@@ -193,9 +220,7 @@ def validate_quads(img: np.ndarray, quads: list[list[list[int]]]):
193
220
  h, w = img.shape[:2]
194
221
 
195
222
  if x1 < 0 or x2 > w or y1 < 0 or y2 > h:
196
- raise ValueError(
197
- f"The vertices are out of the image. {quad.tolist()}"
198
- )
223
+ raise ValueError(f"The vertices are out of the image. {quad.tolist()}")
199
224
 
200
225
  return True
201
226
 
@@ -268,7 +293,7 @@ def resize_with_padding(img, target_size, background_color=(0, 0, 0)):
268
293
  new_w = int(w * min(scale_w, scale_h))
269
294
  new_h = int(h * min(scale_w, scale_h))
270
295
 
271
- resized = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_LANCZOS4)
296
+ resized = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_AREA)
272
297
 
273
298
  canvas = np.zeros((target_size[0], target_size[1], 3), dtype=np.uint8)
274
299
  canvas[:, :] = background_color