teklia-layout-reader 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- layout_reader/__init__.py +9 -0
- layout_reader/cli.py +27 -0
- layout_reader/datasets/__init__.py +99 -0
- layout_reader/datasets/analyze.py +161 -0
- layout_reader/datasets/extract.py +289 -0
- layout_reader/datasets/lsd.py +133 -0
- layout_reader/datasets/utils.py +128 -0
- layout_reader/helpers.py +358 -0
- layout_reader/inference.py +215 -0
- layout_reader/train/sft.py +69 -0
- teklia_layout_reader-0.2.1.dist-info/METADATA +62 -0
- teklia_layout_reader-0.2.1.dist-info/RECORD +22 -0
- teklia_layout_reader-0.2.1.dist-info/WHEEL +5 -0
- teklia_layout_reader-0.2.1.dist-info/entry_points.txt +2 -0
- teklia_layout_reader-0.2.1.dist-info/top_level.txt +2 -0
- tests/__init__.py +3 -0
- tests/conftest.py +19 -0
- tests/test_analyze.py +14 -0
- tests/test_cli.py +11 -0
- tests/test_extract.py +130 -0
- tests/test_helpers.py +438 -0
- tests/test_predict.py +64 -0
layout_reader/cli.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
|
|
3
|
+
from layout_reader.datasets import add_dataset_command
|
|
4
|
+
from layout_reader.inference import add_predict_command
|
|
5
|
+
from layout_reader.train.sft import add_training_command
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def main():
|
|
9
|
+
parser = argparse.ArgumentParser(
|
|
10
|
+
prog="teklia-layout-reader",
|
|
11
|
+
description="Scripts for Layout Reader",
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
commands = parser.add_subparsers(metavar="subcommands")
|
|
15
|
+
add_dataset_command(commands)
|
|
16
|
+
add_training_command(commands)
|
|
17
|
+
add_predict_command(commands)
|
|
18
|
+
|
|
19
|
+
args = vars(parser.parse_args())
|
|
20
|
+
if "func" in args:
|
|
21
|
+
args.pop("func")(**args)
|
|
22
|
+
else:
|
|
23
|
+
parser.print_help()
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
if __name__ == "__main__":
|
|
27
|
+
main()
|
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
"""Manage datasets."""
|
|
2
|
+
|
|
3
|
+
import argparse
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
|
|
6
|
+
from layout_reader.datasets.analyze import run_dataset_analysis
|
|
7
|
+
from layout_reader.datasets.extract import Mode, run_dataset_extraction
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def add_dataset_command(commands) -> None:
|
|
11
|
+
export = commands.add_parser(
|
|
12
|
+
"dataset",
|
|
13
|
+
description=__doc__,
|
|
14
|
+
help=__doc__,
|
|
15
|
+
)
|
|
16
|
+
subcommands = export.add_subparsers(metavar="subcommands")
|
|
17
|
+
add_extract_parser(subcommands)
|
|
18
|
+
add_analyze_parser(subcommands)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def valid_rate(rate: float) -> float:
|
|
22
|
+
"""Convert to a float between 0 and 1."""
|
|
23
|
+
try:
|
|
24
|
+
rate = float(rate)
|
|
25
|
+
except ValueError as e:
|
|
26
|
+
raise argparse.ArgumentTypeError(
|
|
27
|
+
f"`The shuffle rate ({rate}) should be a float."
|
|
28
|
+
) from e
|
|
29
|
+
|
|
30
|
+
if rate > 1 or rate < 0:
|
|
31
|
+
raise argparse.ArgumentTypeError(
|
|
32
|
+
f"`The shuffle rate ({rate}) should be between 0 and 1."
|
|
33
|
+
)
|
|
34
|
+
return rate
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def add_extract_parser(subcommands) -> None:
|
|
38
|
+
parser = subcommands.add_parser(
|
|
39
|
+
"extract",
|
|
40
|
+
description="Convert a HF dataset into LayoutReader format.",
|
|
41
|
+
help="Convert a HF dataset into LayoutReader format.",
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
parser.add_argument(
|
|
45
|
+
"dataset",
|
|
46
|
+
type=str,
|
|
47
|
+
help="Name of the HuggingFace or YOLO dataset",
|
|
48
|
+
)
|
|
49
|
+
parser.add_argument(
|
|
50
|
+
"--output-dir",
|
|
51
|
+
type=Path,
|
|
52
|
+
help="Output directory.",
|
|
53
|
+
required=True,
|
|
54
|
+
)
|
|
55
|
+
parser.add_argument(
|
|
56
|
+
"--mode",
|
|
57
|
+
type=Mode,
|
|
58
|
+
help=f"Dataset type. Must be in {[e.value for e in Mode]}",
|
|
59
|
+
required=True,
|
|
60
|
+
)
|
|
61
|
+
# Optional
|
|
62
|
+
parser.add_argument(
|
|
63
|
+
"--shuffle-rate",
|
|
64
|
+
type=valid_rate,
|
|
65
|
+
help="Ratio of the data that will be shuffled (expected between 0 and 1).",
|
|
66
|
+
default=0.5,
|
|
67
|
+
)
|
|
68
|
+
parser.add_argument(
|
|
69
|
+
"--extract-classes",
|
|
70
|
+
action="store_true",
|
|
71
|
+
help="Whether to extract classes.",
|
|
72
|
+
)
|
|
73
|
+
parser.add_argument(
|
|
74
|
+
"--extract-separators",
|
|
75
|
+
action="store_true",
|
|
76
|
+
help="Whether to extract separators.",
|
|
77
|
+
)
|
|
78
|
+
parser.set_defaults(func=run_dataset_extraction)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def add_analyze_parser(subcommands) -> None:
|
|
82
|
+
parser = subcommands.add_parser(
|
|
83
|
+
"analyze",
|
|
84
|
+
description="Analyze a LayoutReader dataset.",
|
|
85
|
+
help="Analyze a LayoutReader dataset.",
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
parser.add_argument(
|
|
89
|
+
"dataset_path",
|
|
90
|
+
type=Path,
|
|
91
|
+
help="Path of the LayoutReader dataset.",
|
|
92
|
+
)
|
|
93
|
+
parser.add_argument(
|
|
94
|
+
"--report-path",
|
|
95
|
+
type=Path,
|
|
96
|
+
help="Path to save the Markdown report.",
|
|
97
|
+
default="dataset_report.md",
|
|
98
|
+
)
|
|
99
|
+
parser.set_defaults(func=run_dataset_analysis)
|
|
@@ -0,0 +1,161 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Analyze a local LayoutReader dataset.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
from collections import Counter, defaultdict
|
|
7
|
+
from functools import partial
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
from mdutils.mdutils import MdUtils
|
|
12
|
+
from prettytable import PrettyTable, TableStyle
|
|
13
|
+
|
|
14
|
+
from layout_reader.datasets.utils import load_gzip_jsonl
|
|
15
|
+
|
|
16
|
+
logger = logging.getLogger(__name__)
|
|
17
|
+
|
|
18
|
+
METRIC_COLUMN = "Metric"
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def create_table(
|
|
22
|
+
data: dict,
|
|
23
|
+
count: bool = False,
|
|
24
|
+
total: bool = True,
|
|
25
|
+
) -> PrettyTable:
|
|
26
|
+
"""
|
|
27
|
+
Each keys will be made into a column.
|
|
28
|
+
We compute min, max, mean, median, total by default.
|
|
29
|
+
Total can be disabled. Count (length) computation can be enabled.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
statistics = PrettyTable(field_names=[METRIC_COLUMN, *data.keys()])
|
|
33
|
+
statistics.align.update({METRIC_COLUMN: "l"})
|
|
34
|
+
statistics.set_style(TableStyle.MARKDOWN)
|
|
35
|
+
|
|
36
|
+
operations = []
|
|
37
|
+
|
|
38
|
+
if count:
|
|
39
|
+
operations.append(("Count", len, None))
|
|
40
|
+
|
|
41
|
+
operations.extend(
|
|
42
|
+
[
|
|
43
|
+
("Min", np.min, None),
|
|
44
|
+
("Max", np.max, None),
|
|
45
|
+
("Mean", np.mean, 2),
|
|
46
|
+
("Median", np.median, 2),
|
|
47
|
+
]
|
|
48
|
+
)
|
|
49
|
+
if total:
|
|
50
|
+
operations.append(("Total", np.sum, None))
|
|
51
|
+
|
|
52
|
+
statistics.add_rows(
|
|
53
|
+
[
|
|
54
|
+
[
|
|
55
|
+
col_name,
|
|
56
|
+
*list(
|
|
57
|
+
map(
|
|
58
|
+
# Round values if needed
|
|
59
|
+
partial(round, ndigits=digits),
|
|
60
|
+
map(operator, data.values()),
|
|
61
|
+
)
|
|
62
|
+
),
|
|
63
|
+
]
|
|
64
|
+
for col_name, operator, digits in operations
|
|
65
|
+
]
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
return statistics
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class Statistics:
|
|
72
|
+
HEADERS = {
|
|
73
|
+
"Coordinates": "Coordinates statistics",
|
|
74
|
+
"Classes": "Classes statistics",
|
|
75
|
+
"Separators": "Separators statistics",
|
|
76
|
+
}
|
|
77
|
+
|
|
78
|
+
def __init__(self, output: Path) -> None:
|
|
79
|
+
self.document = MdUtils(file_name=str(output), title="Statistics")
|
|
80
|
+
|
|
81
|
+
def _write_section(self, table: PrettyTable, title: str, level: int = 2) -> None:
|
|
82
|
+
"""
|
|
83
|
+
Write the new section in the file.
|
|
84
|
+
|
|
85
|
+
<title with appropriate level>
|
|
86
|
+
|
|
87
|
+
<table>
|
|
88
|
+
|
|
89
|
+
"""
|
|
90
|
+
self.document.new_header(level=level, title=title, add_table_of_contents="n")
|
|
91
|
+
self.document.write("\n")
|
|
92
|
+
|
|
93
|
+
logger.info(f"{title}\n\n{table}\n")
|
|
94
|
+
|
|
95
|
+
self.document.write(table.get_string())
|
|
96
|
+
self.document.write("\n")
|
|
97
|
+
|
|
98
|
+
def create_classes_statistics(self, labels: list[str], title: str) -> None:
|
|
99
|
+
"""
|
|
100
|
+
Compute statistics on class labels and write them to file.
|
|
101
|
+
"""
|
|
102
|
+
class_counter = Counter()
|
|
103
|
+
|
|
104
|
+
for classes in labels:
|
|
105
|
+
class_counter.update(classes)
|
|
106
|
+
|
|
107
|
+
statistics = PrettyTable(
|
|
108
|
+
field_names=[METRIC_COLUMN] + [f"Class ID {k}" for k in class_counter]
|
|
109
|
+
)
|
|
110
|
+
statistics.set_style(TableStyle.MARKDOWN)
|
|
111
|
+
|
|
112
|
+
statistics.add_row(["Count"] + [v for v in class_counter.values()])
|
|
113
|
+
|
|
114
|
+
self._write_section(
|
|
115
|
+
table=statistics,
|
|
116
|
+
title=title,
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
def create_boxes_statistics(self, labels: list[list[float]], title: str) -> None:
|
|
120
|
+
"""
|
|
121
|
+
Compute statistics on bounding boxes and write them to file.
|
|
122
|
+
"""
|
|
123
|
+
data = defaultdict(list)
|
|
124
|
+
|
|
125
|
+
for page in labels:
|
|
126
|
+
for box in page:
|
|
127
|
+
x1, y1, x2, y2 = box
|
|
128
|
+
data["Box width"].append(x2 - x1)
|
|
129
|
+
data["Box height"].append(y2 - y1)
|
|
130
|
+
data["Box surface (%)"].append((x2 - x1) * (y2 - y1) / 10000)
|
|
131
|
+
|
|
132
|
+
self._write_section(
|
|
133
|
+
table=create_table(data=data, total=False),
|
|
134
|
+
title=title,
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
def run(self, filenames: list[Path]) -> None:
|
|
138
|
+
"""Iterate over each split and create report"""
|
|
139
|
+
for filename in filenames:
|
|
140
|
+
self.document.new_header(level=1, title=str(filename.stem).capitalize())
|
|
141
|
+
labels = load_gzip_jsonl(filename)
|
|
142
|
+
self.create_classes_statistics(
|
|
143
|
+
labels=[page["target_classes"] for page in labels],
|
|
144
|
+
title=Statistics.HEADERS["Classes"],
|
|
145
|
+
)
|
|
146
|
+
self.create_boxes_statistics(
|
|
147
|
+
labels=[page["target_boxes"] for page in labels],
|
|
148
|
+
title=Statistics.HEADERS["Coordinates"],
|
|
149
|
+
)
|
|
150
|
+
self.create_boxes_statistics(
|
|
151
|
+
labels=[page["separators"] for page in labels],
|
|
152
|
+
title=Statistics.HEADERS["Separators"],
|
|
153
|
+
)
|
|
154
|
+
self.document.create_md_file()
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def run_dataset_analysis(dataset_path: Path, report_path: Path) -> None:
|
|
158
|
+
"""
|
|
159
|
+
Compute statistics on a dataset in LayoutReader format.
|
|
160
|
+
"""
|
|
161
|
+
Statistics(output=report_path).run(filenames=dataset_path.glob("*.jsonl.gz"))
|
|
@@ -0,0 +1,289 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Convert a HuggingFace dataset to LayoutReader format.
|
|
3
|
+
Example: https://huggingface.co/datasets/Teklia/Newspapers-finlam.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import logging
|
|
7
|
+
import random
|
|
8
|
+
import tempfile
|
|
9
|
+
from enum import Enum
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
|
|
12
|
+
import numpy as np
|
|
13
|
+
import tqdm
|
|
14
|
+
from datasets import Dataset, load_dataset, load_from_disk
|
|
15
|
+
|
|
16
|
+
from layout_reader.datasets.lsd import LineDetector
|
|
17
|
+
from layout_reader.datasets.utils import (
|
|
18
|
+
BBX_FACTOR,
|
|
19
|
+
CLS_SHIFT,
|
|
20
|
+
UNK_TOKEN_ID,
|
|
21
|
+
check_is_valid_bbx,
|
|
22
|
+
check_too_many_zones,
|
|
23
|
+
convert_to_bbx,
|
|
24
|
+
load_yolo_line,
|
|
25
|
+
make_bbx_valid,
|
|
26
|
+
save_gzip_jsonl,
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
random.seed(42)
|
|
30
|
+
np.random.seed(42)
|
|
31
|
+
|
|
32
|
+
logger = logging.getLogger(__name__)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class Mode(Enum):
|
|
36
|
+
HF = "HF"
|
|
37
|
+
YOLO = "YOLO"
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
SPLIT_MAPPING = {
|
|
41
|
+
"val": "dev",
|
|
42
|
+
}
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def sort_and_filter_zones(
|
|
46
|
+
orders: list,
|
|
47
|
+
boxes: list,
|
|
48
|
+
classes: list,
|
|
49
|
+
) -> tuple:
|
|
50
|
+
"""
|
|
51
|
+
Order all boxes and classes. Extract separators if needed.
|
|
52
|
+
"""
|
|
53
|
+
filtered_boxes = []
|
|
54
|
+
filtered_classes = []
|
|
55
|
+
filtered_orders = []
|
|
56
|
+
|
|
57
|
+
# Iterate on zones
|
|
58
|
+
for order, box, classif in zip(
|
|
59
|
+
orders,
|
|
60
|
+
boxes,
|
|
61
|
+
classes,
|
|
62
|
+
strict=True,
|
|
63
|
+
):
|
|
64
|
+
# Normalize box
|
|
65
|
+
box = make_bbx_valid(box)
|
|
66
|
+
if not check_is_valid_bbx(box):
|
|
67
|
+
logger.warning(
|
|
68
|
+
f"Ignoring box {box} ({order}) for element as bounding box is invalid"
|
|
69
|
+
)
|
|
70
|
+
continue
|
|
71
|
+
filtered_boxes.append(box)
|
|
72
|
+
filtered_classes.append(classif)
|
|
73
|
+
filtered_orders.append(order)
|
|
74
|
+
|
|
75
|
+
# Sort boxes and classes based on order
|
|
76
|
+
target_boxes = np.array(filtered_boxes)[np.argsort(orders)].tolist()
|
|
77
|
+
target_classes = np.array(filtered_classes)[np.argsort(orders)].tolist()
|
|
78
|
+
return target_boxes, target_classes
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def prepare_training_data(
|
|
82
|
+
target_boxes, target_classes, shuffle_rate: float = 0.5
|
|
83
|
+
) -> tuple[list[list[int]], list[int], list[int]]:
|
|
84
|
+
"""
|
|
85
|
+
Create dataset of zones from the target order.
|
|
86
|
+
"""
|
|
87
|
+
# Shuffle zones or sort zones by XY
|
|
88
|
+
zones = [
|
|
89
|
+
(order, features[0], features[1])
|
|
90
|
+
for order, features in enumerate(zip(target_boxes, target_classes, strict=True))
|
|
91
|
+
]
|
|
92
|
+
if random.random() < shuffle_rate:
|
|
93
|
+
random.shuffle(zones)
|
|
94
|
+
else:
|
|
95
|
+
zones = sorted(zones, key=lambda z: (z[1][0], z[1][1]))
|
|
96
|
+
|
|
97
|
+
# Build unsorted boxes & classes from target zones
|
|
98
|
+
source_orders, source_boxes, source_classes = map(list, zip(*zones, strict=True))
|
|
99
|
+
|
|
100
|
+
# Get sorted index and start at 1
|
|
101
|
+
target_index = (np.argsort(source_orders) + 1).tolist()
|
|
102
|
+
return source_boxes, source_classes, target_index
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def extract_split_hf(
|
|
106
|
+
subset: Dataset,
|
|
107
|
+
shuffle_rate: float,
|
|
108
|
+
extract_classes: bool = True,
|
|
109
|
+
extract_separators: bool = True,
|
|
110
|
+
) -> list[dict]:
|
|
111
|
+
"""
|
|
112
|
+
Extract data from a Dataset split.
|
|
113
|
+
"""
|
|
114
|
+
samples = []
|
|
115
|
+
# Iterate on pages
|
|
116
|
+
for page in tqdm.tqdm(subset):
|
|
117
|
+
# Convert polygons to boxes
|
|
118
|
+
boxes = [
|
|
119
|
+
convert_to_bbx(polygon, BBX_FACTOR) for polygon in page["zone_polygons"]
|
|
120
|
+
]
|
|
121
|
+
classes = [
|
|
122
|
+
classif + CLS_SHIFT if extract_classes else UNK_TOKEN_ID
|
|
123
|
+
for classif in page["zone_classes"]
|
|
124
|
+
]
|
|
125
|
+
orders = page["zone_orders"]
|
|
126
|
+
|
|
127
|
+
# Filter invalid zones and sort them by order
|
|
128
|
+
target_boxes, target_classes = sort_and_filter_zones(orders, boxes, classes)
|
|
129
|
+
|
|
130
|
+
# Get separators
|
|
131
|
+
separators = []
|
|
132
|
+
if extract_separators:
|
|
133
|
+
with tempfile.NamedTemporaryFile(suffix=".jpg") as tmp_file:
|
|
134
|
+
image_path = tmp_file.name
|
|
135
|
+
page["page_image"].save(image_path)
|
|
136
|
+
detector = LineDetector()
|
|
137
|
+
separators = detector.process(image_path)
|
|
138
|
+
|
|
139
|
+
# Check that the number of zones is < MAX_LEN
|
|
140
|
+
if check_too_many_zones(boxes=target_boxes, separators=separators):
|
|
141
|
+
logger.warning("Too many zones on the current page - skipping.")
|
|
142
|
+
continue
|
|
143
|
+
|
|
144
|
+
# Create input data for Layout Reader - sort or shuffle
|
|
145
|
+
source_boxes, source_classes, target_index = prepare_training_data(
|
|
146
|
+
target_boxes, target_classes, shuffle_rate=shuffle_rate
|
|
147
|
+
)
|
|
148
|
+
samples.append(
|
|
149
|
+
{
|
|
150
|
+
"sample_id": page["page_arkindex_id"],
|
|
151
|
+
"source_boxes": source_boxes,
|
|
152
|
+
"source_classes": source_classes,
|
|
153
|
+
"target_boxes": target_boxes,
|
|
154
|
+
"target_classes": target_classes,
|
|
155
|
+
"target_index": target_index,
|
|
156
|
+
"separators": separators,
|
|
157
|
+
}
|
|
158
|
+
)
|
|
159
|
+
return samples
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def run_dataset_extraction_hf(
|
|
163
|
+
dataset: Dataset,
|
|
164
|
+
shuffle_rate: float,
|
|
165
|
+
extract_classes: bool = True,
|
|
166
|
+
extract_separators: bool = True,
|
|
167
|
+
) -> Path:
|
|
168
|
+
"""
|
|
169
|
+
Extract a dataset from HuggingFace or YOLO format and format it in LayoutReader format.
|
|
170
|
+
"""
|
|
171
|
+
data = {}
|
|
172
|
+
for split in dataset:
|
|
173
|
+
# Extract data
|
|
174
|
+
logger.info(f"Extracting split {split}.")
|
|
175
|
+
data[split] = extract_split_hf(
|
|
176
|
+
dataset[split],
|
|
177
|
+
shuffle_rate,
|
|
178
|
+
extract_classes,
|
|
179
|
+
extract_separators,
|
|
180
|
+
)
|
|
181
|
+
return data
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
def convert_yolo_labels(lines: list, extract_classes: bool):
|
|
185
|
+
"""
|
|
186
|
+
Convert YOLO labels to LayoutReader features.
|
|
187
|
+
Assumes the list of zones is ordered.
|
|
188
|
+
"""
|
|
189
|
+
target_boxes = []
|
|
190
|
+
target_classes = []
|
|
191
|
+
for line in lines:
|
|
192
|
+
classif, box = load_yolo_line(line, BBX_FACTOR)
|
|
193
|
+
box = make_bbx_valid(box)
|
|
194
|
+
if not check_is_valid_bbx(box):
|
|
195
|
+
print(f"Ignoring box {box}) as bounding box is invalid")
|
|
196
|
+
continue
|
|
197
|
+
target_boxes.append(box)
|
|
198
|
+
target_classes.append(classif + CLS_SHIFT if extract_classes else UNK_TOKEN_ID)
|
|
199
|
+
return target_boxes, target_classes
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
def run_dataset_extraction_yolo(
|
|
203
|
+
dataset: str,
|
|
204
|
+
shuffle_rate: float,
|
|
205
|
+
extract_classes: bool = True,
|
|
206
|
+
extract_separators: bool = True,
|
|
207
|
+
) -> Path:
|
|
208
|
+
data = {}
|
|
209
|
+
for subset in Path(dataset).iterdir():
|
|
210
|
+
samples = []
|
|
211
|
+
labels_path = sorted((Path(dataset) / subset.name / "labels").rglob("*.txt"))
|
|
212
|
+
images_base_path = Path(dataset) / subset.name / "images"
|
|
213
|
+
|
|
214
|
+
for label_path in tqdm.tqdm(labels_path):
|
|
215
|
+
image_path = images_base_path / label_path.with_suffix(".jpg").name
|
|
216
|
+
if not image_path.exists():
|
|
217
|
+
continue
|
|
218
|
+
|
|
219
|
+
# Remove duplicate lines (YOLO finds many of them)
|
|
220
|
+
yolo_lines = label_path.read_text().splitlines()
|
|
221
|
+
target_boxes, target_classes = convert_yolo_labels(
|
|
222
|
+
yolo_lines, extract_classes
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
# Extract separators
|
|
226
|
+
separators = []
|
|
227
|
+
if extract_separators:
|
|
228
|
+
detector = LineDetector()
|
|
229
|
+
separators = detector.process(image_path)
|
|
230
|
+
|
|
231
|
+
# Check that the number of zones is < MAX_LEN
|
|
232
|
+
if check_too_many_zones(boxes=target_boxes, separators=separators):
|
|
233
|
+
logger.warning("Too many zones on the current page - skipping.")
|
|
234
|
+
continue
|
|
235
|
+
|
|
236
|
+
# Create input data for Layout Reader - sort or shuffle
|
|
237
|
+
source_boxes, source_classes, target_index = prepare_training_data(
|
|
238
|
+
target_boxes, target_classes, shuffle_rate=shuffle_rate
|
|
239
|
+
)
|
|
240
|
+
samples.append(
|
|
241
|
+
{
|
|
242
|
+
"sample_id": label_path.stem,
|
|
243
|
+
"source_boxes": source_boxes,
|
|
244
|
+
"source_classes": source_classes,
|
|
245
|
+
"target_boxes": target_boxes,
|
|
246
|
+
"target_classes": target_classes,
|
|
247
|
+
"target_index": target_index,
|
|
248
|
+
"separators": separators,
|
|
249
|
+
}
|
|
250
|
+
)
|
|
251
|
+
data[subset.name] = samples
|
|
252
|
+
return data
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
def run_dataset_extraction(
|
|
256
|
+
dataset: str,
|
|
257
|
+
mode: Mode,
|
|
258
|
+
shuffle_rate: float,
|
|
259
|
+
output_dir: Path,
|
|
260
|
+
extract_classes: bool = True,
|
|
261
|
+
extract_separators: bool = True,
|
|
262
|
+
) -> Path:
|
|
263
|
+
output_dir.mkdir(exist_ok=True)
|
|
264
|
+
|
|
265
|
+
if mode == Mode.HF:
|
|
266
|
+
if Path(dataset).exists():
|
|
267
|
+
dataset = load_from_disk(dataset)
|
|
268
|
+
else:
|
|
269
|
+
dataset = load_dataset(dataset)
|
|
270
|
+
|
|
271
|
+
data = run_dataset_extraction_hf(
|
|
272
|
+
dataset,
|
|
273
|
+
shuffle_rate,
|
|
274
|
+
extract_classes,
|
|
275
|
+
extract_separators,
|
|
276
|
+
)
|
|
277
|
+
else:
|
|
278
|
+
data = run_dataset_extraction_yolo(
|
|
279
|
+
dataset,
|
|
280
|
+
shuffle_rate,
|
|
281
|
+
extract_classes,
|
|
282
|
+
extract_separators,
|
|
283
|
+
)
|
|
284
|
+
|
|
285
|
+
# Write to JSONL in gzip format
|
|
286
|
+
for subset in data:
|
|
287
|
+
out_name = output_dir / f"{SPLIT_MAPPING.get(subset, subset)}.jsonl.gz"
|
|
288
|
+
logger.info(f"Saving {subset} data to file: {out_name}.")
|
|
289
|
+
save_gzip_jsonl(filename=out_name, content=data[subset])
|