teklia-layout-reader 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,133 @@
1
+ from pathlib import Path
2
+
3
+ import cv2
4
+ import numpy as np
5
+ from matplotlib import pyplot as plt
6
+
7
+ from layout_reader.datasets.utils import check_is_valid_bbx, make_bbx_valid, resize_bbx
8
+
9
+
10
+ class LineDetector:
11
+ def __init__(
12
+ self,
13
+ process_size: tuple[int] = (2000, 1000),
14
+ target_size: tuple[int] = (1000, 1000),
15
+ filter_ratio: tuple[float] = (0.05, 0.05),
16
+ ):
17
+ """
18
+ Initialize Line Detector with filtering parameters
19
+
20
+ Args:
21
+ min_line_height: Minimum height distance of detected lines
22
+ min_line_width: Minimum width distance of detected lines
23
+ angle_tolerance: Angle tolerance for merging collinear lines (degrees)
24
+ merge_distance: Maximum distance between lines to consider merging
25
+ """
26
+ self.process_width, self.process_height = process_size
27
+ self.target_width, self.target_height = target_size
28
+ self.filter_width_ratio, self.filter_height_ratio = filter_ratio
29
+
30
+ # Create LSD detector
31
+ self.lsd = cv2.createLineSegmentDetector(cv2.LSD_REFINE_STD)
32
+
33
+ def calculate_line_properties(self, line: np.ndarray) -> tuple[float, float, float]:
34
+ """
35
+ Calculate line properties: length, angle, and center
36
+
37
+ Args:
38
+ line: Line coordinates [x1, y1, x2, y2]
39
+
40
+ Returns:
41
+ Tuple of (length, angle_degrees, center_x, center_y)
42
+ """
43
+ x1, y1, x2, y2 = line
44
+
45
+ # Calculate length
46
+ height = y2 - y1
47
+ width = x2 - x1
48
+
49
+ # Calculate angle (in degrees)
50
+ angle = np.degrees(np.arctan2(y2 - y1, x2 - x1))
51
+
52
+ return height, width, angle
53
+
54
+ def filter_lines(
55
+ self, lines: np.ndarray, image_height: int, image_width: int
56
+ ) -> np.ndarray:
57
+ """
58
+ Filter lines based on length and other criteria
59
+
60
+ Args:
61
+ lines: Array of line coordinates
62
+
63
+ Returns:
64
+ Filtered lines array
65
+ """
66
+ if len(lines) == 0:
67
+ return lines
68
+
69
+ filtered_lines = []
70
+
71
+ for line in lines:
72
+ height, width, angle = self.calculate_line_properties(line)
73
+
74
+ # Filter by minimum length
75
+ if (
76
+ (height > width and height > self.filter_height_ratio * image_height)
77
+ or (width >= height and width > self.filter_width_ratio * image_width)
78
+ ) and (abs(angle) < 10 or abs(angle) > 80):
79
+ filtered_lines.append(line)
80
+
81
+ return filtered_lines
82
+
83
+ def visualize(self, image, lines, output_path="results.jpg"):
84
+ vis_image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
85
+
86
+ # Draw lines
87
+ for line in lines:
88
+ x1, y1, x2, y2 = line
89
+ cv2.line(vis_image, (x1, y1), (x2, y2), (255, 0, 0), 2)
90
+
91
+ # Display
92
+ h, w = image.shape
93
+ plt.figure(figsize=(w / 100, h / 100), dpi=100)
94
+ plt.axis("off")
95
+ plt.imshow(vis_image)
96
+ plt.tight_layout(pad=0)
97
+ plt.savefig(output_path)
98
+ plt.close()
99
+
100
+ def process(self, image_path: Path, visualize: bool = False) -> np.ndarray:
101
+ """
102
+ Process an image
103
+ """
104
+ image = cv2.imread(str(image_path), flags=0)
105
+
106
+ # Resize and detect
107
+ resized_image = cv2.resize(image, (self.process_width, self.process_height))
108
+ lines = self.lsd.detect(resized_image)[0].reshape(-1, 4)
109
+
110
+ # Scale lines to target size
111
+ lines = resize_bbx(
112
+ lines,
113
+ width=self.process_width,
114
+ height=self.process_height,
115
+ target_width=self.target_width,
116
+ target_height=self.target_height,
117
+ )
118
+
119
+ # Filter lines
120
+ lines = self.filter_lines(lines, self.target_height, self.target_width)
121
+
122
+ # Transform lines into valid bounding boxes
123
+ for i in range(len(lines)):
124
+ lines[i] = lines[i].tolist()
125
+ if not check_is_valid_bbx(lines[i]):
126
+ lines[i] = make_bbx_valid(lines[i])
127
+
128
+ # Visualize
129
+ if visualize:
130
+ target_image = cv2.resize(image, (self.target_width, self.target_height))
131
+ self.visualize(target_image, lines, f"{image_path.stem}_filter.jpg")
132
+
133
+ return lines
@@ -0,0 +1,128 @@
1
+ """Helpers."""
2
+
3
+ import gzip
4
+ import json
5
+ from pathlib import Path
6
+
7
+ import numpy as np
8
+
9
+ MAX_LEN = 510
10
+ UNK_TOKEN_ID = 3
11
+ # CLS_TOKEN_ID = 0
12
+ # EOS_TOKEN_ID = 2
13
+ CLS_SHIFT = 4 # 0 1 2 3 are already taken
14
+ BBX_FACTOR = 1000
15
+
16
+
17
+ def check_is_valid_bbx(bbx: list[int], min_value=0, max_value=1000) -> bool:
18
+ """
19
+ Check if box is valid.
20
+ """
21
+ x1, y1, x2, y2 = bbx
22
+ return (
23
+ (y1 < y2)
24
+ and (x1 < x2)
25
+ and (x2 < max_value)
26
+ and (y2 < max_value)
27
+ and (x1 >= min_value)
28
+ and (y1 >= min_value)
29
+ )
30
+
31
+
32
+ def convert_to_bbx(
33
+ polygon: list[list[float]], bbx_factor, min_value=0, max_value=1000
34
+ ) -> list[int]:
35
+ """
36
+ Convert a polygon to a bounding box.
37
+ """
38
+ # Extract x and y from polygon
39
+ xs = np.array(polygon)[:, 0]
40
+ ys = np.array(polygon)[:, 1]
41
+
42
+ # Scale between min_value and max_value
43
+ x1 = max(min_value, int(xs.min() * bbx_factor))
44
+ y1 = max(min_value, int(ys.min() * bbx_factor))
45
+ x2 = min(max_value - 1, int(xs.max() * bbx_factor))
46
+ y2 = min(max_value - 1, int(ys.max() * bbx_factor))
47
+ return [x1, y1, x2, y2]
48
+
49
+
50
+ def make_bbx_valid(box, min_value: int = 0, max_value: int = 1000) -> list[int]:
51
+ x1, y1, x2, y2 = box
52
+ # Clip
53
+ x1 = max(min_value, x1)
54
+ y1 = max(min_value, y1)
55
+ x2 = min(max_value - 1, x2)
56
+ y2 = min(max_value - 1, y2)
57
+
58
+ # Shift equal coordinates
59
+ if x1 == x2:
60
+ if x2 != max_value - 1:
61
+ x2 += 1
62
+ else:
63
+ x1 -= 1
64
+ if y1 == y2:
65
+ if y2 != max_value - 1:
66
+ y2 += 1
67
+ else:
68
+ x2 -= 1
69
+ return [x1, y1, x2, y2]
70
+
71
+
72
+ def resize_bbx(
73
+ lines: np.ndarray, width: int, height: int, target_width: int, target_height: int
74
+ ) -> list[list[float]]:
75
+ w_ratio = target_width / width
76
+ h_ratio = target_height / height
77
+ newlines = []
78
+ for line in lines:
79
+ x1, y1, x2, y2 = line.tolist()
80
+ newlines.append(
81
+ np.array([x1 * w_ratio, y1 * h_ratio, x2 * w_ratio, y2 * h_ratio])
82
+ )
83
+ return np.array([(sep).astype(int) for sep in newlines])
84
+
85
+
86
+ def save_gzip_jsonl(filename: Path, content: list[dict]) -> None:
87
+ """
88
+ Write content in GZIP JSONL format.
89
+
90
+ Args:
91
+ filename (Path): Output filename.
92
+ content (str): Content to write.
93
+ """
94
+ with gzip.open(filename, "wt") as f:
95
+ f.write("\n".join([json.dumps(c) for c in content]))
96
+
97
+
98
+ def load_gzip_jsonl(filename: Path) -> list[dict]:
99
+ """
100
+ Read content in GZIP JSONL format.
101
+
102
+ Args:
103
+ filename (Path): Input filename.
104
+ """
105
+ with gzip.open(filename, "rt") as f:
106
+ content = f.read().splitlines()
107
+ return [json.loads(c) for c in content]
108
+
109
+
110
+ def check_too_many_zones(
111
+ boxes: list[list[int]], separators: list[list[int]], max_len: int = 512
112
+ ) -> bool:
113
+ # Count total objects + [BOS] + [EOS]
114
+ return (len(separators) + len(boxes) + 2) >= max_len
115
+
116
+
117
+ def load_yolo_line(line, bbx_factor) -> tuple[int, list[int]]:
118
+ parts = line.strip().split()
119
+ if len(parts) != 5:
120
+ raise ValueError(f"Invalid YOLO format: expected 5 values, got {len(parts)}")
121
+ classif = int(parts[0])
122
+ x_c, y_c, w, h = map(float, parts[1:])
123
+ x_min = (x_c - w / 2) * bbx_factor
124
+ y_min = (y_c - h / 2) * bbx_factor
125
+ x_max = (x_c + w / 2) * bbx_factor
126
+ y_max = (y_c + h / 2) * bbx_factor
127
+ box = [int(x_min), int(y_min), int(x_max), int(y_max)]
128
+ return classif, box
@@ -0,0 +1,358 @@
1
+ import logging
2
+ import random
3
+ from collections import defaultdict
4
+ from pathlib import Path
5
+
6
+ import numpy as np
7
+ import torch
8
+ import yaml
9
+ from colour import Color
10
+ from datasets import load_dataset
11
+ from PIL import Image, ImageDraw, ImageFont
12
+ from transformers import LayoutLMv3ForTokenClassification
13
+
14
+ logger = logging.getLogger(__name__)
15
+ FONT = ImageFont.truetype("fonts/LinuxLibertine.ttf")
16
+
17
+
18
+ # Maximum number of zones to be ordered.
19
+ MAX_LEN = 510
20
+
21
+ # Maximum coordinate after normalization
22
+ MAX_COOR = 1000
23
+
24
+ # Custom classes used by LayoutReader
25
+ CLS_TOKEN_ID = 0 # Class token
26
+ PAD_TOKEN_ID = 3 # Padding token
27
+ EOS_TOKEN_ID = 2 # End-of-sequence token
28
+
29
+ # Label to be ignored in the loss computation (padding, separators...)
30
+ IGNORE_LABEL_ID = -100
31
+
32
+ # Width (1000) divided in 10 columns => 100 pixels/columns
33
+ COLUMN_WIDTH = 1000 // 10
34
+
35
+
36
+ def sort_zones(zones, sort_method: str):
37
+ if sort_method == "random":
38
+ zones = list(zones)
39
+ random.shuffle(zones)
40
+ return
41
+
42
+ sort_keys = {
43
+ "sortxy_by_column": lambda z: (z[1][0] // COLUMN_WIDTH, z[1][1]),
44
+ "sortxy": lambda z: (z[1][0], z[1][1]),
45
+ "sortyx": lambda z: (z[1][1], z[1][0]),
46
+ }
47
+ zones.sort(key=sort_keys[sort_method])
48
+
49
+
50
+ def sort_sample(
51
+ element, sort_ratio: float = 0.5, sort_method: str = "sortxy_by_column"
52
+ ):
53
+ """
54
+ Sort zones from a page.
55
+ """
56
+ # Get boxes or classes and sort them
57
+ boxes = element.get("target_boxes") or element.get("source_boxes")
58
+ classes = element.get("target_classes") or element.get("source_classes") or []
59
+
60
+ if random.random() > sort_ratio or not boxes:
61
+ return element
62
+
63
+ if classes:
64
+ zones = [
65
+ (i, box, cls)
66
+ for i, (box, cls) in enumerate(zip(boxes, classes, strict=True))
67
+ ]
68
+ else:
69
+ zones = [(i, box) for i, box in enumerate(boxes)]
70
+
71
+ # Sort the zones
72
+ sort_zones(zones, sort_method)
73
+
74
+ if "target_index" in element:
75
+ target_orders = (
76
+ np.argsort([zone[0] for zone in zones]) + 1
77
+ ).tolist() # start at 1
78
+ _, boxes, classes = map(list, zip(*zones, strict=True))
79
+ element["target_index"] = target_orders
80
+ element["source_boxes"] = boxes
81
+ element["source_classes"] = classes
82
+ return element
83
+
84
+ if classes:
85
+ _, boxes, classes = map(list, zip(*zones, strict=True))
86
+ else:
87
+ _, boxes = map(list, zip(*zones, strict=True))
88
+
89
+ element["source_boxes"] = boxes
90
+ element["source_classes"] = classes
91
+ return element
92
+
93
+
94
+ def read_yaml(filename: str):
95
+ if not Path(filename).exists():
96
+ raise FileNotFoundError(f"Configuration not found: {filename}")
97
+ return yaml.safe_load(Path(filename).read_text())
98
+
99
+
100
+ def load_dataset_split(dataset_path: str, split: str):
101
+ filename = Path(dataset_path) / f"{split}.jsonl.gz"
102
+ if not filename.exists():
103
+ raise FileNotFoundError(f"Dataset file not found: {filename}")
104
+ try:
105
+ return load_dataset(
106
+ "json",
107
+ data_files={split: str(filename)},
108
+ )[split]
109
+ except Exception as e:
110
+ raise ValueError(f"Failed to load dataset file {filename}") from e
111
+
112
+
113
+ def load_model(model_path: str):
114
+ try:
115
+ return LayoutLMv3ForTokenClassification.from_pretrained(
116
+ model_path,
117
+ device_map="auto",
118
+ )
119
+ except Exception as e:
120
+ raise ValueError(
121
+ f"Failed to load model from '{model_path}'. "
122
+ "Model path must be a valid Hugging Face model ID or a local directory."
123
+ ) from e
124
+
125
+
126
+ class DataCollator:
127
+ def __init__(self, with_classes: bool = False, with_separators: bool = False):
128
+ self.with_classes = with_classes
129
+ self.with_separators = with_separators
130
+
131
+ def _truncate(self, seq: list, name: str, max_len: int) -> list:
132
+ if len(seq) > max_len:
133
+ logger.warning(
134
+ f"Truncated {name}. Length ({len(seq)}) exceeds MAX_LEN ({max_len})."
135
+ )
136
+ return seq[:max_len]
137
+ return seq
138
+
139
+ def _prepare_single_feature(self, feature: dict) -> dict[str, list]:
140
+ bboxes = feature["source_boxes"].copy()
141
+ len_boxes = len(bboxes)
142
+ separators = list(feature.get("separators", []))
143
+ if self.with_separators:
144
+ bboxes.extend(separators)
145
+ bboxes = self._truncate(bboxes, "bounding boxes", MAX_LEN)
146
+
147
+ # Prepare reading order
148
+ # Separators should be ignored in the loss
149
+ labels = feature["target_index"]
150
+ if self.with_separators:
151
+ labels.extend([IGNORE_LABEL_ID] * len(separators))
152
+ labels = self._truncate(labels, "labels", MAX_LEN)
153
+
154
+ # Prepare classes
155
+ input_ids = (
156
+ feature["source_classes"]
157
+ if self.with_classes
158
+ else [PAD_TOKEN_ID] * len_boxes
159
+ )
160
+ if self.with_separators:
161
+ input_ids.extend([PAD_TOKEN_ID] * len(separators))
162
+ input_ids = self._truncate(input_ids, "input_ids", MAX_LEN)
163
+
164
+ # Prepare attention mask
165
+ attention_mask = [1] * len(bboxes)
166
+
167
+ # Sanity check
168
+ assert len(bboxes) == len(labels) == len(input_ids) == len(attention_mask), (
169
+ f"Length mismatch: bbox={len(bboxes)}, labels={len(labels)}, "
170
+ f"input_ids={len(input_ids)}, attention_mask={len(attention_mask)}"
171
+ )
172
+
173
+ return {
174
+ "bbox": bboxes,
175
+ "labels": labels,
176
+ "input_ids": input_ids,
177
+ "attention_mask": attention_mask,
178
+ }
179
+
180
+ def _add_special_tokens(self, batch: dict[str, list[list]]) -> None:
181
+ """Add CLS and EOS tokens (in-place)."""
182
+ for i in range(len(batch["bbox"])):
183
+ batch["bbox"][i] = [[0, 0, 0, 0]] + batch["bbox"][i] + [[0, 0, 0, 0]]
184
+ batch["labels"][i] = (
185
+ [IGNORE_LABEL_ID] + batch["labels"][i] + [IGNORE_LABEL_ID]
186
+ )
187
+ batch["input_ids"][i] = (
188
+ [CLS_TOKEN_ID] + batch["input_ids"][i] + [EOS_TOKEN_ID]
189
+ )
190
+ batch["attention_mask"][i] = [1] + batch["attention_mask"][i] + [1]
191
+
192
+ def _pad_sequences(self, batch: dict[str, list[list]]) -> None:
193
+ """Pad all sequences to max length (in-place)."""
194
+ max_len = max(len(x) for x in batch["bbox"])
195
+
196
+ for i in range(len(batch["bbox"])):
197
+ pad_len = max_len - len(batch["bbox"][i])
198
+ batch["bbox"][i] += [[0, 0, 0, 0]] * pad_len
199
+ batch["labels"][i] += [IGNORE_LABEL_ID] * pad_len
200
+ batch["input_ids"][i] += [PAD_TOKEN_ID] * pad_len
201
+ batch["attention_mask"][i] += [0] * pad_len
202
+
203
+ def __call__(self, features: list[dict]) -> dict[str, torch.Tensor]:
204
+ batch = {
205
+ "bbox": [],
206
+ "labels": [],
207
+ "input_ids": [],
208
+ "attention_mask": [],
209
+ }
210
+
211
+ for feature in features:
212
+ processed = self._prepare_single_feature(feature)
213
+ for key in batch:
214
+ batch[key].append(processed[key])
215
+
216
+ # Add special tokens
217
+ self._add_special_tokens(batch)
218
+
219
+ # Pad to same length
220
+ self._pad_sequences(batch)
221
+
222
+ batch_tensors = {
223
+ key: torch.tensor(batch[key], dtype=torch.long) for key in batch
224
+ }
225
+
226
+ # Post-process labels
227
+ labels_tensor = batch_tensors["labels"]
228
+ labels_tensor[labels_tensor > MAX_LEN] = IGNORE_LABEL_ID
229
+ labels_tensor[labels_tensor > 0] -= 1
230
+ batch_tensors["labels"] = labels_tensor
231
+
232
+ return batch_tensors
233
+
234
+
235
+ def boxes_to_inputs(boxes, cls, separators) -> dict[str, torch.Tensor]:
236
+ all_boxes = boxes.copy()
237
+ if separators:
238
+ all_boxes += separators
239
+ if cls:
240
+ cls += [PAD_TOKEN_ID] * len(separators)
241
+
242
+ if not cls:
243
+ cls = [PAD_TOKEN_ID] * len(all_boxes)
244
+
245
+ bbox = [[0, 0, 0, 0]] + all_boxes + [[0, 0, 0, 0]]
246
+ input_ids = [CLS_TOKEN_ID] + cls + [EOS_TOKEN_ID]
247
+ attention_mask = [1] + [1] * len(all_boxes) + [1]
248
+ return {
249
+ "bbox": torch.tensor([bbox]),
250
+ "attention_mask": torch.tensor([attention_mask]),
251
+ "input_ids": torch.tensor([input_ids]),
252
+ }
253
+
254
+
255
+ def prepare_inputs(
256
+ inputs: dict[str, torch.Tensor], model: LayoutLMv3ForTokenClassification
257
+ ) -> dict[str, torch.Tensor]:
258
+ prepared = {}
259
+ for key, tensor in inputs.items():
260
+ tensor = tensor.to(model.device)
261
+ if torch.is_floating_point(tensor):
262
+ tensor = tensor.to(model.dtype)
263
+ prepared[key] = tensor
264
+ return prepared
265
+
266
+
267
+ def parse_logits(logits: torch.Tensor, length: int) -> list[int]:
268
+ """
269
+ Convert logits to reading orders.
270
+ """
271
+
272
+ def _find_conflicts(assigned_orders: list[int]) -> dict[int, list[int]]:
273
+ order_to_elements = defaultdict(list)
274
+ for element_idx, order in enumerate(assigned_orders):
275
+ order_to_elements[order].append(element_idx)
276
+
277
+ # Keep only positions with conflicts
278
+ return {
279
+ order: elements
280
+ for order, elements in order_to_elements.items()
281
+ if len(elements) > 1
282
+ }
283
+
284
+ def _resolve_conflicts(
285
+ conflicts: dict[int, list[int]],
286
+ assigned_orders: list[int],
287
+ candidate_orders: list[list[int]],
288
+ logits: torch.Tensor,
289
+ ) -> None:
290
+ """
291
+ Resolve conflicts based on logits values (keep highest).
292
+ """
293
+ for order, element_indices in conflicts.items():
294
+ # Get logit scores for all elements predicted to current order
295
+ elements_by_score = [
296
+ (element_idx, logits[element_idx, order].item())
297
+ for element_idx in element_indices
298
+ ]
299
+ elements_by_score.sort(key=lambda x: x[1], reverse=True)
300
+
301
+ # Reassign all but the highest-scoring element to next candidates
302
+ for element_idx, _ in elements_by_score[1:]:
303
+ assigned_orders[element_idx] = candidate_orders[element_idx].pop()
304
+
305
+ # Extract relevant logits (skip special tokens)
306
+ logits = logits[1 : length + 1, :length]
307
+
308
+ # Get sorted candidate positions for each element (ascending order)
309
+ # Each row contains position candidates from lowest to highest score
310
+ candidate_orders = logits.argsort(descending=False).tolist()
311
+
312
+ # Initialize with best candidate for each element
313
+ assigned_orders = [candidates.pop() for candidates in candidate_orders]
314
+
315
+ while True:
316
+ conflicts = _find_conflicts(assigned_orders)
317
+ if not conflicts:
318
+ break
319
+ _resolve_conflicts(conflicts, assigned_orders, candidate_orders, logits)
320
+
321
+ return assigned_orders
322
+
323
+
324
+ def check_duplicate(seq: list[int]) -> bool:
325
+ return len(seq) != len(set(seq))
326
+
327
+
328
+ def save_visualization(
329
+ image_path: Path,
330
+ boxes: list[list[int]],
331
+ predicted_order: list[int],
332
+ output_path: Path,
333
+ ):
334
+ colors = list(Color("red").range_to(Color("green"), len(boxes)))
335
+ page = Image.open(image_path)
336
+
337
+ center = (0, 0)
338
+ for order, index in enumerate(predicted_order):
339
+ x1, y1, x2, y2 = boxes[index]
340
+ x1 = int(x1 / MAX_COOR * page.width)
341
+ y1 = int(y1 / MAX_COOR * page.height)
342
+ x2 = int(x2 / MAX_COOR * page.width)
343
+ y2 = int(y2 / MAX_COOR * page.height)
344
+
345
+ draw = ImageDraw.Draw(page)
346
+ if order > 0:
347
+ draw.line(
348
+ [center, ((x1 + x2) / 2, (y1 + y2) / 2)],
349
+ fill=colors[order].hex,
350
+ width=2,
351
+ )
352
+ draw.rectangle([(x1, y1), (x2, y2)], outline=colors[order].hex, width=4)
353
+ draw.text(
354
+ (x1, y1), text=str(order), font=FONT, fill=colors[order].hex, align="left"
355
+ )
356
+ center = ((x1 + x2) / 2, (y1 + y2) / 2)
357
+
358
+ page.save(output_path, "JPEG")