teklia-layout-reader 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,9 @@
1
+ import logging
2
+ import random
3
+
4
+ logging.basicConfig(
5
+ format="%(asctime)s %(levelname)s/%(name)s: %(message)s",
6
+ level=logging.INFO,
7
+ )
8
+
9
+ random.seed(42)
layout_reader/cli.py ADDED
@@ -0,0 +1,27 @@
1
+ import argparse
2
+
3
+ from layout_reader.datasets import add_dataset_command
4
+ from layout_reader.inference import add_predict_command
5
+ from layout_reader.train.sft import add_training_command
6
+
7
+
8
+ def main():
9
+ parser = argparse.ArgumentParser(
10
+ prog="teklia-layout-reader",
11
+ description="Scripts for Layout Reader",
12
+ )
13
+
14
+ commands = parser.add_subparsers(metavar="subcommands")
15
+ add_dataset_command(commands)
16
+ add_training_command(commands)
17
+ add_predict_command(commands)
18
+
19
+ args = vars(parser.parse_args())
20
+ if "func" in args:
21
+ args.pop("func")(**args)
22
+ else:
23
+ parser.print_help()
24
+
25
+
26
+ if __name__ == "__main__":
27
+ main()
@@ -0,0 +1,99 @@
1
+ """Manage datasets."""
2
+
3
+ import argparse
4
+ from pathlib import Path
5
+
6
+ from layout_reader.datasets.analyze import run_dataset_analysis
7
+ from layout_reader.datasets.extract import Mode, run_dataset_extraction
8
+
9
+
10
+ def add_dataset_command(commands) -> None:
11
+ export = commands.add_parser(
12
+ "dataset",
13
+ description=__doc__,
14
+ help=__doc__,
15
+ )
16
+ subcommands = export.add_subparsers(metavar="subcommands")
17
+ add_extract_parser(subcommands)
18
+ add_analyze_parser(subcommands)
19
+
20
+
21
+ def valid_rate(rate: float) -> float:
22
+ """Convert to a float between 0 and 1."""
23
+ try:
24
+ rate = float(rate)
25
+ except ValueError as e:
26
+ raise argparse.ArgumentTypeError(
27
+ f"`The shuffle rate ({rate}) should be a float."
28
+ ) from e
29
+
30
+ if rate > 1 or rate < 0:
31
+ raise argparse.ArgumentTypeError(
32
+ f"`The shuffle rate ({rate}) should be between 0 and 1."
33
+ )
34
+ return rate
35
+
36
+
37
+ def add_extract_parser(subcommands) -> None:
38
+ parser = subcommands.add_parser(
39
+ "extract",
40
+ description="Convert a HF dataset into LayoutReader format.",
41
+ help="Convert a HF dataset into LayoutReader format.",
42
+ )
43
+
44
+ parser.add_argument(
45
+ "dataset",
46
+ type=str,
47
+ help="Name of the HuggingFace or YOLO dataset",
48
+ )
49
+ parser.add_argument(
50
+ "--output-dir",
51
+ type=Path,
52
+ help="Output directory.",
53
+ required=True,
54
+ )
55
+ parser.add_argument(
56
+ "--mode",
57
+ type=Mode,
58
+ help=f"Dataset type. Must be in {[e.value for e in Mode]}",
59
+ required=True,
60
+ )
61
+ # Optional
62
+ parser.add_argument(
63
+ "--shuffle-rate",
64
+ type=valid_rate,
65
+ help="Ratio of the data that will be shuffled (expected between 0 and 1).",
66
+ default=0.5,
67
+ )
68
+ parser.add_argument(
69
+ "--extract-classes",
70
+ action="store_true",
71
+ help="Whether to extract classes.",
72
+ )
73
+ parser.add_argument(
74
+ "--extract-separators",
75
+ action="store_true",
76
+ help="Whether to extract separators.",
77
+ )
78
+ parser.set_defaults(func=run_dataset_extraction)
79
+
80
+
81
+ def add_analyze_parser(subcommands) -> None:
82
+ parser = subcommands.add_parser(
83
+ "analyze",
84
+ description="Analyze a LayoutReader dataset.",
85
+ help="Analyze a LayoutReader dataset.",
86
+ )
87
+
88
+ parser.add_argument(
89
+ "dataset_path",
90
+ type=Path,
91
+ help="Path of the LayoutReader dataset.",
92
+ )
93
+ parser.add_argument(
94
+ "--report-path",
95
+ type=Path,
96
+ help="Path to save the Markdown report.",
97
+ default="dataset_report.md",
98
+ )
99
+ parser.set_defaults(func=run_dataset_analysis)
@@ -0,0 +1,161 @@
1
+ """
2
+ Analyze a local LayoutReader dataset.
3
+ """
4
+
5
+ import logging
6
+ from collections import Counter, defaultdict
7
+ from functools import partial
8
+ from pathlib import Path
9
+
10
+ import numpy as np
11
+ from mdutils.mdutils import MdUtils
12
+ from prettytable import PrettyTable, TableStyle
13
+
14
+ from layout_reader.datasets.utils import load_gzip_jsonl
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+ METRIC_COLUMN = "Metric"
19
+
20
+
21
+ def create_table(
22
+ data: dict,
23
+ count: bool = False,
24
+ total: bool = True,
25
+ ) -> PrettyTable:
26
+ """
27
+ Each keys will be made into a column.
28
+ We compute min, max, mean, median, total by default.
29
+ Total can be disabled. Count (length) computation can be enabled.
30
+ """
31
+
32
+ statistics = PrettyTable(field_names=[METRIC_COLUMN, *data.keys()])
33
+ statistics.align.update({METRIC_COLUMN: "l"})
34
+ statistics.set_style(TableStyle.MARKDOWN)
35
+
36
+ operations = []
37
+
38
+ if count:
39
+ operations.append(("Count", len, None))
40
+
41
+ operations.extend(
42
+ [
43
+ ("Min", np.min, None),
44
+ ("Max", np.max, None),
45
+ ("Mean", np.mean, 2),
46
+ ("Median", np.median, 2),
47
+ ]
48
+ )
49
+ if total:
50
+ operations.append(("Total", np.sum, None))
51
+
52
+ statistics.add_rows(
53
+ [
54
+ [
55
+ col_name,
56
+ *list(
57
+ map(
58
+ # Round values if needed
59
+ partial(round, ndigits=digits),
60
+ map(operator, data.values()),
61
+ )
62
+ ),
63
+ ]
64
+ for col_name, operator, digits in operations
65
+ ]
66
+ )
67
+
68
+ return statistics
69
+
70
+
71
+ class Statistics:
72
+ HEADERS = {
73
+ "Coordinates": "Coordinates statistics",
74
+ "Classes": "Classes statistics",
75
+ "Separators": "Separators statistics",
76
+ }
77
+
78
+ def __init__(self, output: Path) -> None:
79
+ self.document = MdUtils(file_name=str(output), title="Statistics")
80
+
81
+ def _write_section(self, table: PrettyTable, title: str, level: int = 2) -> None:
82
+ """
83
+ Write the new section in the file.
84
+
85
+ <title with appropriate level>
86
+
87
+ <table>
88
+
89
+ """
90
+ self.document.new_header(level=level, title=title, add_table_of_contents="n")
91
+ self.document.write("\n")
92
+
93
+ logger.info(f"{title}\n\n{table}\n")
94
+
95
+ self.document.write(table.get_string())
96
+ self.document.write("\n")
97
+
98
+ def create_classes_statistics(self, labels: list[str], title: str) -> None:
99
+ """
100
+ Compute statistics on class labels and write them to file.
101
+ """
102
+ class_counter = Counter()
103
+
104
+ for classes in labels:
105
+ class_counter.update(classes)
106
+
107
+ statistics = PrettyTable(
108
+ field_names=[METRIC_COLUMN] + [f"Class ID {k}" for k in class_counter]
109
+ )
110
+ statistics.set_style(TableStyle.MARKDOWN)
111
+
112
+ statistics.add_row(["Count"] + [v for v in class_counter.values()])
113
+
114
+ self._write_section(
115
+ table=statistics,
116
+ title=title,
117
+ )
118
+
119
+ def create_boxes_statistics(self, labels: list[list[float]], title: str) -> None:
120
+ """
121
+ Compute statistics on bounding boxes and write them to file.
122
+ """
123
+ data = defaultdict(list)
124
+
125
+ for page in labels:
126
+ for box in page:
127
+ x1, y1, x2, y2 = box
128
+ data["Box width"].append(x2 - x1)
129
+ data["Box height"].append(y2 - y1)
130
+ data["Box surface (%)"].append((x2 - x1) * (y2 - y1) / 10000)
131
+
132
+ self._write_section(
133
+ table=create_table(data=data, total=False),
134
+ title=title,
135
+ )
136
+
137
+ def run(self, filenames: list[Path]) -> None:
138
+ """Iterate over each split and create report"""
139
+ for filename in filenames:
140
+ self.document.new_header(level=1, title=str(filename.stem).capitalize())
141
+ labels = load_gzip_jsonl(filename)
142
+ self.create_classes_statistics(
143
+ labels=[page["target_classes"] for page in labels],
144
+ title=Statistics.HEADERS["Classes"],
145
+ )
146
+ self.create_boxes_statistics(
147
+ labels=[page["target_boxes"] for page in labels],
148
+ title=Statistics.HEADERS["Coordinates"],
149
+ )
150
+ self.create_boxes_statistics(
151
+ labels=[page["separators"] for page in labels],
152
+ title=Statistics.HEADERS["Separators"],
153
+ )
154
+ self.document.create_md_file()
155
+
156
+
157
+ def run_dataset_analysis(dataset_path: Path, report_path: Path) -> None:
158
+ """
159
+ Compute statistics on a dataset in LayoutReader format.
160
+ """
161
+ Statistics(output=report_path).run(filenames=dataset_path.glob("*.jsonl.gz"))
@@ -0,0 +1,289 @@
1
+ """
2
+ Convert a HuggingFace dataset to LayoutReader format.
3
+ Example: https://huggingface.co/datasets/Teklia/Newspapers-finlam.
4
+ """
5
+
6
+ import logging
7
+ import random
8
+ import tempfile
9
+ from enum import Enum
10
+ from pathlib import Path
11
+
12
+ import numpy as np
13
+ import tqdm
14
+ from datasets import Dataset, load_dataset, load_from_disk
15
+
16
+ from layout_reader.datasets.lsd import LineDetector
17
+ from layout_reader.datasets.utils import (
18
+ BBX_FACTOR,
19
+ CLS_SHIFT,
20
+ UNK_TOKEN_ID,
21
+ check_is_valid_bbx,
22
+ check_too_many_zones,
23
+ convert_to_bbx,
24
+ load_yolo_line,
25
+ make_bbx_valid,
26
+ save_gzip_jsonl,
27
+ )
28
+
29
+ random.seed(42)
30
+ np.random.seed(42)
31
+
32
+ logger = logging.getLogger(__name__)
33
+
34
+
35
+ class Mode(Enum):
36
+ HF = "HF"
37
+ YOLO = "YOLO"
38
+
39
+
40
+ SPLIT_MAPPING = {
41
+ "val": "dev",
42
+ }
43
+
44
+
45
+ def sort_and_filter_zones(
46
+ orders: list,
47
+ boxes: list,
48
+ classes: list,
49
+ ) -> tuple:
50
+ """
51
+ Order all boxes and classes. Extract separators if needed.
52
+ """
53
+ filtered_boxes = []
54
+ filtered_classes = []
55
+ filtered_orders = []
56
+
57
+ # Iterate on zones
58
+ for order, box, classif in zip(
59
+ orders,
60
+ boxes,
61
+ classes,
62
+ strict=True,
63
+ ):
64
+ # Normalize box
65
+ box = make_bbx_valid(box)
66
+ if not check_is_valid_bbx(box):
67
+ logger.warning(
68
+ f"Ignoring box {box} ({order}) for element as bounding box is invalid"
69
+ )
70
+ continue
71
+ filtered_boxes.append(box)
72
+ filtered_classes.append(classif)
73
+ filtered_orders.append(order)
74
+
75
+ # Sort boxes and classes based on order
76
+ target_boxes = np.array(filtered_boxes)[np.argsort(orders)].tolist()
77
+ target_classes = np.array(filtered_classes)[np.argsort(orders)].tolist()
78
+ return target_boxes, target_classes
79
+
80
+
81
+ def prepare_training_data(
82
+ target_boxes, target_classes, shuffle_rate: float = 0.5
83
+ ) -> tuple[list[list[int]], list[int], list[int]]:
84
+ """
85
+ Create dataset of zones from the target order.
86
+ """
87
+ # Shuffle zones or sort zones by XY
88
+ zones = [
89
+ (order, features[0], features[1])
90
+ for order, features in enumerate(zip(target_boxes, target_classes, strict=True))
91
+ ]
92
+ if random.random() < shuffle_rate:
93
+ random.shuffle(zones)
94
+ else:
95
+ zones = sorted(zones, key=lambda z: (z[1][0], z[1][1]))
96
+
97
+ # Build unsorted boxes & classes from target zones
98
+ source_orders, source_boxes, source_classes = map(list, zip(*zones, strict=True))
99
+
100
+ # Get sorted index and start at 1
101
+ target_index = (np.argsort(source_orders) + 1).tolist()
102
+ return source_boxes, source_classes, target_index
103
+
104
+
105
+ def extract_split_hf(
106
+ subset: Dataset,
107
+ shuffle_rate: float,
108
+ extract_classes: bool = True,
109
+ extract_separators: bool = True,
110
+ ) -> list[dict]:
111
+ """
112
+ Extract data from a Dataset split.
113
+ """
114
+ samples = []
115
+ # Iterate on pages
116
+ for page in tqdm.tqdm(subset):
117
+ # Convert polygons to boxes
118
+ boxes = [
119
+ convert_to_bbx(polygon, BBX_FACTOR) for polygon in page["zone_polygons"]
120
+ ]
121
+ classes = [
122
+ classif + CLS_SHIFT if extract_classes else UNK_TOKEN_ID
123
+ for classif in page["zone_classes"]
124
+ ]
125
+ orders = page["zone_orders"]
126
+
127
+ # Filter invalid zones and sort them by order
128
+ target_boxes, target_classes = sort_and_filter_zones(orders, boxes, classes)
129
+
130
+ # Get separators
131
+ separators = []
132
+ if extract_separators:
133
+ with tempfile.NamedTemporaryFile(suffix=".jpg") as tmp_file:
134
+ image_path = tmp_file.name
135
+ page["page_image"].save(image_path)
136
+ detector = LineDetector()
137
+ separators = detector.process(image_path)
138
+
139
+ # Check that the number of zones is < MAX_LEN
140
+ if check_too_many_zones(boxes=target_boxes, separators=separators):
141
+ logger.warning("Too many zones on the current page - skipping.")
142
+ continue
143
+
144
+ # Create input data for Layout Reader - sort or shuffle
145
+ source_boxes, source_classes, target_index = prepare_training_data(
146
+ target_boxes, target_classes, shuffle_rate=shuffle_rate
147
+ )
148
+ samples.append(
149
+ {
150
+ "sample_id": page["page_arkindex_id"],
151
+ "source_boxes": source_boxes,
152
+ "source_classes": source_classes,
153
+ "target_boxes": target_boxes,
154
+ "target_classes": target_classes,
155
+ "target_index": target_index,
156
+ "separators": separators,
157
+ }
158
+ )
159
+ return samples
160
+
161
+
162
+ def run_dataset_extraction_hf(
163
+ dataset: Dataset,
164
+ shuffle_rate: float,
165
+ extract_classes: bool = True,
166
+ extract_separators: bool = True,
167
+ ) -> Path:
168
+ """
169
+ Extract a dataset from HuggingFace or YOLO format and format it in LayoutReader format.
170
+ """
171
+ data = {}
172
+ for split in dataset:
173
+ # Extract data
174
+ logger.info(f"Extracting split {split}.")
175
+ data[split] = extract_split_hf(
176
+ dataset[split],
177
+ shuffle_rate,
178
+ extract_classes,
179
+ extract_separators,
180
+ )
181
+ return data
182
+
183
+
184
+ def convert_yolo_labels(lines: list, extract_classes: bool):
185
+ """
186
+ Convert YOLO labels to LayoutReader features.
187
+ Assumes the list of zones is ordered.
188
+ """
189
+ target_boxes = []
190
+ target_classes = []
191
+ for line in lines:
192
+ classif, box = load_yolo_line(line, BBX_FACTOR)
193
+ box = make_bbx_valid(box)
194
+ if not check_is_valid_bbx(box):
195
+ print(f"Ignoring box {box}) as bounding box is invalid")
196
+ continue
197
+ target_boxes.append(box)
198
+ target_classes.append(classif + CLS_SHIFT if extract_classes else UNK_TOKEN_ID)
199
+ return target_boxes, target_classes
200
+
201
+
202
+ def run_dataset_extraction_yolo(
203
+ dataset: str,
204
+ shuffle_rate: float,
205
+ extract_classes: bool = True,
206
+ extract_separators: bool = True,
207
+ ) -> Path:
208
+ data = {}
209
+ for subset in Path(dataset).iterdir():
210
+ samples = []
211
+ labels_path = sorted((Path(dataset) / subset.name / "labels").rglob("*.txt"))
212
+ images_base_path = Path(dataset) / subset.name / "images"
213
+
214
+ for label_path in tqdm.tqdm(labels_path):
215
+ image_path = images_base_path / label_path.with_suffix(".jpg").name
216
+ if not image_path.exists():
217
+ continue
218
+
219
+ # Remove duplicate lines (YOLO finds many of them)
220
+ yolo_lines = label_path.read_text().splitlines()
221
+ target_boxes, target_classes = convert_yolo_labels(
222
+ yolo_lines, extract_classes
223
+ )
224
+
225
+ # Extract separators
226
+ separators = []
227
+ if extract_separators:
228
+ detector = LineDetector()
229
+ separators = detector.process(image_path)
230
+
231
+ # Check that the number of zones is < MAX_LEN
232
+ if check_too_many_zones(boxes=target_boxes, separators=separators):
233
+ logger.warning("Too many zones on the current page - skipping.")
234
+ continue
235
+
236
+ # Create input data for Layout Reader - sort or shuffle
237
+ source_boxes, source_classes, target_index = prepare_training_data(
238
+ target_boxes, target_classes, shuffle_rate=shuffle_rate
239
+ )
240
+ samples.append(
241
+ {
242
+ "sample_id": label_path.stem,
243
+ "source_boxes": source_boxes,
244
+ "source_classes": source_classes,
245
+ "target_boxes": target_boxes,
246
+ "target_classes": target_classes,
247
+ "target_index": target_index,
248
+ "separators": separators,
249
+ }
250
+ )
251
+ data[subset.name] = samples
252
+ return data
253
+
254
+
255
+ def run_dataset_extraction(
256
+ dataset: str,
257
+ mode: Mode,
258
+ shuffle_rate: float,
259
+ output_dir: Path,
260
+ extract_classes: bool = True,
261
+ extract_separators: bool = True,
262
+ ) -> Path:
263
+ output_dir.mkdir(exist_ok=True)
264
+
265
+ if mode == Mode.HF:
266
+ if Path(dataset).exists():
267
+ dataset = load_from_disk(dataset)
268
+ else:
269
+ dataset = load_dataset(dataset)
270
+
271
+ data = run_dataset_extraction_hf(
272
+ dataset,
273
+ shuffle_rate,
274
+ extract_classes,
275
+ extract_separators,
276
+ )
277
+ else:
278
+ data = run_dataset_extraction_yolo(
279
+ dataset,
280
+ shuffle_rate,
281
+ extract_classes,
282
+ extract_separators,
283
+ )
284
+
285
+ # Write to JSONL in gzip format
286
+ for subset in data:
287
+ out_name = output_dir / f"{SPLIT_MAPPING.get(subset, subset)}.jsonl.gz"
288
+ logger.info(f"Saving {subset} data to file: {out_name}.")
289
+ save_gzip_jsonl(filename=out_name, content=data[subset])