yomitoku 0.4.0.post1.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- yomitoku/__init__.py +20 -0
- yomitoku/base.py +136 -0
- yomitoku/cli/__init__.py +0 -0
- yomitoku/cli/main.py +230 -0
- yomitoku/configs/__init__.py +13 -0
- yomitoku/configs/cfg_layout_parser_rtdtrv2.py +89 -0
- yomitoku/configs/cfg_table_structure_recognizer_rtdtrv2.py +80 -0
- yomitoku/configs/cfg_text_detector_dbnet.py +49 -0
- yomitoku/configs/cfg_text_recognizer_parseq.py +51 -0
- yomitoku/constants.py +32 -0
- yomitoku/data/__init__.py +3 -0
- yomitoku/data/dataset.py +40 -0
- yomitoku/data/functions.py +279 -0
- yomitoku/document_analyzer.py +315 -0
- yomitoku/export/__init__.py +6 -0
- yomitoku/export/export_csv.py +71 -0
- yomitoku/export/export_html.py +188 -0
- yomitoku/export/export_json.py +34 -0
- yomitoku/export/export_markdown.py +145 -0
- yomitoku/layout_analyzer.py +66 -0
- yomitoku/layout_parser.py +189 -0
- yomitoku/models/__init__.py +9 -0
- yomitoku/models/dbnet_plus.py +272 -0
- yomitoku/models/layers/__init__.py +0 -0
- yomitoku/models/layers/activate.py +38 -0
- yomitoku/models/layers/dbnet_feature_attention.py +160 -0
- yomitoku/models/layers/parseq_transformer.py +218 -0
- yomitoku/models/layers/rtdetr_backbone.py +333 -0
- yomitoku/models/layers/rtdetr_hybrid_encoder.py +433 -0
- yomitoku/models/layers/rtdetrv2_decoder.py +811 -0
- yomitoku/models/parseq.py +243 -0
- yomitoku/models/rtdetr.py +22 -0
- yomitoku/ocr.py +87 -0
- yomitoku/postprocessor/__init__.py +9 -0
- yomitoku/postprocessor/dbnet_postporcessor.py +137 -0
- yomitoku/postprocessor/parseq_tokenizer.py +128 -0
- yomitoku/postprocessor/rtdetr_postprocessor.py +107 -0
- yomitoku/reading_order.py +214 -0
- yomitoku/resource/MPLUS1p-Medium.ttf +0 -0
- yomitoku/resource/charset.txt +1 -0
- yomitoku/table_structure_recognizer.py +244 -0
- yomitoku/text_detector.py +103 -0
- yomitoku/text_recognizer.py +128 -0
- yomitoku/utils/__init__.py +0 -0
- yomitoku/utils/graph.py +20 -0
- yomitoku/utils/logger.py +15 -0
- yomitoku/utils/misc.py +102 -0
- yomitoku/utils/visualizer.py +179 -0
- yomitoku-0.4.0.post1.dev0.dist-info/METADATA +127 -0
- yomitoku-0.4.0.post1.dev0.dist-info/RECORD +52 -0
- yomitoku-0.4.0.post1.dev0.dist-info/WHEEL +4 -0
- yomitoku-0.4.0.post1.dev0.dist-info/entry_points.txt +2 -0
yomitoku/__init__.py
ADDED
@@ -0,0 +1,20 @@
|
|
1
|
+
from importlib.metadata import version
|
2
|
+
|
3
|
+
from .document_analyzer import DocumentAnalyzer
|
4
|
+
from .layout_analyzer import LayoutAnalyzer
|
5
|
+
from .layout_parser import LayoutParser
|
6
|
+
from .ocr import OCR
|
7
|
+
from .table_structure_recognizer import TableStructureRecognizer
|
8
|
+
from .text_detector import TextDetector
|
9
|
+
from .text_recognizer import TextRecognizer
|
10
|
+
|
11
|
+
__all__ = [
|
12
|
+
"OCR",
|
13
|
+
"LayoutParser",
|
14
|
+
"TableStructureRecognizer",
|
15
|
+
"TextDetector",
|
16
|
+
"TextRecognizer",
|
17
|
+
"LayoutAnalyzer",
|
18
|
+
"DocumentAnalyzer",
|
19
|
+
]
|
20
|
+
__version__ = version(__package__)
|
yomitoku/base.py
ADDED
@@ -0,0 +1,136 @@
|
|
1
|
+
import time
|
2
|
+
from pathlib import Path
|
3
|
+
from typing import Union
|
4
|
+
|
5
|
+
import torch
|
6
|
+
from omegaconf import OmegaConf
|
7
|
+
from pydantic import BaseModel, Extra
|
8
|
+
|
9
|
+
from .export import export_json
|
10
|
+
from .utils.logger import set_logger
|
11
|
+
|
12
|
+
logger = set_logger(__name__, "INFO")
|
13
|
+
|
14
|
+
|
15
|
+
def load_yaml_config(path_config: str):
|
16
|
+
path_config = Path(path_config)
|
17
|
+
if not path_config.exists():
|
18
|
+
raise FileNotFoundError(f"Config file not found: {path_config}")
|
19
|
+
|
20
|
+
with open(path_config, "r") as file:
|
21
|
+
yaml_config = OmegaConf.load(file)
|
22
|
+
return yaml_config
|
23
|
+
|
24
|
+
|
25
|
+
def load_config(
|
26
|
+
default_config,
|
27
|
+
path_config: Union[str, None] = None,
|
28
|
+
):
|
29
|
+
cfg = OmegaConf.structured(default_config)
|
30
|
+
if path_config is not None:
|
31
|
+
yaml_config = load_yaml_config(path_config)
|
32
|
+
cfg = OmegaConf.merge(cfg, yaml_config)
|
33
|
+
return cfg
|
34
|
+
|
35
|
+
|
36
|
+
def observer(cls, func):
|
37
|
+
def wrapper(*args, **kwargs):
|
38
|
+
try:
|
39
|
+
start = time.time()
|
40
|
+
result = func(*args, **kwargs)
|
41
|
+
elapsed = time.time() - start
|
42
|
+
logger.info(f"{cls.__name__} {func.__name__} elapsed_time: {elapsed}")
|
43
|
+
except Exception as e:
|
44
|
+
logger.error(f"Error occurred in {cls.__name__} {func.__name__}: {e}")
|
45
|
+
raise e
|
46
|
+
return result
|
47
|
+
|
48
|
+
return wrapper
|
49
|
+
|
50
|
+
|
51
|
+
class BaseSchema(BaseModel):
|
52
|
+
class Config:
|
53
|
+
extra = Extra.forbid
|
54
|
+
validate_assignment = True
|
55
|
+
|
56
|
+
def to_json(self, out_path: str, **kwargs):
|
57
|
+
export_json(self, out_path, **kwargs)
|
58
|
+
|
59
|
+
|
60
|
+
class BaseModule:
|
61
|
+
model_catalog = None
|
62
|
+
|
63
|
+
def __init__(self):
|
64
|
+
if self.model_catalog is None:
|
65
|
+
raise NotImplementedError
|
66
|
+
|
67
|
+
if not issubclass(self.model_catalog.__class__, BaseModelCatalog):
|
68
|
+
raise ValueError(
|
69
|
+
f"{self.model_catalog.__class__} is not SubClass BaseModelCatalog."
|
70
|
+
)
|
71
|
+
|
72
|
+
if len(self.model_catalog.list_model()) == 0:
|
73
|
+
raise ValueError("No model is registered.")
|
74
|
+
|
75
|
+
def __new__(cls, *args, **kwds):
|
76
|
+
logger.info(f"Initialize {cls.__name__}")
|
77
|
+
cls.__call__ = observer(cls, cls.__call__)
|
78
|
+
return super().__new__(cls)
|
79
|
+
|
80
|
+
def load_model(self, name, path_cfg, from_pretrained=True):
|
81
|
+
default_cfg, Net = self.model_catalog.get(name)
|
82
|
+
self._cfg = load_config(default_cfg, path_cfg)
|
83
|
+
if from_pretrained:
|
84
|
+
self.model = Net.from_pretrained(self._cfg.hf_hub_repo, cfg=self._cfg)
|
85
|
+
else:
|
86
|
+
self.model = Net(cfg=self._cfg)
|
87
|
+
|
88
|
+
def save_config(self, path_cfg):
|
89
|
+
OmegaConf.save(self._cfg, path_cfg)
|
90
|
+
|
91
|
+
def log_config(self):
|
92
|
+
logger.info(OmegaConf.to_yaml(self._cfg))
|
93
|
+
|
94
|
+
@classmethod
|
95
|
+
def catalog(cls):
|
96
|
+
display = ""
|
97
|
+
for model in cls.model_catalog.list_model():
|
98
|
+
display += f"{model} "
|
99
|
+
logger.info(f"{cls.__name__} Implemented Models")
|
100
|
+
logger.info(display)
|
101
|
+
|
102
|
+
@property
|
103
|
+
def device(self):
|
104
|
+
return self._device
|
105
|
+
|
106
|
+
@device.setter
|
107
|
+
def device(self, device):
|
108
|
+
if "cuda" in device:
|
109
|
+
if torch.cuda.is_available():
|
110
|
+
self._device = torch.device(device)
|
111
|
+
else:
|
112
|
+
self._device = torch.device("cpu")
|
113
|
+
logger.warning("CUDA is not available. Use CPU instead.")
|
114
|
+
else:
|
115
|
+
self._device = torch.device("cpu")
|
116
|
+
|
117
|
+
|
118
|
+
class BaseModelCatalog:
|
119
|
+
def __init__(self):
|
120
|
+
self.catalog = {}
|
121
|
+
|
122
|
+
def get(self, model_name):
|
123
|
+
model_name = model_name.lower()
|
124
|
+
if model_name in self.catalog:
|
125
|
+
return self.catalog[model_name]
|
126
|
+
|
127
|
+
raise ValueError(f"Unknown model: {model_name}")
|
128
|
+
|
129
|
+
def register(self, model_name, config, model):
|
130
|
+
if model_name in self.catalog:
|
131
|
+
raise ValueError(f"{model_name} is already registered.")
|
132
|
+
|
133
|
+
self.catalog[model_name] = (config, model)
|
134
|
+
|
135
|
+
def list_model(self):
|
136
|
+
return list(self.catalog.keys())
|
yomitoku/cli/__init__.py
ADDED
File without changes
|
yomitoku/cli/main.py
ADDED
@@ -0,0 +1,230 @@
|
|
1
|
+
import argparse
|
2
|
+
import os
|
3
|
+
from pathlib import Path
|
4
|
+
|
5
|
+
import cv2
|
6
|
+
import time
|
7
|
+
|
8
|
+
from ..constants import SUPPORT_OUTPUT_FORMAT
|
9
|
+
from ..data.functions import load_image, load_pdf
|
10
|
+
from ..document_analyzer import DocumentAnalyzer
|
11
|
+
from ..utils.logger import set_logger
|
12
|
+
|
13
|
+
logger = set_logger(__name__, "INFO")
|
14
|
+
|
15
|
+
|
16
|
+
def process_single_file(args, analyzer, path, format):
|
17
|
+
if path.suffix[1:].lower() in ["pdf"]:
|
18
|
+
imgs = load_pdf(path)
|
19
|
+
else:
|
20
|
+
imgs = [load_image(path)]
|
21
|
+
|
22
|
+
for page, img in enumerate(imgs):
|
23
|
+
results, ocr, layout = analyzer(img)
|
24
|
+
|
25
|
+
dirname = path.parent.name
|
26
|
+
filename = path.stem
|
27
|
+
|
28
|
+
if ocr is not None:
|
29
|
+
out_path = os.path.join(
|
30
|
+
args.outdir, f"{dirname}_{filename}_p{page+1}_ocr.jpg"
|
31
|
+
)
|
32
|
+
|
33
|
+
cv2.imwrite(out_path, ocr)
|
34
|
+
logger.info(f"Output file: {out_path}")
|
35
|
+
|
36
|
+
if layout is not None:
|
37
|
+
out_path = os.path.join(
|
38
|
+
args.outdir, f"{dirname}_{filename}_p{page+1}_layout.jpg"
|
39
|
+
)
|
40
|
+
|
41
|
+
cv2.imwrite(out_path, layout)
|
42
|
+
logger.info(f"Output file: {out_path}")
|
43
|
+
|
44
|
+
out_path = os.path.join(args.outdir, f"{dirname}_{filename}_p{page+1}.{format}")
|
45
|
+
|
46
|
+
if format == "json":
|
47
|
+
results.to_json(
|
48
|
+
out_path,
|
49
|
+
ignore_line_break=args.ignore_line_break,
|
50
|
+
)
|
51
|
+
elif format == "csv":
|
52
|
+
results.to_csv(
|
53
|
+
out_path,
|
54
|
+
ignore_line_break=args.ignore_line_break,
|
55
|
+
)
|
56
|
+
elif format == "html":
|
57
|
+
results.to_html(
|
58
|
+
out_path,
|
59
|
+
ignore_line_break=args.ignore_line_break,
|
60
|
+
img=img,
|
61
|
+
export_figure=args.figure,
|
62
|
+
export_figure_letter=args.figure_letter,
|
63
|
+
figure_width=args.figure_width,
|
64
|
+
figure_dir=args.figure_dir,
|
65
|
+
)
|
66
|
+
elif format == "md":
|
67
|
+
results.to_markdown(
|
68
|
+
out_path,
|
69
|
+
ignore_line_break=args.ignore_line_break,
|
70
|
+
img=img,
|
71
|
+
export_figure=args.figure,
|
72
|
+
export_figure_letter=args.figure_letter,
|
73
|
+
figure_width=args.figure_width,
|
74
|
+
figure_dir=args.figure_dir,
|
75
|
+
)
|
76
|
+
|
77
|
+
logger.info(f"Output file: {out_path}")
|
78
|
+
|
79
|
+
|
80
|
+
def main():
|
81
|
+
parser = argparse.ArgumentParser()
|
82
|
+
parser.add_argument(
|
83
|
+
"arg1",
|
84
|
+
type=str,
|
85
|
+
help="path of target image file or directory",
|
86
|
+
)
|
87
|
+
parser.add_argument(
|
88
|
+
"-f",
|
89
|
+
"--format",
|
90
|
+
type=str,
|
91
|
+
default="json",
|
92
|
+
help="output format type (json or csv or html or md)",
|
93
|
+
)
|
94
|
+
parser.add_argument(
|
95
|
+
"-v",
|
96
|
+
"--vis",
|
97
|
+
action="store_true",
|
98
|
+
help="if set, visualize the result",
|
99
|
+
)
|
100
|
+
parser.add_argument(
|
101
|
+
"-o",
|
102
|
+
"--outdir",
|
103
|
+
type=str,
|
104
|
+
default="results",
|
105
|
+
help="output directory",
|
106
|
+
)
|
107
|
+
parser.add_argument(
|
108
|
+
"-d",
|
109
|
+
"--device",
|
110
|
+
type=str,
|
111
|
+
default="cuda",
|
112
|
+
help="device to use",
|
113
|
+
)
|
114
|
+
parser.add_argument(
|
115
|
+
"--td_cfg",
|
116
|
+
type=str,
|
117
|
+
default=None,
|
118
|
+
help="path of text detector config file",
|
119
|
+
)
|
120
|
+
parser.add_argument(
|
121
|
+
"--tr_cfg",
|
122
|
+
type=str,
|
123
|
+
default=None,
|
124
|
+
help="path of text recognizer config file",
|
125
|
+
)
|
126
|
+
parser.add_argument(
|
127
|
+
"--lp_cfg",
|
128
|
+
type=str,
|
129
|
+
default=None,
|
130
|
+
help="path of layout parser config file",
|
131
|
+
)
|
132
|
+
parser.add_argument(
|
133
|
+
"--tsr_cfg",
|
134
|
+
type=str,
|
135
|
+
default=None,
|
136
|
+
help="path of table structure recognizer config file",
|
137
|
+
)
|
138
|
+
parser.add_argument(
|
139
|
+
"--ignore_line_break",
|
140
|
+
action="store_true",
|
141
|
+
help="if set, ignore line break in the output",
|
142
|
+
)
|
143
|
+
parser.add_argument(
|
144
|
+
"--figure",
|
145
|
+
action="store_true",
|
146
|
+
help="if set, export figure in the output",
|
147
|
+
)
|
148
|
+
parser.add_argument(
|
149
|
+
"--figure_letter",
|
150
|
+
action="store_true",
|
151
|
+
help="if set, export letter within figure in the output",
|
152
|
+
)
|
153
|
+
parser.add_argument(
|
154
|
+
"--figure_width",
|
155
|
+
type=int,
|
156
|
+
default=200,
|
157
|
+
help="width of figure image in the output",
|
158
|
+
)
|
159
|
+
parser.add_argument(
|
160
|
+
"--figure_dir",
|
161
|
+
type=str,
|
162
|
+
default="figures",
|
163
|
+
help="directory to save figure images",
|
164
|
+
)
|
165
|
+
|
166
|
+
args = parser.parse_args()
|
167
|
+
|
168
|
+
path = Path(args.arg1)
|
169
|
+
if not path.exists():
|
170
|
+
raise FileNotFoundError(f"File not found: {args.arg1}")
|
171
|
+
|
172
|
+
format = args.format.lower()
|
173
|
+
if format not in SUPPORT_OUTPUT_FORMAT:
|
174
|
+
raise ValueError(
|
175
|
+
f"Invalid output format: {args.format}. Supported formats are {SUPPORT_OUTPUT_FORMAT}"
|
176
|
+
)
|
177
|
+
|
178
|
+
if format == "markdown":
|
179
|
+
format = "md"
|
180
|
+
|
181
|
+
configs = {
|
182
|
+
"ocr": {
|
183
|
+
"text_detector": {
|
184
|
+
"path_cfg": args.td_cfg,
|
185
|
+
},
|
186
|
+
"text_recognizer": {
|
187
|
+
"path_cfg": args.tr_cfg,
|
188
|
+
},
|
189
|
+
},
|
190
|
+
"layout_analyzer": {
|
191
|
+
"layout_parser": {
|
192
|
+
"path_cfg": args.lp_cfg,
|
193
|
+
},
|
194
|
+
"table_structure_recognizer": {
|
195
|
+
"path_cfg": args.tsr_cfg,
|
196
|
+
},
|
197
|
+
},
|
198
|
+
}
|
199
|
+
|
200
|
+
analyzer = DocumentAnalyzer(
|
201
|
+
configs=configs,
|
202
|
+
visualize=args.vis,
|
203
|
+
device=args.device,
|
204
|
+
)
|
205
|
+
|
206
|
+
os.makedirs(args.outdir, exist_ok=True)
|
207
|
+
logger.info(f"Output directory: {args.outdir}")
|
208
|
+
|
209
|
+
if path.is_dir():
|
210
|
+
all_files = [f for f in path.rglob("*") if f.is_file()]
|
211
|
+
for f in all_files:
|
212
|
+
try:
|
213
|
+
start = time.time()
|
214
|
+
file_path = Path(f)
|
215
|
+
logger.info(f"Processing file: {file_path}")
|
216
|
+
process_single_file(args, analyzer, file_path, format)
|
217
|
+
end = time.time()
|
218
|
+
logger.info(f"Total Processing time: {end-start:.2f} sec")
|
219
|
+
except Exception:
|
220
|
+
continue
|
221
|
+
else:
|
222
|
+
start = time.time()
|
223
|
+
logger.info(f"Processing file: {path}")
|
224
|
+
process_single_file(args, analyzer, path, format)
|
225
|
+
end = time.time()
|
226
|
+
logger.info(f"Total Processing time: {end-start:.2f} sec")
|
227
|
+
|
228
|
+
|
229
|
+
if __name__ == "__main__":
|
230
|
+
main()
|
@@ -0,0 +1,13 @@
|
|
1
|
+
from .cfg_layout_parser_rtdtrv2 import LayoutParserRTDETRv2Config
|
2
|
+
from .cfg_table_structure_recognizer_rtdtrv2 import (
|
3
|
+
TableStructureRecognizerRTDETRv2Config,
|
4
|
+
)
|
5
|
+
from .cfg_text_detector_dbnet import TextDetectorDBNetConfig
|
6
|
+
from .cfg_text_recognizer_parseq import TextRecognizerPARSeqConfig
|
7
|
+
|
8
|
+
__all__ = [
|
9
|
+
"TextDetectorDBNetConfig",
|
10
|
+
"TextRecognizerPARSeqConfig",
|
11
|
+
"LayoutParserRTDETRv2Config",
|
12
|
+
"TableStructureRecognizerRTDETRv2Config",
|
13
|
+
]
|
@@ -0,0 +1,89 @@
|
|
1
|
+
from dataclasses import dataclass, field
|
2
|
+
from typing import List
|
3
|
+
|
4
|
+
|
5
|
+
@dataclass
|
6
|
+
class Data:
|
7
|
+
img_size: List[int] = field(default_factory=lambda: [640, 640])
|
8
|
+
|
9
|
+
|
10
|
+
@dataclass
|
11
|
+
class BackBone:
|
12
|
+
depth: int = 50
|
13
|
+
variant: str = "d"
|
14
|
+
freeze_at: int = 0
|
15
|
+
return_idx: List[int] = field(default_factory=lambda: [1, 2, 3])
|
16
|
+
num_stages: int = 4
|
17
|
+
freeze_norm: bool = True
|
18
|
+
|
19
|
+
|
20
|
+
@dataclass
|
21
|
+
class Encoder:
|
22
|
+
in_channels: List[int] = field(default_factory=lambda: [512, 1024, 2048])
|
23
|
+
feat_strides: List[int] = field(default_factory=lambda: [8, 16, 32])
|
24
|
+
|
25
|
+
# intra
|
26
|
+
hidden_dim: int = 256
|
27
|
+
use_encoder_idx: List[int] = field(default_factory=lambda: [2])
|
28
|
+
num_encoder_layers: int = 1
|
29
|
+
nhead: int = 8
|
30
|
+
dim_feedforward: int = 1024
|
31
|
+
dropout: float = 0.0
|
32
|
+
enc_act: str = "gelu"
|
33
|
+
|
34
|
+
# cross
|
35
|
+
expansion: float = 1.0
|
36
|
+
depth_mult: int = 1
|
37
|
+
act: str = "silu"
|
38
|
+
|
39
|
+
|
40
|
+
@dataclass
|
41
|
+
class Decoder:
|
42
|
+
num_classes: int = 6
|
43
|
+
feat_channels: List[int] = field(default_factory=lambda: [256, 256, 256])
|
44
|
+
feat_strides: List[int] = field(default_factory=lambda: [8, 16, 32])
|
45
|
+
hidden_dim: int = 256
|
46
|
+
num_levels: int = 3
|
47
|
+
|
48
|
+
num_layers: int = 6
|
49
|
+
num_queries: int = 300
|
50
|
+
|
51
|
+
num_denoising: int = 100
|
52
|
+
label_noise_ratio: float = 0.5
|
53
|
+
box_noise_scale: float = 1.0
|
54
|
+
eval_spatial_size: List[int] = field(default_factory=lambda: [640, 640])
|
55
|
+
|
56
|
+
eval_idx: int = -1
|
57
|
+
|
58
|
+
num_points: List[int] = field(default_factory=lambda: [4, 4, 4])
|
59
|
+
cross_attn_method: str = "default"
|
60
|
+
query_select_method: str = "default"
|
61
|
+
|
62
|
+
|
63
|
+
@dataclass
|
64
|
+
class LayoutParserRTDETRv2Config:
|
65
|
+
hf_hub_repo: str = "KotaroKinoshita/yomitoku-layout-parser-rtdtrv2-open-beta"
|
66
|
+
thresh_score: float = 0.5
|
67
|
+
data: Data = field(default_factory=Data)
|
68
|
+
PResNet: BackBone = field(default_factory=BackBone)
|
69
|
+
HybridEncoder: Encoder = field(default_factory=Encoder)
|
70
|
+
RTDETRTransformerv2: Decoder = field(default_factory=Decoder)
|
71
|
+
|
72
|
+
category: List[str] = field(
|
73
|
+
default_factory=lambda: [
|
74
|
+
"tables",
|
75
|
+
"figures",
|
76
|
+
"paragraphs",
|
77
|
+
"section_headings",
|
78
|
+
"page_header",
|
79
|
+
"page_footer",
|
80
|
+
]
|
81
|
+
)
|
82
|
+
|
83
|
+
role: List[str] = field(
|
84
|
+
default_factory=lambda: [
|
85
|
+
"section_headings",
|
86
|
+
"page_header",
|
87
|
+
"page_footer",
|
88
|
+
]
|
89
|
+
)
|
@@ -0,0 +1,80 @@
|
|
1
|
+
from dataclasses import dataclass, field
|
2
|
+
from typing import List
|
3
|
+
|
4
|
+
|
5
|
+
@dataclass
|
6
|
+
class Data:
|
7
|
+
img_size: List[int] = field(default_factory=lambda: [640, 640])
|
8
|
+
|
9
|
+
|
10
|
+
@dataclass
|
11
|
+
class BackBone:
|
12
|
+
depth: int = 50
|
13
|
+
variant: str = "d"
|
14
|
+
freeze_at: int = 0
|
15
|
+
return_idx: List[int] = field(default_factory=lambda: [1, 2, 3])
|
16
|
+
num_stages: int = 4
|
17
|
+
freeze_norm: bool = True
|
18
|
+
|
19
|
+
|
20
|
+
@dataclass
|
21
|
+
class Encoder:
|
22
|
+
in_channels: List[int] = field(default_factory=lambda: [512, 1024, 2048])
|
23
|
+
feat_strides: List[int] = field(default_factory=lambda: [8, 16, 32])
|
24
|
+
|
25
|
+
# intra
|
26
|
+
hidden_dim: int = 256
|
27
|
+
use_encoder_idx: List[int] = field(default_factory=lambda: [2])
|
28
|
+
num_encoder_layers: int = 1
|
29
|
+
nhead: int = 8
|
30
|
+
dim_feedforward: int = 1024
|
31
|
+
dropout: float = 0.0
|
32
|
+
enc_act: str = "gelu"
|
33
|
+
|
34
|
+
# cross
|
35
|
+
expansion: float = 1.0
|
36
|
+
depth_mult: int = 1
|
37
|
+
act: str = "silu"
|
38
|
+
|
39
|
+
|
40
|
+
@dataclass
|
41
|
+
class Decoder:
|
42
|
+
num_classes: int = 3
|
43
|
+
feat_channels: List[int] = field(default_factory=lambda: [256, 256, 256])
|
44
|
+
feat_strides: List[int] = field(default_factory=lambda: [8, 16, 32])
|
45
|
+
hidden_dim: int = 256
|
46
|
+
num_levels: int = 3
|
47
|
+
|
48
|
+
num_layers: int = 6
|
49
|
+
num_queries: int = 300
|
50
|
+
|
51
|
+
num_denoising: int = 100
|
52
|
+
label_noise_ratio: float = 0.5
|
53
|
+
box_noise_scale: float = 1.0 # 1.0 0.4
|
54
|
+
eval_spatial_size: List[int] = field(default_factory=lambda: [640, 640])
|
55
|
+
|
56
|
+
eval_idx: int = -1
|
57
|
+
|
58
|
+
num_points: List[int] = field(default_factory=lambda: [4, 4, 4])
|
59
|
+
cross_attn_method: str = "default"
|
60
|
+
query_select_method: str = "default"
|
61
|
+
|
62
|
+
|
63
|
+
@dataclass
|
64
|
+
class TableStructureRecognizerRTDETRv2Config:
|
65
|
+
hf_hub_repo: str = (
|
66
|
+
"KotaroKinoshita/yomitoku-table-structure-recognizer-rtdtrv2-open-beta"
|
67
|
+
)
|
68
|
+
thresh_score: float = 0.4
|
69
|
+
data: Data = field(default_factory=Data)
|
70
|
+
PResNet: BackBone = field(default_factory=BackBone)
|
71
|
+
HybridEncoder: Encoder = field(default_factory=Encoder)
|
72
|
+
RTDETRTransformerv2: Decoder = field(default_factory=Decoder)
|
73
|
+
|
74
|
+
category: List[str] = field(
|
75
|
+
default_factory=lambda: [
|
76
|
+
"row",
|
77
|
+
"col",
|
78
|
+
"span",
|
79
|
+
]
|
80
|
+
)
|
@@ -0,0 +1,49 @@
|
|
1
|
+
from dataclasses import dataclass, field
|
2
|
+
from typing import List
|
3
|
+
|
4
|
+
|
5
|
+
@dataclass
|
6
|
+
class BackBone:
|
7
|
+
name: str = "resnet50"
|
8
|
+
dilation: bool = True
|
9
|
+
|
10
|
+
|
11
|
+
@dataclass
|
12
|
+
class Decoder:
|
13
|
+
in_channels: list[int] = field(default_factory=lambda: [256, 512, 1024, 2048])
|
14
|
+
hidden_dim: int = 256
|
15
|
+
adaptive: bool = True
|
16
|
+
serial: bool = True
|
17
|
+
smooth: bool = False
|
18
|
+
k: int = 50
|
19
|
+
|
20
|
+
|
21
|
+
@dataclass
|
22
|
+
class Data:
|
23
|
+
shortest_size: int = 1280
|
24
|
+
limit_size: int = 1600
|
25
|
+
|
26
|
+
|
27
|
+
@dataclass
|
28
|
+
class PostProcess:
|
29
|
+
min_size: int = 2
|
30
|
+
thresh: float = 0.2
|
31
|
+
box_thresh: float = 0.5
|
32
|
+
max_candidates: int = 1500
|
33
|
+
unclip_ratio: float = 2.0
|
34
|
+
|
35
|
+
|
36
|
+
@dataclass
|
37
|
+
class Visualize:
|
38
|
+
color: List[int] = field(default_factory=lambda: [0, 255, 0])
|
39
|
+
heatmap: bool = False
|
40
|
+
|
41
|
+
|
42
|
+
@dataclass
|
43
|
+
class TextDetectorDBNetConfig:
|
44
|
+
hf_hub_repo: str = "KotaroKinoshita/yomitoku-text-detector-dbnet-open-beta"
|
45
|
+
backbone: BackBone = field(default_factory=BackBone)
|
46
|
+
decoder: Decoder = field(default_factory=Decoder)
|
47
|
+
data: Data = field(default_factory=Data)
|
48
|
+
post_process: PostProcess = field(default_factory=PostProcess)
|
49
|
+
visualize: Visualize = field(default_factory=Visualize)
|
@@ -0,0 +1,51 @@
|
|
1
|
+
from dataclasses import dataclass, field
|
2
|
+
from typing import List
|
3
|
+
|
4
|
+
from ..constants import ROOT_DIR
|
5
|
+
|
6
|
+
|
7
|
+
@dataclass
|
8
|
+
class Data:
|
9
|
+
num_workers: int = 4
|
10
|
+
batch_size: int = 128
|
11
|
+
img_size: List[int] = field(default_factory=lambda: [32, 800])
|
12
|
+
|
13
|
+
|
14
|
+
@dataclass
|
15
|
+
class Encoder:
|
16
|
+
patch_size: List[int] = field(default_factory=lambda: [8, 8])
|
17
|
+
num_heads: int = 8
|
18
|
+
embed_dim: int = 512
|
19
|
+
mlp_ratio: int = 4
|
20
|
+
depth: int = 12
|
21
|
+
|
22
|
+
|
23
|
+
@dataclass
|
24
|
+
class Decoder:
|
25
|
+
embed_dim: int = 512
|
26
|
+
num_heads: int = 8
|
27
|
+
mlp_ratio: int = 4
|
28
|
+
depth: int = 1
|
29
|
+
|
30
|
+
|
31
|
+
@dataclass
|
32
|
+
class Visualize:
|
33
|
+
font: str = str(ROOT_DIR + "/resource/MPLUS1p-Medium.ttf")
|
34
|
+
color: List[int] = field(default_factory=lambda: [0, 0, 255]) # RGB
|
35
|
+
font_size: int = 18
|
36
|
+
|
37
|
+
|
38
|
+
@dataclass
|
39
|
+
class TextRecognizerPARSeqConfig:
|
40
|
+
hf_hub_repo: str = "KotaroKinoshita/yomitoku-text-recognizer-parseq-open-beta"
|
41
|
+
charset: str = str(ROOT_DIR + "/resource/charset.txt")
|
42
|
+
num_tokens: int = 7312
|
43
|
+
max_label_length: int = 100
|
44
|
+
decode_ar: int = 1
|
45
|
+
refine_iters: int = 1
|
46
|
+
|
47
|
+
data: Data = field(default_factory=Data)
|
48
|
+
encoder: Encoder = field(default_factory=Encoder)
|
49
|
+
decoder: Decoder = field(default_factory=Decoder)
|
50
|
+
|
51
|
+
visualize: Visualize = field(default_factory=Visualize)
|
yomitoku/constants.py
ADDED
@@ -0,0 +1,32 @@
|
|
1
|
+
import os
|
2
|
+
|
3
|
+
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
4
|
+
SUPPORT_OUTPUT_FORMAT = ["json", "csv", "html", "markdown", "md"]
|
5
|
+
SUPPORT_INPUT_FORMAT = ["jpg", "jpeg", "png", "bmp", "tiff", "tif", "pdf"]
|
6
|
+
MIN_IMAGE_SIZE = 32
|
7
|
+
WARNING_IMAGE_SIZE = 720
|
8
|
+
|
9
|
+
PALETTE = [
|
10
|
+
[255, 0, 0],
|
11
|
+
[0, 255, 0],
|
12
|
+
[0, 0, 255],
|
13
|
+
[255, 255, 0],
|
14
|
+
[0, 255, 255],
|
15
|
+
[255, 0, 255],
|
16
|
+
[128, 0, 0],
|
17
|
+
[0, 128, 0],
|
18
|
+
[0, 0, 128],
|
19
|
+
[255, 128, 0],
|
20
|
+
[0, 255, 128],
|
21
|
+
[128, 0, 255],
|
22
|
+
[128, 255, 0],
|
23
|
+
[0, 128, 255],
|
24
|
+
[255, 0, 128],
|
25
|
+
[255, 128, 128],
|
26
|
+
[128, 255, 128],
|
27
|
+
[128, 128, 255],
|
28
|
+
[255, 255, 128],
|
29
|
+
[255, 128, 255],
|
30
|
+
[128, 255, 255],
|
31
|
+
[128, 128, 128],
|
32
|
+
]
|