teklia-layout-reader 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- layout_reader/__init__.py +9 -0
- layout_reader/cli.py +27 -0
- layout_reader/datasets/__init__.py +99 -0
- layout_reader/datasets/analyze.py +161 -0
- layout_reader/datasets/extract.py +289 -0
- layout_reader/datasets/lsd.py +133 -0
- layout_reader/datasets/utils.py +128 -0
- layout_reader/helpers.py +358 -0
- layout_reader/inference.py +215 -0
- layout_reader/train/sft.py +69 -0
- teklia_layout_reader-0.2.1.dist-info/METADATA +62 -0
- teklia_layout_reader-0.2.1.dist-info/RECORD +22 -0
- teklia_layout_reader-0.2.1.dist-info/WHEEL +5 -0
- teklia_layout_reader-0.2.1.dist-info/entry_points.txt +2 -0
- teklia_layout_reader-0.2.1.dist-info/top_level.txt +2 -0
- tests/__init__.py +3 -0
- tests/conftest.py +19 -0
- tests/test_analyze.py +14 -0
- tests/test_cli.py +11 -0
- tests/test_extract.py +130 -0
- tests/test_helpers.py +438 -0
- tests/test_predict.py +64 -0
|
@@ -0,0 +1,133 @@
|
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
|
|
3
|
+
import cv2
|
|
4
|
+
import numpy as np
|
|
5
|
+
from matplotlib import pyplot as plt
|
|
6
|
+
|
|
7
|
+
from layout_reader.datasets.utils import check_is_valid_bbx, make_bbx_valid, resize_bbx
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class LineDetector:
|
|
11
|
+
def __init__(
|
|
12
|
+
self,
|
|
13
|
+
process_size: tuple[int] = (2000, 1000),
|
|
14
|
+
target_size: tuple[int] = (1000, 1000),
|
|
15
|
+
filter_ratio: tuple[float] = (0.05, 0.05),
|
|
16
|
+
):
|
|
17
|
+
"""
|
|
18
|
+
Initialize Line Detector with filtering parameters
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
min_line_height: Minimum height distance of detected lines
|
|
22
|
+
min_line_width: Minimum width distance of detected lines
|
|
23
|
+
angle_tolerance: Angle tolerance for merging collinear lines (degrees)
|
|
24
|
+
merge_distance: Maximum distance between lines to consider merging
|
|
25
|
+
"""
|
|
26
|
+
self.process_width, self.process_height = process_size
|
|
27
|
+
self.target_width, self.target_height = target_size
|
|
28
|
+
self.filter_width_ratio, self.filter_height_ratio = filter_ratio
|
|
29
|
+
|
|
30
|
+
# Create LSD detector
|
|
31
|
+
self.lsd = cv2.createLineSegmentDetector(cv2.LSD_REFINE_STD)
|
|
32
|
+
|
|
33
|
+
def calculate_line_properties(self, line: np.ndarray) -> tuple[float, float, float]:
|
|
34
|
+
"""
|
|
35
|
+
Calculate line properties: length, angle, and center
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
line: Line coordinates [x1, y1, x2, y2]
|
|
39
|
+
|
|
40
|
+
Returns:
|
|
41
|
+
Tuple of (length, angle_degrees, center_x, center_y)
|
|
42
|
+
"""
|
|
43
|
+
x1, y1, x2, y2 = line
|
|
44
|
+
|
|
45
|
+
# Calculate length
|
|
46
|
+
height = y2 - y1
|
|
47
|
+
width = x2 - x1
|
|
48
|
+
|
|
49
|
+
# Calculate angle (in degrees)
|
|
50
|
+
angle = np.degrees(np.arctan2(y2 - y1, x2 - x1))
|
|
51
|
+
|
|
52
|
+
return height, width, angle
|
|
53
|
+
|
|
54
|
+
def filter_lines(
|
|
55
|
+
self, lines: np.ndarray, image_height: int, image_width: int
|
|
56
|
+
) -> np.ndarray:
|
|
57
|
+
"""
|
|
58
|
+
Filter lines based on length and other criteria
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
lines: Array of line coordinates
|
|
62
|
+
|
|
63
|
+
Returns:
|
|
64
|
+
Filtered lines array
|
|
65
|
+
"""
|
|
66
|
+
if len(lines) == 0:
|
|
67
|
+
return lines
|
|
68
|
+
|
|
69
|
+
filtered_lines = []
|
|
70
|
+
|
|
71
|
+
for line in lines:
|
|
72
|
+
height, width, angle = self.calculate_line_properties(line)
|
|
73
|
+
|
|
74
|
+
# Filter by minimum length
|
|
75
|
+
if (
|
|
76
|
+
(height > width and height > self.filter_height_ratio * image_height)
|
|
77
|
+
or (width >= height and width > self.filter_width_ratio * image_width)
|
|
78
|
+
) and (abs(angle) < 10 or abs(angle) > 80):
|
|
79
|
+
filtered_lines.append(line)
|
|
80
|
+
|
|
81
|
+
return filtered_lines
|
|
82
|
+
|
|
83
|
+
def visualize(self, image, lines, output_path="results.jpg"):
|
|
84
|
+
vis_image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
|
|
85
|
+
|
|
86
|
+
# Draw lines
|
|
87
|
+
for line in lines:
|
|
88
|
+
x1, y1, x2, y2 = line
|
|
89
|
+
cv2.line(vis_image, (x1, y1), (x2, y2), (255, 0, 0), 2)
|
|
90
|
+
|
|
91
|
+
# Display
|
|
92
|
+
h, w = image.shape
|
|
93
|
+
plt.figure(figsize=(w / 100, h / 100), dpi=100)
|
|
94
|
+
plt.axis("off")
|
|
95
|
+
plt.imshow(vis_image)
|
|
96
|
+
plt.tight_layout(pad=0)
|
|
97
|
+
plt.savefig(output_path)
|
|
98
|
+
plt.close()
|
|
99
|
+
|
|
100
|
+
def process(self, image_path: Path, visualize: bool = False) -> np.ndarray:
|
|
101
|
+
"""
|
|
102
|
+
Process an image
|
|
103
|
+
"""
|
|
104
|
+
image = cv2.imread(str(image_path), flags=0)
|
|
105
|
+
|
|
106
|
+
# Resize and detect
|
|
107
|
+
resized_image = cv2.resize(image, (self.process_width, self.process_height))
|
|
108
|
+
lines = self.lsd.detect(resized_image)[0].reshape(-1, 4)
|
|
109
|
+
|
|
110
|
+
# Scale lines to target size
|
|
111
|
+
lines = resize_bbx(
|
|
112
|
+
lines,
|
|
113
|
+
width=self.process_width,
|
|
114
|
+
height=self.process_height,
|
|
115
|
+
target_width=self.target_width,
|
|
116
|
+
target_height=self.target_height,
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
# Filter lines
|
|
120
|
+
lines = self.filter_lines(lines, self.target_height, self.target_width)
|
|
121
|
+
|
|
122
|
+
# Transform lines into valid bounding boxes
|
|
123
|
+
for i in range(len(lines)):
|
|
124
|
+
lines[i] = lines[i].tolist()
|
|
125
|
+
if not check_is_valid_bbx(lines[i]):
|
|
126
|
+
lines[i] = make_bbx_valid(lines[i])
|
|
127
|
+
|
|
128
|
+
# Visualize
|
|
129
|
+
if visualize:
|
|
130
|
+
target_image = cv2.resize(image, (self.target_width, self.target_height))
|
|
131
|
+
self.visualize(target_image, lines, f"{image_path.stem}_filter.jpg")
|
|
132
|
+
|
|
133
|
+
return lines
|
|
@@ -0,0 +1,128 @@
|
|
|
1
|
+
"""Helpers."""
|
|
2
|
+
|
|
3
|
+
import gzip
|
|
4
|
+
import json
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
|
|
9
|
+
MAX_LEN = 510
|
|
10
|
+
UNK_TOKEN_ID = 3
|
|
11
|
+
# CLS_TOKEN_ID = 0
|
|
12
|
+
# EOS_TOKEN_ID = 2
|
|
13
|
+
CLS_SHIFT = 4 # 0 1 2 3 are already taken
|
|
14
|
+
BBX_FACTOR = 1000
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def check_is_valid_bbx(bbx: list[int], min_value=0, max_value=1000) -> bool:
|
|
18
|
+
"""
|
|
19
|
+
Check if box is valid.
|
|
20
|
+
"""
|
|
21
|
+
x1, y1, x2, y2 = bbx
|
|
22
|
+
return (
|
|
23
|
+
(y1 < y2)
|
|
24
|
+
and (x1 < x2)
|
|
25
|
+
and (x2 < max_value)
|
|
26
|
+
and (y2 < max_value)
|
|
27
|
+
and (x1 >= min_value)
|
|
28
|
+
and (y1 >= min_value)
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def convert_to_bbx(
|
|
33
|
+
polygon: list[list[float]], bbx_factor, min_value=0, max_value=1000
|
|
34
|
+
) -> list[int]:
|
|
35
|
+
"""
|
|
36
|
+
Convert a polygon to a bounding box.
|
|
37
|
+
"""
|
|
38
|
+
# Extract x and y from polygon
|
|
39
|
+
xs = np.array(polygon)[:, 0]
|
|
40
|
+
ys = np.array(polygon)[:, 1]
|
|
41
|
+
|
|
42
|
+
# Scale between min_value and max_value
|
|
43
|
+
x1 = max(min_value, int(xs.min() * bbx_factor))
|
|
44
|
+
y1 = max(min_value, int(ys.min() * bbx_factor))
|
|
45
|
+
x2 = min(max_value - 1, int(xs.max() * bbx_factor))
|
|
46
|
+
y2 = min(max_value - 1, int(ys.max() * bbx_factor))
|
|
47
|
+
return [x1, y1, x2, y2]
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def make_bbx_valid(box, min_value: int = 0, max_value: int = 1000) -> list[int]:
|
|
51
|
+
x1, y1, x2, y2 = box
|
|
52
|
+
# Clip
|
|
53
|
+
x1 = max(min_value, x1)
|
|
54
|
+
y1 = max(min_value, y1)
|
|
55
|
+
x2 = min(max_value - 1, x2)
|
|
56
|
+
y2 = min(max_value - 1, y2)
|
|
57
|
+
|
|
58
|
+
# Shift equal coordinates
|
|
59
|
+
if x1 == x2:
|
|
60
|
+
if x2 != max_value - 1:
|
|
61
|
+
x2 += 1
|
|
62
|
+
else:
|
|
63
|
+
x1 -= 1
|
|
64
|
+
if y1 == y2:
|
|
65
|
+
if y2 != max_value - 1:
|
|
66
|
+
y2 += 1
|
|
67
|
+
else:
|
|
68
|
+
x2 -= 1
|
|
69
|
+
return [x1, y1, x2, y2]
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def resize_bbx(
|
|
73
|
+
lines: np.ndarray, width: int, height: int, target_width: int, target_height: int
|
|
74
|
+
) -> list[list[float]]:
|
|
75
|
+
w_ratio = target_width / width
|
|
76
|
+
h_ratio = target_height / height
|
|
77
|
+
newlines = []
|
|
78
|
+
for line in lines:
|
|
79
|
+
x1, y1, x2, y2 = line.tolist()
|
|
80
|
+
newlines.append(
|
|
81
|
+
np.array([x1 * w_ratio, y1 * h_ratio, x2 * w_ratio, y2 * h_ratio])
|
|
82
|
+
)
|
|
83
|
+
return np.array([(sep).astype(int) for sep in newlines])
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def save_gzip_jsonl(filename: Path, content: list[dict]) -> None:
|
|
87
|
+
"""
|
|
88
|
+
Write content in GZIP JSONL format.
|
|
89
|
+
|
|
90
|
+
Args:
|
|
91
|
+
filename (Path): Output filename.
|
|
92
|
+
content (str): Content to write.
|
|
93
|
+
"""
|
|
94
|
+
with gzip.open(filename, "wt") as f:
|
|
95
|
+
f.write("\n".join([json.dumps(c) for c in content]))
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def load_gzip_jsonl(filename: Path) -> list[dict]:
|
|
99
|
+
"""
|
|
100
|
+
Read content in GZIP JSONL format.
|
|
101
|
+
|
|
102
|
+
Args:
|
|
103
|
+
filename (Path): Input filename.
|
|
104
|
+
"""
|
|
105
|
+
with gzip.open(filename, "rt") as f:
|
|
106
|
+
content = f.read().splitlines()
|
|
107
|
+
return [json.loads(c) for c in content]
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def check_too_many_zones(
|
|
111
|
+
boxes: list[list[int]], separators: list[list[int]], max_len: int = 512
|
|
112
|
+
) -> bool:
|
|
113
|
+
# Count total objects + [BOS] + [EOS]
|
|
114
|
+
return (len(separators) + len(boxes) + 2) >= max_len
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def load_yolo_line(line, bbx_factor) -> tuple[int, list[int]]:
|
|
118
|
+
parts = line.strip().split()
|
|
119
|
+
if len(parts) != 5:
|
|
120
|
+
raise ValueError(f"Invalid YOLO format: expected 5 values, got {len(parts)}")
|
|
121
|
+
classif = int(parts[0])
|
|
122
|
+
x_c, y_c, w, h = map(float, parts[1:])
|
|
123
|
+
x_min = (x_c - w / 2) * bbx_factor
|
|
124
|
+
y_min = (y_c - h / 2) * bbx_factor
|
|
125
|
+
x_max = (x_c + w / 2) * bbx_factor
|
|
126
|
+
y_max = (y_c + h / 2) * bbx_factor
|
|
127
|
+
box = [int(x_min), int(y_min), int(x_max), int(y_max)]
|
|
128
|
+
return classif, box
|
layout_reader/helpers.py
ADDED
|
@@ -0,0 +1,358 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import random
|
|
3
|
+
from collections import defaultdict
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
import torch
|
|
8
|
+
import yaml
|
|
9
|
+
from colour import Color
|
|
10
|
+
from datasets import load_dataset
|
|
11
|
+
from PIL import Image, ImageDraw, ImageFont
|
|
12
|
+
from transformers import LayoutLMv3ForTokenClassification
|
|
13
|
+
|
|
14
|
+
logger = logging.getLogger(__name__)
|
|
15
|
+
FONT = ImageFont.truetype("fonts/LinuxLibertine.ttf")
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
# Maximum number of zones to be ordered.
|
|
19
|
+
MAX_LEN = 510
|
|
20
|
+
|
|
21
|
+
# Maximum coordinate after normalization
|
|
22
|
+
MAX_COOR = 1000
|
|
23
|
+
|
|
24
|
+
# Custom classes used by LayoutReader
|
|
25
|
+
CLS_TOKEN_ID = 0 # Class token
|
|
26
|
+
PAD_TOKEN_ID = 3 # Padding token
|
|
27
|
+
EOS_TOKEN_ID = 2 # End-of-sequence token
|
|
28
|
+
|
|
29
|
+
# Label to be ignored in the loss computation (padding, separators...)
|
|
30
|
+
IGNORE_LABEL_ID = -100
|
|
31
|
+
|
|
32
|
+
# Width (1000) divided in 10 columns => 100 pixels/columns
|
|
33
|
+
COLUMN_WIDTH = 1000 // 10
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def sort_zones(zones, sort_method: str):
|
|
37
|
+
if sort_method == "random":
|
|
38
|
+
zones = list(zones)
|
|
39
|
+
random.shuffle(zones)
|
|
40
|
+
return
|
|
41
|
+
|
|
42
|
+
sort_keys = {
|
|
43
|
+
"sortxy_by_column": lambda z: (z[1][0] // COLUMN_WIDTH, z[1][1]),
|
|
44
|
+
"sortxy": lambda z: (z[1][0], z[1][1]),
|
|
45
|
+
"sortyx": lambda z: (z[1][1], z[1][0]),
|
|
46
|
+
}
|
|
47
|
+
zones.sort(key=sort_keys[sort_method])
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def sort_sample(
|
|
51
|
+
element, sort_ratio: float = 0.5, sort_method: str = "sortxy_by_column"
|
|
52
|
+
):
|
|
53
|
+
"""
|
|
54
|
+
Sort zones from a page.
|
|
55
|
+
"""
|
|
56
|
+
# Get boxes or classes and sort them
|
|
57
|
+
boxes = element.get("target_boxes") or element.get("source_boxes")
|
|
58
|
+
classes = element.get("target_classes") or element.get("source_classes") or []
|
|
59
|
+
|
|
60
|
+
if random.random() > sort_ratio or not boxes:
|
|
61
|
+
return element
|
|
62
|
+
|
|
63
|
+
if classes:
|
|
64
|
+
zones = [
|
|
65
|
+
(i, box, cls)
|
|
66
|
+
for i, (box, cls) in enumerate(zip(boxes, classes, strict=True))
|
|
67
|
+
]
|
|
68
|
+
else:
|
|
69
|
+
zones = [(i, box) for i, box in enumerate(boxes)]
|
|
70
|
+
|
|
71
|
+
# Sort the zones
|
|
72
|
+
sort_zones(zones, sort_method)
|
|
73
|
+
|
|
74
|
+
if "target_index" in element:
|
|
75
|
+
target_orders = (
|
|
76
|
+
np.argsort([zone[0] for zone in zones]) + 1
|
|
77
|
+
).tolist() # start at 1
|
|
78
|
+
_, boxes, classes = map(list, zip(*zones, strict=True))
|
|
79
|
+
element["target_index"] = target_orders
|
|
80
|
+
element["source_boxes"] = boxes
|
|
81
|
+
element["source_classes"] = classes
|
|
82
|
+
return element
|
|
83
|
+
|
|
84
|
+
if classes:
|
|
85
|
+
_, boxes, classes = map(list, zip(*zones, strict=True))
|
|
86
|
+
else:
|
|
87
|
+
_, boxes = map(list, zip(*zones, strict=True))
|
|
88
|
+
|
|
89
|
+
element["source_boxes"] = boxes
|
|
90
|
+
element["source_classes"] = classes
|
|
91
|
+
return element
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def read_yaml(filename: str):
|
|
95
|
+
if not Path(filename).exists():
|
|
96
|
+
raise FileNotFoundError(f"Configuration not found: {filename}")
|
|
97
|
+
return yaml.safe_load(Path(filename).read_text())
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def load_dataset_split(dataset_path: str, split: str):
|
|
101
|
+
filename = Path(dataset_path) / f"{split}.jsonl.gz"
|
|
102
|
+
if not filename.exists():
|
|
103
|
+
raise FileNotFoundError(f"Dataset file not found: {filename}")
|
|
104
|
+
try:
|
|
105
|
+
return load_dataset(
|
|
106
|
+
"json",
|
|
107
|
+
data_files={split: str(filename)},
|
|
108
|
+
)[split]
|
|
109
|
+
except Exception as e:
|
|
110
|
+
raise ValueError(f"Failed to load dataset file {filename}") from e
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def load_model(model_path: str):
|
|
114
|
+
try:
|
|
115
|
+
return LayoutLMv3ForTokenClassification.from_pretrained(
|
|
116
|
+
model_path,
|
|
117
|
+
device_map="auto",
|
|
118
|
+
)
|
|
119
|
+
except Exception as e:
|
|
120
|
+
raise ValueError(
|
|
121
|
+
f"Failed to load model from '{model_path}'. "
|
|
122
|
+
"Model path must be a valid Hugging Face model ID or a local directory."
|
|
123
|
+
) from e
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
class DataCollator:
|
|
127
|
+
def __init__(self, with_classes: bool = False, with_separators: bool = False):
|
|
128
|
+
self.with_classes = with_classes
|
|
129
|
+
self.with_separators = with_separators
|
|
130
|
+
|
|
131
|
+
def _truncate(self, seq: list, name: str, max_len: int) -> list:
|
|
132
|
+
if len(seq) > max_len:
|
|
133
|
+
logger.warning(
|
|
134
|
+
f"Truncated {name}. Length ({len(seq)}) exceeds MAX_LEN ({max_len})."
|
|
135
|
+
)
|
|
136
|
+
return seq[:max_len]
|
|
137
|
+
return seq
|
|
138
|
+
|
|
139
|
+
def _prepare_single_feature(self, feature: dict) -> dict[str, list]:
|
|
140
|
+
bboxes = feature["source_boxes"].copy()
|
|
141
|
+
len_boxes = len(bboxes)
|
|
142
|
+
separators = list(feature.get("separators", []))
|
|
143
|
+
if self.with_separators:
|
|
144
|
+
bboxes.extend(separators)
|
|
145
|
+
bboxes = self._truncate(bboxes, "bounding boxes", MAX_LEN)
|
|
146
|
+
|
|
147
|
+
# Prepare reading order
|
|
148
|
+
# Separators should be ignored in the loss
|
|
149
|
+
labels = feature["target_index"]
|
|
150
|
+
if self.with_separators:
|
|
151
|
+
labels.extend([IGNORE_LABEL_ID] * len(separators))
|
|
152
|
+
labels = self._truncate(labels, "labels", MAX_LEN)
|
|
153
|
+
|
|
154
|
+
# Prepare classes
|
|
155
|
+
input_ids = (
|
|
156
|
+
feature["source_classes"]
|
|
157
|
+
if self.with_classes
|
|
158
|
+
else [PAD_TOKEN_ID] * len_boxes
|
|
159
|
+
)
|
|
160
|
+
if self.with_separators:
|
|
161
|
+
input_ids.extend([PAD_TOKEN_ID] * len(separators))
|
|
162
|
+
input_ids = self._truncate(input_ids, "input_ids", MAX_LEN)
|
|
163
|
+
|
|
164
|
+
# Prepare attention mask
|
|
165
|
+
attention_mask = [1] * len(bboxes)
|
|
166
|
+
|
|
167
|
+
# Sanity check
|
|
168
|
+
assert len(bboxes) == len(labels) == len(input_ids) == len(attention_mask), (
|
|
169
|
+
f"Length mismatch: bbox={len(bboxes)}, labels={len(labels)}, "
|
|
170
|
+
f"input_ids={len(input_ids)}, attention_mask={len(attention_mask)}"
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
return {
|
|
174
|
+
"bbox": bboxes,
|
|
175
|
+
"labels": labels,
|
|
176
|
+
"input_ids": input_ids,
|
|
177
|
+
"attention_mask": attention_mask,
|
|
178
|
+
}
|
|
179
|
+
|
|
180
|
+
def _add_special_tokens(self, batch: dict[str, list[list]]) -> None:
|
|
181
|
+
"""Add CLS and EOS tokens (in-place)."""
|
|
182
|
+
for i in range(len(batch["bbox"])):
|
|
183
|
+
batch["bbox"][i] = [[0, 0, 0, 0]] + batch["bbox"][i] + [[0, 0, 0, 0]]
|
|
184
|
+
batch["labels"][i] = (
|
|
185
|
+
[IGNORE_LABEL_ID] + batch["labels"][i] + [IGNORE_LABEL_ID]
|
|
186
|
+
)
|
|
187
|
+
batch["input_ids"][i] = (
|
|
188
|
+
[CLS_TOKEN_ID] + batch["input_ids"][i] + [EOS_TOKEN_ID]
|
|
189
|
+
)
|
|
190
|
+
batch["attention_mask"][i] = [1] + batch["attention_mask"][i] + [1]
|
|
191
|
+
|
|
192
|
+
def _pad_sequences(self, batch: dict[str, list[list]]) -> None:
|
|
193
|
+
"""Pad all sequences to max length (in-place)."""
|
|
194
|
+
max_len = max(len(x) for x in batch["bbox"])
|
|
195
|
+
|
|
196
|
+
for i in range(len(batch["bbox"])):
|
|
197
|
+
pad_len = max_len - len(batch["bbox"][i])
|
|
198
|
+
batch["bbox"][i] += [[0, 0, 0, 0]] * pad_len
|
|
199
|
+
batch["labels"][i] += [IGNORE_LABEL_ID] * pad_len
|
|
200
|
+
batch["input_ids"][i] += [PAD_TOKEN_ID] * pad_len
|
|
201
|
+
batch["attention_mask"][i] += [0] * pad_len
|
|
202
|
+
|
|
203
|
+
def __call__(self, features: list[dict]) -> dict[str, torch.Tensor]:
|
|
204
|
+
batch = {
|
|
205
|
+
"bbox": [],
|
|
206
|
+
"labels": [],
|
|
207
|
+
"input_ids": [],
|
|
208
|
+
"attention_mask": [],
|
|
209
|
+
}
|
|
210
|
+
|
|
211
|
+
for feature in features:
|
|
212
|
+
processed = self._prepare_single_feature(feature)
|
|
213
|
+
for key in batch:
|
|
214
|
+
batch[key].append(processed[key])
|
|
215
|
+
|
|
216
|
+
# Add special tokens
|
|
217
|
+
self._add_special_tokens(batch)
|
|
218
|
+
|
|
219
|
+
# Pad to same length
|
|
220
|
+
self._pad_sequences(batch)
|
|
221
|
+
|
|
222
|
+
batch_tensors = {
|
|
223
|
+
key: torch.tensor(batch[key], dtype=torch.long) for key in batch
|
|
224
|
+
}
|
|
225
|
+
|
|
226
|
+
# Post-process labels
|
|
227
|
+
labels_tensor = batch_tensors["labels"]
|
|
228
|
+
labels_tensor[labels_tensor > MAX_LEN] = IGNORE_LABEL_ID
|
|
229
|
+
labels_tensor[labels_tensor > 0] -= 1
|
|
230
|
+
batch_tensors["labels"] = labels_tensor
|
|
231
|
+
|
|
232
|
+
return batch_tensors
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
def boxes_to_inputs(boxes, cls, separators) -> dict[str, torch.Tensor]:
|
|
236
|
+
all_boxes = boxes.copy()
|
|
237
|
+
if separators:
|
|
238
|
+
all_boxes += separators
|
|
239
|
+
if cls:
|
|
240
|
+
cls += [PAD_TOKEN_ID] * len(separators)
|
|
241
|
+
|
|
242
|
+
if not cls:
|
|
243
|
+
cls = [PAD_TOKEN_ID] * len(all_boxes)
|
|
244
|
+
|
|
245
|
+
bbox = [[0, 0, 0, 0]] + all_boxes + [[0, 0, 0, 0]]
|
|
246
|
+
input_ids = [CLS_TOKEN_ID] + cls + [EOS_TOKEN_ID]
|
|
247
|
+
attention_mask = [1] + [1] * len(all_boxes) + [1]
|
|
248
|
+
return {
|
|
249
|
+
"bbox": torch.tensor([bbox]),
|
|
250
|
+
"attention_mask": torch.tensor([attention_mask]),
|
|
251
|
+
"input_ids": torch.tensor([input_ids]),
|
|
252
|
+
}
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
def prepare_inputs(
|
|
256
|
+
inputs: dict[str, torch.Tensor], model: LayoutLMv3ForTokenClassification
|
|
257
|
+
) -> dict[str, torch.Tensor]:
|
|
258
|
+
prepared = {}
|
|
259
|
+
for key, tensor in inputs.items():
|
|
260
|
+
tensor = tensor.to(model.device)
|
|
261
|
+
if torch.is_floating_point(tensor):
|
|
262
|
+
tensor = tensor.to(model.dtype)
|
|
263
|
+
prepared[key] = tensor
|
|
264
|
+
return prepared
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
def parse_logits(logits: torch.Tensor, length: int) -> list[int]:
|
|
268
|
+
"""
|
|
269
|
+
Convert logits to reading orders.
|
|
270
|
+
"""
|
|
271
|
+
|
|
272
|
+
def _find_conflicts(assigned_orders: list[int]) -> dict[int, list[int]]:
|
|
273
|
+
order_to_elements = defaultdict(list)
|
|
274
|
+
for element_idx, order in enumerate(assigned_orders):
|
|
275
|
+
order_to_elements[order].append(element_idx)
|
|
276
|
+
|
|
277
|
+
# Keep only positions with conflicts
|
|
278
|
+
return {
|
|
279
|
+
order: elements
|
|
280
|
+
for order, elements in order_to_elements.items()
|
|
281
|
+
if len(elements) > 1
|
|
282
|
+
}
|
|
283
|
+
|
|
284
|
+
def _resolve_conflicts(
|
|
285
|
+
conflicts: dict[int, list[int]],
|
|
286
|
+
assigned_orders: list[int],
|
|
287
|
+
candidate_orders: list[list[int]],
|
|
288
|
+
logits: torch.Tensor,
|
|
289
|
+
) -> None:
|
|
290
|
+
"""
|
|
291
|
+
Resolve conflicts based on logits values (keep highest).
|
|
292
|
+
"""
|
|
293
|
+
for order, element_indices in conflicts.items():
|
|
294
|
+
# Get logit scores for all elements predicted to current order
|
|
295
|
+
elements_by_score = [
|
|
296
|
+
(element_idx, logits[element_idx, order].item())
|
|
297
|
+
for element_idx in element_indices
|
|
298
|
+
]
|
|
299
|
+
elements_by_score.sort(key=lambda x: x[1], reverse=True)
|
|
300
|
+
|
|
301
|
+
# Reassign all but the highest-scoring element to next candidates
|
|
302
|
+
for element_idx, _ in elements_by_score[1:]:
|
|
303
|
+
assigned_orders[element_idx] = candidate_orders[element_idx].pop()
|
|
304
|
+
|
|
305
|
+
# Extract relevant logits (skip special tokens)
|
|
306
|
+
logits = logits[1 : length + 1, :length]
|
|
307
|
+
|
|
308
|
+
# Get sorted candidate positions for each element (ascending order)
|
|
309
|
+
# Each row contains position candidates from lowest to highest score
|
|
310
|
+
candidate_orders = logits.argsort(descending=False).tolist()
|
|
311
|
+
|
|
312
|
+
# Initialize with best candidate for each element
|
|
313
|
+
assigned_orders = [candidates.pop() for candidates in candidate_orders]
|
|
314
|
+
|
|
315
|
+
while True:
|
|
316
|
+
conflicts = _find_conflicts(assigned_orders)
|
|
317
|
+
if not conflicts:
|
|
318
|
+
break
|
|
319
|
+
_resolve_conflicts(conflicts, assigned_orders, candidate_orders, logits)
|
|
320
|
+
|
|
321
|
+
return assigned_orders
|
|
322
|
+
|
|
323
|
+
|
|
324
|
+
def check_duplicate(seq: list[int]) -> bool:
|
|
325
|
+
return len(seq) != len(set(seq))
|
|
326
|
+
|
|
327
|
+
|
|
328
|
+
def save_visualization(
|
|
329
|
+
image_path: Path,
|
|
330
|
+
boxes: list[list[int]],
|
|
331
|
+
predicted_order: list[int],
|
|
332
|
+
output_path: Path,
|
|
333
|
+
):
|
|
334
|
+
colors = list(Color("red").range_to(Color("green"), len(boxes)))
|
|
335
|
+
page = Image.open(image_path)
|
|
336
|
+
|
|
337
|
+
center = (0, 0)
|
|
338
|
+
for order, index in enumerate(predicted_order):
|
|
339
|
+
x1, y1, x2, y2 = boxes[index]
|
|
340
|
+
x1 = int(x1 / MAX_COOR * page.width)
|
|
341
|
+
y1 = int(y1 / MAX_COOR * page.height)
|
|
342
|
+
x2 = int(x2 / MAX_COOR * page.width)
|
|
343
|
+
y2 = int(y2 / MAX_COOR * page.height)
|
|
344
|
+
|
|
345
|
+
draw = ImageDraw.Draw(page)
|
|
346
|
+
if order > 0:
|
|
347
|
+
draw.line(
|
|
348
|
+
[center, ((x1 + x2) / 2, (y1 + y2) / 2)],
|
|
349
|
+
fill=colors[order].hex,
|
|
350
|
+
width=2,
|
|
351
|
+
)
|
|
352
|
+
draw.rectangle([(x1, y1), (x2, y2)], outline=colors[order].hex, width=4)
|
|
353
|
+
draw.text(
|
|
354
|
+
(x1, y1), text=str(order), font=FONT, fill=colors[order].hex, align="left"
|
|
355
|
+
)
|
|
356
|
+
center = ((x1 + x2) / 2, (y1 + y2) / 2)
|
|
357
|
+
|
|
358
|
+
page.save(output_path, "JPEG")
|