yomitoku 0.4.0.post1.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. yomitoku/__init__.py +20 -0
  2. yomitoku/base.py +136 -0
  3. yomitoku/cli/__init__.py +0 -0
  4. yomitoku/cli/main.py +230 -0
  5. yomitoku/configs/__init__.py +13 -0
  6. yomitoku/configs/cfg_layout_parser_rtdtrv2.py +89 -0
  7. yomitoku/configs/cfg_table_structure_recognizer_rtdtrv2.py +80 -0
  8. yomitoku/configs/cfg_text_detector_dbnet.py +49 -0
  9. yomitoku/configs/cfg_text_recognizer_parseq.py +51 -0
  10. yomitoku/constants.py +32 -0
  11. yomitoku/data/__init__.py +3 -0
  12. yomitoku/data/dataset.py +40 -0
  13. yomitoku/data/functions.py +279 -0
  14. yomitoku/document_analyzer.py +315 -0
  15. yomitoku/export/__init__.py +6 -0
  16. yomitoku/export/export_csv.py +71 -0
  17. yomitoku/export/export_html.py +188 -0
  18. yomitoku/export/export_json.py +34 -0
  19. yomitoku/export/export_markdown.py +145 -0
  20. yomitoku/layout_analyzer.py +66 -0
  21. yomitoku/layout_parser.py +189 -0
  22. yomitoku/models/__init__.py +9 -0
  23. yomitoku/models/dbnet_plus.py +272 -0
  24. yomitoku/models/layers/__init__.py +0 -0
  25. yomitoku/models/layers/activate.py +38 -0
  26. yomitoku/models/layers/dbnet_feature_attention.py +160 -0
  27. yomitoku/models/layers/parseq_transformer.py +218 -0
  28. yomitoku/models/layers/rtdetr_backbone.py +333 -0
  29. yomitoku/models/layers/rtdetr_hybrid_encoder.py +433 -0
  30. yomitoku/models/layers/rtdetrv2_decoder.py +811 -0
  31. yomitoku/models/parseq.py +243 -0
  32. yomitoku/models/rtdetr.py +22 -0
  33. yomitoku/ocr.py +87 -0
  34. yomitoku/postprocessor/__init__.py +9 -0
  35. yomitoku/postprocessor/dbnet_postporcessor.py +137 -0
  36. yomitoku/postprocessor/parseq_tokenizer.py +128 -0
  37. yomitoku/postprocessor/rtdetr_postprocessor.py +107 -0
  38. yomitoku/reading_order.py +214 -0
  39. yomitoku/resource/MPLUS1p-Medium.ttf +0 -0
  40. yomitoku/resource/charset.txt +1 -0
  41. yomitoku/table_structure_recognizer.py +244 -0
  42. yomitoku/text_detector.py +103 -0
  43. yomitoku/text_recognizer.py +128 -0
  44. yomitoku/utils/__init__.py +0 -0
  45. yomitoku/utils/graph.py +20 -0
  46. yomitoku/utils/logger.py +15 -0
  47. yomitoku/utils/misc.py +102 -0
  48. yomitoku/utils/visualizer.py +179 -0
  49. yomitoku-0.4.0.post1.dev0.dist-info/METADATA +127 -0
  50. yomitoku-0.4.0.post1.dev0.dist-info/RECORD +52 -0
  51. yomitoku-0.4.0.post1.dev0.dist-info/WHEEL +4 -0
  52. yomitoku-0.4.0.post1.dev0.dist-info/entry_points.txt +2 -0
yomitoku/__init__.py ADDED
@@ -0,0 +1,20 @@
1
+ from importlib.metadata import version
2
+
3
+ from .document_analyzer import DocumentAnalyzer
4
+ from .layout_analyzer import LayoutAnalyzer
5
+ from .layout_parser import LayoutParser
6
+ from .ocr import OCR
7
+ from .table_structure_recognizer import TableStructureRecognizer
8
+ from .text_detector import TextDetector
9
+ from .text_recognizer import TextRecognizer
10
+
11
+ __all__ = [
12
+ "OCR",
13
+ "LayoutParser",
14
+ "TableStructureRecognizer",
15
+ "TextDetector",
16
+ "TextRecognizer",
17
+ "LayoutAnalyzer",
18
+ "DocumentAnalyzer",
19
+ ]
20
+ __version__ = version(__package__)
yomitoku/base.py ADDED
@@ -0,0 +1,136 @@
1
+ import time
2
+ from pathlib import Path
3
+ from typing import Union
4
+
5
+ import torch
6
+ from omegaconf import OmegaConf
7
+ from pydantic import BaseModel, Extra
8
+
9
+ from .export import export_json
10
+ from .utils.logger import set_logger
11
+
12
+ logger = set_logger(__name__, "INFO")
13
+
14
+
15
+ def load_yaml_config(path_config: str):
16
+ path_config = Path(path_config)
17
+ if not path_config.exists():
18
+ raise FileNotFoundError(f"Config file not found: {path_config}")
19
+
20
+ with open(path_config, "r") as file:
21
+ yaml_config = OmegaConf.load(file)
22
+ return yaml_config
23
+
24
+
25
+ def load_config(
26
+ default_config,
27
+ path_config: Union[str, None] = None,
28
+ ):
29
+ cfg = OmegaConf.structured(default_config)
30
+ if path_config is not None:
31
+ yaml_config = load_yaml_config(path_config)
32
+ cfg = OmegaConf.merge(cfg, yaml_config)
33
+ return cfg
34
+
35
+
36
+ def observer(cls, func):
37
+ def wrapper(*args, **kwargs):
38
+ try:
39
+ start = time.time()
40
+ result = func(*args, **kwargs)
41
+ elapsed = time.time() - start
42
+ logger.info(f"{cls.__name__} {func.__name__} elapsed_time: {elapsed}")
43
+ except Exception as e:
44
+ logger.error(f"Error occurred in {cls.__name__} {func.__name__}: {e}")
45
+ raise e
46
+ return result
47
+
48
+ return wrapper
49
+
50
+
51
+ class BaseSchema(BaseModel):
52
+ class Config:
53
+ extra = Extra.forbid
54
+ validate_assignment = True
55
+
56
+ def to_json(self, out_path: str, **kwargs):
57
+ export_json(self, out_path, **kwargs)
58
+
59
+
60
+ class BaseModule:
61
+ model_catalog = None
62
+
63
+ def __init__(self):
64
+ if self.model_catalog is None:
65
+ raise NotImplementedError
66
+
67
+ if not issubclass(self.model_catalog.__class__, BaseModelCatalog):
68
+ raise ValueError(
69
+ f"{self.model_catalog.__class__} is not SubClass BaseModelCatalog."
70
+ )
71
+
72
+ if len(self.model_catalog.list_model()) == 0:
73
+ raise ValueError("No model is registered.")
74
+
75
+ def __new__(cls, *args, **kwds):
76
+ logger.info(f"Initialize {cls.__name__}")
77
+ cls.__call__ = observer(cls, cls.__call__)
78
+ return super().__new__(cls)
79
+
80
+ def load_model(self, name, path_cfg, from_pretrained=True):
81
+ default_cfg, Net = self.model_catalog.get(name)
82
+ self._cfg = load_config(default_cfg, path_cfg)
83
+ if from_pretrained:
84
+ self.model = Net.from_pretrained(self._cfg.hf_hub_repo, cfg=self._cfg)
85
+ else:
86
+ self.model = Net(cfg=self._cfg)
87
+
88
+ def save_config(self, path_cfg):
89
+ OmegaConf.save(self._cfg, path_cfg)
90
+
91
+ def log_config(self):
92
+ logger.info(OmegaConf.to_yaml(self._cfg))
93
+
94
+ @classmethod
95
+ def catalog(cls):
96
+ display = ""
97
+ for model in cls.model_catalog.list_model():
98
+ display += f"{model} "
99
+ logger.info(f"{cls.__name__} Implemented Models")
100
+ logger.info(display)
101
+
102
+ @property
103
+ def device(self):
104
+ return self._device
105
+
106
+ @device.setter
107
+ def device(self, device):
108
+ if "cuda" in device:
109
+ if torch.cuda.is_available():
110
+ self._device = torch.device(device)
111
+ else:
112
+ self._device = torch.device("cpu")
113
+ logger.warning("CUDA is not available. Use CPU instead.")
114
+ else:
115
+ self._device = torch.device("cpu")
116
+
117
+
118
+ class BaseModelCatalog:
119
+ def __init__(self):
120
+ self.catalog = {}
121
+
122
+ def get(self, model_name):
123
+ model_name = model_name.lower()
124
+ if model_name in self.catalog:
125
+ return self.catalog[model_name]
126
+
127
+ raise ValueError(f"Unknown model: {model_name}")
128
+
129
+ def register(self, model_name, config, model):
130
+ if model_name in self.catalog:
131
+ raise ValueError(f"{model_name} is already registered.")
132
+
133
+ self.catalog[model_name] = (config, model)
134
+
135
+ def list_model(self):
136
+ return list(self.catalog.keys())
File without changes
yomitoku/cli/main.py ADDED
@@ -0,0 +1,230 @@
1
+ import argparse
2
+ import os
3
+ from pathlib import Path
4
+
5
+ import cv2
6
+ import time
7
+
8
+ from ..constants import SUPPORT_OUTPUT_FORMAT
9
+ from ..data.functions import load_image, load_pdf
10
+ from ..document_analyzer import DocumentAnalyzer
11
+ from ..utils.logger import set_logger
12
+
13
+ logger = set_logger(__name__, "INFO")
14
+
15
+
16
+ def process_single_file(args, analyzer, path, format):
17
+ if path.suffix[1:].lower() in ["pdf"]:
18
+ imgs = load_pdf(path)
19
+ else:
20
+ imgs = [load_image(path)]
21
+
22
+ for page, img in enumerate(imgs):
23
+ results, ocr, layout = analyzer(img)
24
+
25
+ dirname = path.parent.name
26
+ filename = path.stem
27
+
28
+ if ocr is not None:
29
+ out_path = os.path.join(
30
+ args.outdir, f"{dirname}_{filename}_p{page+1}_ocr.jpg"
31
+ )
32
+
33
+ cv2.imwrite(out_path, ocr)
34
+ logger.info(f"Output file: {out_path}")
35
+
36
+ if layout is not None:
37
+ out_path = os.path.join(
38
+ args.outdir, f"{dirname}_{filename}_p{page+1}_layout.jpg"
39
+ )
40
+
41
+ cv2.imwrite(out_path, layout)
42
+ logger.info(f"Output file: {out_path}")
43
+
44
+ out_path = os.path.join(args.outdir, f"{dirname}_{filename}_p{page+1}.{format}")
45
+
46
+ if format == "json":
47
+ results.to_json(
48
+ out_path,
49
+ ignore_line_break=args.ignore_line_break,
50
+ )
51
+ elif format == "csv":
52
+ results.to_csv(
53
+ out_path,
54
+ ignore_line_break=args.ignore_line_break,
55
+ )
56
+ elif format == "html":
57
+ results.to_html(
58
+ out_path,
59
+ ignore_line_break=args.ignore_line_break,
60
+ img=img,
61
+ export_figure=args.figure,
62
+ export_figure_letter=args.figure_letter,
63
+ figure_width=args.figure_width,
64
+ figure_dir=args.figure_dir,
65
+ )
66
+ elif format == "md":
67
+ results.to_markdown(
68
+ out_path,
69
+ ignore_line_break=args.ignore_line_break,
70
+ img=img,
71
+ export_figure=args.figure,
72
+ export_figure_letter=args.figure_letter,
73
+ figure_width=args.figure_width,
74
+ figure_dir=args.figure_dir,
75
+ )
76
+
77
+ logger.info(f"Output file: {out_path}")
78
+
79
+
80
+ def main():
81
+ parser = argparse.ArgumentParser()
82
+ parser.add_argument(
83
+ "arg1",
84
+ type=str,
85
+ help="path of target image file or directory",
86
+ )
87
+ parser.add_argument(
88
+ "-f",
89
+ "--format",
90
+ type=str,
91
+ default="json",
92
+ help="output format type (json or csv or html or md)",
93
+ )
94
+ parser.add_argument(
95
+ "-v",
96
+ "--vis",
97
+ action="store_true",
98
+ help="if set, visualize the result",
99
+ )
100
+ parser.add_argument(
101
+ "-o",
102
+ "--outdir",
103
+ type=str,
104
+ default="results",
105
+ help="output directory",
106
+ )
107
+ parser.add_argument(
108
+ "-d",
109
+ "--device",
110
+ type=str,
111
+ default="cuda",
112
+ help="device to use",
113
+ )
114
+ parser.add_argument(
115
+ "--td_cfg",
116
+ type=str,
117
+ default=None,
118
+ help="path of text detector config file",
119
+ )
120
+ parser.add_argument(
121
+ "--tr_cfg",
122
+ type=str,
123
+ default=None,
124
+ help="path of text recognizer config file",
125
+ )
126
+ parser.add_argument(
127
+ "--lp_cfg",
128
+ type=str,
129
+ default=None,
130
+ help="path of layout parser config file",
131
+ )
132
+ parser.add_argument(
133
+ "--tsr_cfg",
134
+ type=str,
135
+ default=None,
136
+ help="path of table structure recognizer config file",
137
+ )
138
+ parser.add_argument(
139
+ "--ignore_line_break",
140
+ action="store_true",
141
+ help="if set, ignore line break in the output",
142
+ )
143
+ parser.add_argument(
144
+ "--figure",
145
+ action="store_true",
146
+ help="if set, export figure in the output",
147
+ )
148
+ parser.add_argument(
149
+ "--figure_letter",
150
+ action="store_true",
151
+ help="if set, export letter within figure in the output",
152
+ )
153
+ parser.add_argument(
154
+ "--figure_width",
155
+ type=int,
156
+ default=200,
157
+ help="width of figure image in the output",
158
+ )
159
+ parser.add_argument(
160
+ "--figure_dir",
161
+ type=str,
162
+ default="figures",
163
+ help="directory to save figure images",
164
+ )
165
+
166
+ args = parser.parse_args()
167
+
168
+ path = Path(args.arg1)
169
+ if not path.exists():
170
+ raise FileNotFoundError(f"File not found: {args.arg1}")
171
+
172
+ format = args.format.lower()
173
+ if format not in SUPPORT_OUTPUT_FORMAT:
174
+ raise ValueError(
175
+ f"Invalid output format: {args.format}. Supported formats are {SUPPORT_OUTPUT_FORMAT}"
176
+ )
177
+
178
+ if format == "markdown":
179
+ format = "md"
180
+
181
+ configs = {
182
+ "ocr": {
183
+ "text_detector": {
184
+ "path_cfg": args.td_cfg,
185
+ },
186
+ "text_recognizer": {
187
+ "path_cfg": args.tr_cfg,
188
+ },
189
+ },
190
+ "layout_analyzer": {
191
+ "layout_parser": {
192
+ "path_cfg": args.lp_cfg,
193
+ },
194
+ "table_structure_recognizer": {
195
+ "path_cfg": args.tsr_cfg,
196
+ },
197
+ },
198
+ }
199
+
200
+ analyzer = DocumentAnalyzer(
201
+ configs=configs,
202
+ visualize=args.vis,
203
+ device=args.device,
204
+ )
205
+
206
+ os.makedirs(args.outdir, exist_ok=True)
207
+ logger.info(f"Output directory: {args.outdir}")
208
+
209
+ if path.is_dir():
210
+ all_files = [f for f in path.rglob("*") if f.is_file()]
211
+ for f in all_files:
212
+ try:
213
+ start = time.time()
214
+ file_path = Path(f)
215
+ logger.info(f"Processing file: {file_path}")
216
+ process_single_file(args, analyzer, file_path, format)
217
+ end = time.time()
218
+ logger.info(f"Total Processing time: {end-start:.2f} sec")
219
+ except Exception:
220
+ continue
221
+ else:
222
+ start = time.time()
223
+ logger.info(f"Processing file: {path}")
224
+ process_single_file(args, analyzer, path, format)
225
+ end = time.time()
226
+ logger.info(f"Total Processing time: {end-start:.2f} sec")
227
+
228
+
229
+ if __name__ == "__main__":
230
+ main()
@@ -0,0 +1,13 @@
1
+ from .cfg_layout_parser_rtdtrv2 import LayoutParserRTDETRv2Config
2
+ from .cfg_table_structure_recognizer_rtdtrv2 import (
3
+ TableStructureRecognizerRTDETRv2Config,
4
+ )
5
+ from .cfg_text_detector_dbnet import TextDetectorDBNetConfig
6
+ from .cfg_text_recognizer_parseq import TextRecognizerPARSeqConfig
7
+
8
+ __all__ = [
9
+ "TextDetectorDBNetConfig",
10
+ "TextRecognizerPARSeqConfig",
11
+ "LayoutParserRTDETRv2Config",
12
+ "TableStructureRecognizerRTDETRv2Config",
13
+ ]
@@ -0,0 +1,89 @@
1
+ from dataclasses import dataclass, field
2
+ from typing import List
3
+
4
+
5
+ @dataclass
6
+ class Data:
7
+ img_size: List[int] = field(default_factory=lambda: [640, 640])
8
+
9
+
10
+ @dataclass
11
+ class BackBone:
12
+ depth: int = 50
13
+ variant: str = "d"
14
+ freeze_at: int = 0
15
+ return_idx: List[int] = field(default_factory=lambda: [1, 2, 3])
16
+ num_stages: int = 4
17
+ freeze_norm: bool = True
18
+
19
+
20
+ @dataclass
21
+ class Encoder:
22
+ in_channels: List[int] = field(default_factory=lambda: [512, 1024, 2048])
23
+ feat_strides: List[int] = field(default_factory=lambda: [8, 16, 32])
24
+
25
+ # intra
26
+ hidden_dim: int = 256
27
+ use_encoder_idx: List[int] = field(default_factory=lambda: [2])
28
+ num_encoder_layers: int = 1
29
+ nhead: int = 8
30
+ dim_feedforward: int = 1024
31
+ dropout: float = 0.0
32
+ enc_act: str = "gelu"
33
+
34
+ # cross
35
+ expansion: float = 1.0
36
+ depth_mult: int = 1
37
+ act: str = "silu"
38
+
39
+
40
+ @dataclass
41
+ class Decoder:
42
+ num_classes: int = 6
43
+ feat_channels: List[int] = field(default_factory=lambda: [256, 256, 256])
44
+ feat_strides: List[int] = field(default_factory=lambda: [8, 16, 32])
45
+ hidden_dim: int = 256
46
+ num_levels: int = 3
47
+
48
+ num_layers: int = 6
49
+ num_queries: int = 300
50
+
51
+ num_denoising: int = 100
52
+ label_noise_ratio: float = 0.5
53
+ box_noise_scale: float = 1.0
54
+ eval_spatial_size: List[int] = field(default_factory=lambda: [640, 640])
55
+
56
+ eval_idx: int = -1
57
+
58
+ num_points: List[int] = field(default_factory=lambda: [4, 4, 4])
59
+ cross_attn_method: str = "default"
60
+ query_select_method: str = "default"
61
+
62
+
63
+ @dataclass
64
+ class LayoutParserRTDETRv2Config:
65
+ hf_hub_repo: str = "KotaroKinoshita/yomitoku-layout-parser-rtdtrv2-open-beta"
66
+ thresh_score: float = 0.5
67
+ data: Data = field(default_factory=Data)
68
+ PResNet: BackBone = field(default_factory=BackBone)
69
+ HybridEncoder: Encoder = field(default_factory=Encoder)
70
+ RTDETRTransformerv2: Decoder = field(default_factory=Decoder)
71
+
72
+ category: List[str] = field(
73
+ default_factory=lambda: [
74
+ "tables",
75
+ "figures",
76
+ "paragraphs",
77
+ "section_headings",
78
+ "page_header",
79
+ "page_footer",
80
+ ]
81
+ )
82
+
83
+ role: List[str] = field(
84
+ default_factory=lambda: [
85
+ "section_headings",
86
+ "page_header",
87
+ "page_footer",
88
+ ]
89
+ )
@@ -0,0 +1,80 @@
1
+ from dataclasses import dataclass, field
2
+ from typing import List
3
+
4
+
5
+ @dataclass
6
+ class Data:
7
+ img_size: List[int] = field(default_factory=lambda: [640, 640])
8
+
9
+
10
+ @dataclass
11
+ class BackBone:
12
+ depth: int = 50
13
+ variant: str = "d"
14
+ freeze_at: int = 0
15
+ return_idx: List[int] = field(default_factory=lambda: [1, 2, 3])
16
+ num_stages: int = 4
17
+ freeze_norm: bool = True
18
+
19
+
20
+ @dataclass
21
+ class Encoder:
22
+ in_channels: List[int] = field(default_factory=lambda: [512, 1024, 2048])
23
+ feat_strides: List[int] = field(default_factory=lambda: [8, 16, 32])
24
+
25
+ # intra
26
+ hidden_dim: int = 256
27
+ use_encoder_idx: List[int] = field(default_factory=lambda: [2])
28
+ num_encoder_layers: int = 1
29
+ nhead: int = 8
30
+ dim_feedforward: int = 1024
31
+ dropout: float = 0.0
32
+ enc_act: str = "gelu"
33
+
34
+ # cross
35
+ expansion: float = 1.0
36
+ depth_mult: int = 1
37
+ act: str = "silu"
38
+
39
+
40
+ @dataclass
41
+ class Decoder:
42
+ num_classes: int = 3
43
+ feat_channels: List[int] = field(default_factory=lambda: [256, 256, 256])
44
+ feat_strides: List[int] = field(default_factory=lambda: [8, 16, 32])
45
+ hidden_dim: int = 256
46
+ num_levels: int = 3
47
+
48
+ num_layers: int = 6
49
+ num_queries: int = 300
50
+
51
+ num_denoising: int = 100
52
+ label_noise_ratio: float = 0.5
53
+ box_noise_scale: float = 1.0 # 1.0 0.4
54
+ eval_spatial_size: List[int] = field(default_factory=lambda: [640, 640])
55
+
56
+ eval_idx: int = -1
57
+
58
+ num_points: List[int] = field(default_factory=lambda: [4, 4, 4])
59
+ cross_attn_method: str = "default"
60
+ query_select_method: str = "default"
61
+
62
+
63
+ @dataclass
64
+ class TableStructureRecognizerRTDETRv2Config:
65
+ hf_hub_repo: str = (
66
+ "KotaroKinoshita/yomitoku-table-structure-recognizer-rtdtrv2-open-beta"
67
+ )
68
+ thresh_score: float = 0.4
69
+ data: Data = field(default_factory=Data)
70
+ PResNet: BackBone = field(default_factory=BackBone)
71
+ HybridEncoder: Encoder = field(default_factory=Encoder)
72
+ RTDETRTransformerv2: Decoder = field(default_factory=Decoder)
73
+
74
+ category: List[str] = field(
75
+ default_factory=lambda: [
76
+ "row",
77
+ "col",
78
+ "span",
79
+ ]
80
+ )
@@ -0,0 +1,49 @@
1
+ from dataclasses import dataclass, field
2
+ from typing import List
3
+
4
+
5
+ @dataclass
6
+ class BackBone:
7
+ name: str = "resnet50"
8
+ dilation: bool = True
9
+
10
+
11
+ @dataclass
12
+ class Decoder:
13
+ in_channels: list[int] = field(default_factory=lambda: [256, 512, 1024, 2048])
14
+ hidden_dim: int = 256
15
+ adaptive: bool = True
16
+ serial: bool = True
17
+ smooth: bool = False
18
+ k: int = 50
19
+
20
+
21
+ @dataclass
22
+ class Data:
23
+ shortest_size: int = 1280
24
+ limit_size: int = 1600
25
+
26
+
27
+ @dataclass
28
+ class PostProcess:
29
+ min_size: int = 2
30
+ thresh: float = 0.2
31
+ box_thresh: float = 0.5
32
+ max_candidates: int = 1500
33
+ unclip_ratio: float = 2.0
34
+
35
+
36
+ @dataclass
37
+ class Visualize:
38
+ color: List[int] = field(default_factory=lambda: [0, 255, 0])
39
+ heatmap: bool = False
40
+
41
+
42
+ @dataclass
43
+ class TextDetectorDBNetConfig:
44
+ hf_hub_repo: str = "KotaroKinoshita/yomitoku-text-detector-dbnet-open-beta"
45
+ backbone: BackBone = field(default_factory=BackBone)
46
+ decoder: Decoder = field(default_factory=Decoder)
47
+ data: Data = field(default_factory=Data)
48
+ post_process: PostProcess = field(default_factory=PostProcess)
49
+ visualize: Visualize = field(default_factory=Visualize)
@@ -0,0 +1,51 @@
1
+ from dataclasses import dataclass, field
2
+ from typing import List
3
+
4
+ from ..constants import ROOT_DIR
5
+
6
+
7
+ @dataclass
8
+ class Data:
9
+ num_workers: int = 4
10
+ batch_size: int = 128
11
+ img_size: List[int] = field(default_factory=lambda: [32, 800])
12
+
13
+
14
+ @dataclass
15
+ class Encoder:
16
+ patch_size: List[int] = field(default_factory=lambda: [8, 8])
17
+ num_heads: int = 8
18
+ embed_dim: int = 512
19
+ mlp_ratio: int = 4
20
+ depth: int = 12
21
+
22
+
23
+ @dataclass
24
+ class Decoder:
25
+ embed_dim: int = 512
26
+ num_heads: int = 8
27
+ mlp_ratio: int = 4
28
+ depth: int = 1
29
+
30
+
31
+ @dataclass
32
+ class Visualize:
33
+ font: str = str(ROOT_DIR + "/resource/MPLUS1p-Medium.ttf")
34
+ color: List[int] = field(default_factory=lambda: [0, 0, 255]) # RGB
35
+ font_size: int = 18
36
+
37
+
38
+ @dataclass
39
+ class TextRecognizerPARSeqConfig:
40
+ hf_hub_repo: str = "KotaroKinoshita/yomitoku-text-recognizer-parseq-open-beta"
41
+ charset: str = str(ROOT_DIR + "/resource/charset.txt")
42
+ num_tokens: int = 7312
43
+ max_label_length: int = 100
44
+ decode_ar: int = 1
45
+ refine_iters: int = 1
46
+
47
+ data: Data = field(default_factory=Data)
48
+ encoder: Encoder = field(default_factory=Encoder)
49
+ decoder: Decoder = field(default_factory=Decoder)
50
+
51
+ visualize: Visualize = field(default_factory=Visualize)
yomitoku/constants.py ADDED
@@ -0,0 +1,32 @@
1
+ import os
2
+
3
+ ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
4
+ SUPPORT_OUTPUT_FORMAT = ["json", "csv", "html", "markdown", "md"]
5
+ SUPPORT_INPUT_FORMAT = ["jpg", "jpeg", "png", "bmp", "tiff", "tif", "pdf"]
6
+ MIN_IMAGE_SIZE = 32
7
+ WARNING_IMAGE_SIZE = 720
8
+
9
+ PALETTE = [
10
+ [255, 0, 0],
11
+ [0, 255, 0],
12
+ [0, 0, 255],
13
+ [255, 255, 0],
14
+ [0, 255, 255],
15
+ [255, 0, 255],
16
+ [128, 0, 0],
17
+ [0, 128, 0],
18
+ [0, 0, 128],
19
+ [255, 128, 0],
20
+ [0, 255, 128],
21
+ [128, 0, 255],
22
+ [128, 255, 0],
23
+ [0, 128, 255],
24
+ [255, 0, 128],
25
+ [255, 128, 128],
26
+ [128, 255, 128],
27
+ [128, 128, 255],
28
+ [255, 255, 128],
29
+ [255, 128, 255],
30
+ [128, 255, 255],
31
+ [128, 128, 128],
32
+ ]