teklia-layout-reader 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,215 @@
1
+ import json
2
+ import logging
3
+ from dataclasses import asdict, dataclass
4
+ from pathlib import Path
5
+ from typing import Any
6
+
7
+ from layout_reader.helpers import (
8
+ boxes_to_inputs,
9
+ load_dataset_split,
10
+ load_model,
11
+ parse_logits,
12
+ prepare_inputs,
13
+ save_visualization,
14
+ sort_sample,
15
+ )
16
+
17
+ logger = logging.getLogger(__name__)
18
+ DATA_FILES = ["train", "val", "dev", "test"]
19
+
20
+
21
+ @dataclass
22
+ class PageResult:
23
+ boxes: list[Any] # List of bounding boxes to reorder
24
+ classes: list[Any] # List of corresponding classes
25
+ separators: list[Any] # List of separators, used as context by the model
26
+ predicted_order: list[int] # Detected reading order
27
+
28
+ # Optional evaluation outputs
29
+ target_order: list[int] | None = None # Target reading order
30
+
31
+ # Optional visualization output
32
+ visualization: str | None = None # Path to the visualization image
33
+
34
+ @property
35
+ def average_relative_distance(self) -> float:
36
+ """
37
+ Compute the Average Relative Distance (ARD) between predicted and ground-truth ordering.
38
+ """
39
+ gt_zones = [self.boxes[i] for i in self.target_order]
40
+ predicted_zones = [self.boxes[i] for i in self.predicted_order]
41
+
42
+ ard = 0
43
+
44
+ if not gt_zones and not predicted_zones:
45
+ return 0
46
+
47
+ if not gt_zones or not predicted_zones:
48
+ return float(max(len(gt_zones), len(predicted_zones)))
49
+
50
+ for pred_idx, pred_zone in enumerate(predicted_zones):
51
+ if pred_zone in gt_zones:
52
+ ard += abs(gt_zones.index(pred_zone) - pred_idx)
53
+ else:
54
+ ard += len(predicted_zones)
55
+
56
+ return ard / len(predicted_zones)
57
+
58
+ def to_json(self) -> dict[str, Any]:
59
+ data = asdict(self)
60
+
61
+ if self.target_order is not None:
62
+ data["average_relative_distance"] = self.average_relative_distance
63
+
64
+ return data
65
+
66
+
67
+ def add_predict_command(subcommands) -> None:
68
+ parser = subcommands.add_parser(
69
+ "inference",
70
+ description="Predict and evaluate a model on a Layout Reader dataset.",
71
+ help="Predict and evaluate a model on a Layout Reader dataset.",
72
+ )
73
+ parser.add_argument(
74
+ "--dataset",
75
+ type=str,
76
+ help="Path to the local LayoutReader dataset directory. The directory must contain .jsonl.gz files.",
77
+ )
78
+ parser.add_argument(
79
+ "--split",
80
+ choices=DATA_FILES,
81
+ help="Dataset split to use. Must match the name of a corresponding archive in the --dataset directory (e.g., 'train' -> train.jsonl.gz).",
82
+ )
83
+ parser.add_argument(
84
+ "--model",
85
+ type=str,
86
+ help="Name of the LayoutReader checkpoint dataset",
87
+ )
88
+ parser.add_argument(
89
+ "--output-dir",
90
+ type=Path,
91
+ help="Output directory where results will be saved. If --visualization, it will also plot the predicted reading order on each image.",
92
+ required=True,
93
+ )
94
+ parser.add_argument(
95
+ "--with-classes",
96
+ action="store_true",
97
+ help="Whether to use the zone classes for the prediction.",
98
+ )
99
+ parser.add_argument(
100
+ "--with-separators",
101
+ action="store_true",
102
+ help="Whether to use the separators for the prediction.",
103
+ )
104
+ parser.add_argument(
105
+ "--sort-method",
106
+ choices=["sortxy_by_column", "sortxy", "sortyx", "random"],
107
+ default=None,
108
+ help="How to pre-order the input zones. By default, no sorting method will be applied.",
109
+ )
110
+ parser.add_argument(
111
+ "--images",
112
+ type=Path,
113
+ help="Path to the images (required to visualize predictions).",
114
+ default=None,
115
+ )
116
+ parser.add_argument(
117
+ "--visualize",
118
+ action="store_true",
119
+ help="Whether to visualize the predicted reading order. Plots will be saved in output_dir.",
120
+ )
121
+ parser.set_defaults(func=run)
122
+
123
+
124
+ def predict(model, boxes, classes, separators) -> list[int]:
125
+ if not boxes:
126
+ return []
127
+ inputs = boxes_to_inputs(boxes, cls=classes, separators=separators)
128
+ inputs = prepare_inputs(inputs, model)
129
+ logits = model(**inputs).logits.squeeze(0)
130
+ predicted_orders = parse_logits(logits, len(boxes))
131
+ return predicted_orders
132
+
133
+
134
+ def process_page(
135
+ page,
136
+ model,
137
+ with_classes: bool,
138
+ with_separators: bool,
139
+ visualize: bool,
140
+ images: Path | None,
141
+ output_dir: Path,
142
+ sort_method: str | None,
143
+ ) -> PageResult:
144
+ # Pre-order boxes and classes
145
+ if sort_method is not None:
146
+ page = sort_sample(page, sort_ratio=1.0, sort_method=sort_method)
147
+
148
+ boxes = page["source_boxes"]
149
+ classes = page["source_classes"] if with_classes else []
150
+ separators = page["separators"] if with_separators else []
151
+ predicted_order = predict(model, boxes, classes, separators)
152
+
153
+ result = PageResult(
154
+ boxes=boxes,
155
+ classes=classes,
156
+ separators=separators,
157
+ predicted_order=predicted_order,
158
+ )
159
+
160
+ # Run evaluation if the ground truth is available
161
+ if "target_index" in page:
162
+ result.target_order = [i - 1 for i in page["target_index"]]
163
+
164
+ if visualize and images is not None and result.boxes:
165
+ image_path = (images / page["sample_id"]).with_suffix(".jpg")
166
+ output_path = output_dir / f"{page['sample_id']}_visualize.jpg"
167
+ save_visualization(
168
+ image_path=image_path,
169
+ boxes=boxes,
170
+ predicted_order=predicted_order,
171
+ output_path=output_path,
172
+ )
173
+
174
+ result.visualization = str(output_path)
175
+ return result
176
+
177
+
178
+ def run(
179
+ dataset: str,
180
+ split: str,
181
+ model: str,
182
+ output_dir: Path,
183
+ with_classes: bool = False,
184
+ with_separators: bool = False,
185
+ images: Path | None = None,
186
+ visualize: bool = False,
187
+ sort_method: str | None = None,
188
+ ) -> None:
189
+ logger.info(f"Loading model {model}")
190
+ model = load_model(model_path=model)
191
+
192
+ logger.info(f"Loading dataset {dataset}/{split}.jsonl.gz")
193
+ dataset = load_dataset_split(dataset, split)
194
+
195
+ results = {}
196
+
197
+ if visualize and images is None:
198
+ logger.warning("Skipping visualization, as --images is not defined")
199
+
200
+ for page in dataset:
201
+ page_result = process_page(
202
+ page=page,
203
+ model=model,
204
+ with_classes=with_classes,
205
+ with_separators=with_separators,
206
+ visualize=visualize,
207
+ images=images,
208
+ output_dir=output_dir,
209
+ sort_method=sort_method,
210
+ )
211
+
212
+ results[page["sample_id"]] = page_result.to_json()
213
+
214
+ output_dir.mkdir(exist_ok=True)
215
+ (output_dir / "predictions.json").write_text(json.dumps(results, indent=4))
@@ -0,0 +1,69 @@
1
+ import logging
2
+ from pathlib import Path
3
+
4
+ from transformers import (
5
+ LayoutLMv3ForTokenClassification,
6
+ )
7
+ from transformers.trainer import Trainer
8
+ from trl import SFTConfig
9
+
10
+ from layout_reader.helpers import (
11
+ MAX_LEN,
12
+ DataCollator,
13
+ load_dataset_split,
14
+ read_yaml,
15
+ sort_sample,
16
+ )
17
+
18
+ logger = logging.getLogger()
19
+
20
+
21
+ def run_training(config: dict):
22
+ train_dataset = load_dataset_split(Path(config["dataset_dir"]), "train")
23
+ dev_dataset = load_dataset_split(Path(config["dataset_dir"]), "dev")
24
+ if config["sort_ratio"] > 0:
25
+
26
+ def apply_sort_fn(dataset):
27
+ return sort_sample(
28
+ dataset,
29
+ sort_ratio=config["sort_ratio"],
30
+ sort_method=config["sort_method"],
31
+ )
32
+
33
+ train_dataset = train_dataset.map(apply_sort_fn)
34
+ dev_dataset = dev_dataset.map(apply_sort_fn)
35
+
36
+ logger.info(
37
+ f"Train dataset size: {len(train_dataset)}, Dev dataset size: {len(dev_dataset)}"
38
+ )
39
+
40
+ model = LayoutLMv3ForTokenClassification.from_pretrained(
41
+ config["model_dir"], num_labels=MAX_LEN, visual_embed=False
42
+ )
43
+
44
+ data_collator = DataCollator(
45
+ with_classes=config["with_classes"], with_separators=config["with_separators"]
46
+ )
47
+
48
+ trainer = Trainer(
49
+ model=model,
50
+ args=SFTConfig(**config["SFTTrainer"]),
51
+ train_dataset=train_dataset,
52
+ eval_dataset=dev_dataset,
53
+ data_collator=data_collator,
54
+ )
55
+ trainer.train()
56
+
57
+
58
+ def add_training_command(subcommands) -> None:
59
+ parser = subcommands.add_parser(
60
+ "train",
61
+ description="Train a Layout Reader model on a dataset.",
62
+ help="Train a Layout Reader model on a dataset.",
63
+ )
64
+ parser.add_argument(
65
+ "--config",
66
+ type=read_yaml,
67
+ help="Name of the LayoutReader dataset",
68
+ )
69
+ parser.set_defaults(func=run_training)
@@ -0,0 +1,62 @@
1
+ Metadata-Version: 2.4
2
+ Name: teklia-layout-reader
3
+ Version: 0.2.1
4
+ Summary: Scripts for Layout Reader
5
+ Author-email: Teklia <contact@teklia.com>
6
+ Maintainer-email: Teklia <contact@teklia.com>
7
+ License-Expression: MIT
8
+ Keywords: python
9
+ Classifier: Programming Language :: Python :: 3 :: Only
10
+ Classifier: Programming Language :: Python :: 3.10
11
+ Classifier: Programming Language :: Python :: 3.11
12
+ Classifier: Programming Language :: Python :: 3.12
13
+ Requires-Python: >=3.10
14
+ Description-Content-Type: text/markdown
15
+ Requires-Dist: teklia-toolbox==0.1.11
16
+ Requires-Dist: datasets==3.2.0
17
+ Requires-Dist: mdutils==1.6.0
18
+ Requires-Dist: prettytable==3.13.0
19
+ Requires-Dist: opencv-python-headless==4.12.0.88
20
+ Requires-Dist: matplotlib==3.10.6
21
+ Requires-Dist: deepspeed==0.18.1
22
+ Requires-Dist: trl==0.23.0
23
+ Requires-Dist: mpi4py==4.1.1
24
+ Requires-Dist: wandb==0.23.1
25
+ Requires-Dist: transformers==4.57.6
26
+ Requires-Dist: colour==0.1.5
27
+ Requires-Dist: accelerate==1.12.0
28
+
29
+ # Layout Reader
30
+
31
+ Scripts for Layout Reader
32
+
33
+ ### Development
34
+
35
+ For development and tests purpose it may be useful to install the project as a editable package with pip.
36
+
37
+ * Use a virtualenv (e.g. with virtualenvwrapper `mkvirtualenv -a . layout-reader`)
38
+ * Install layout-reader as a package (e.g. `pip install -e .`)
39
+
40
+ ### Linter
41
+
42
+ Code syntax is analyzed before submitting the code.\
43
+ To run the linter tools suite you may use pre-commit.
44
+
45
+ ```shell
46
+ pip install pre-commit
47
+ pre-commit run -a
48
+ ```
49
+
50
+ ### Run tests
51
+
52
+ Tests are executed with `tox` using [pytest](https://pytest.org).
53
+
54
+ ```shell
55
+ pip install tox
56
+ tox
57
+ ```
58
+
59
+ To recreate tox virtual environment (e.g. a dependencies update), you may run `tox -r`.
60
+
61
+ Run a single test module: `tox -- <test_path>`
62
+ Run a single test: `tox -- <test_path>::<test_function>`
@@ -0,0 +1,22 @@
1
+ layout_reader/__init__.py,sha256=5doGcBcb3kTQMoDbjiESh1g0CrG_mHTJlGRX4Uo7NRU,156
2
+ layout_reader/cli.py,sha256=4M4nzV7kZmF-LmUqHie_eeXJIhvHdhgAi-_q9gm7yEo,665
3
+ layout_reader/helpers.py,sha256=rWJT0glBqOYCj4tvktiPunxoTHiNb5v2oD5MainDrhU,11675
4
+ layout_reader/inference.py,sha256=ECujeVPfYI6DM0yl2WorzmWwD5bMZTblm4uFg2ggldY,6616
5
+ layout_reader/datasets/__init__.py,sha256=ppjB6O0uDJLR3FG3aKnZmtZzYLWJC8Aeb3KjnAWawzE,2663
6
+ layout_reader/datasets/analyze.py,sha256=Lw_Coqpudlz2v4quZPvWr61AEJYTkSYJNUHkzeCxZ38,4700
7
+ layout_reader/datasets/extract.py,sha256=Sdoe8kMBrjKd2ydGYDySc-6e2w6_0BgFa8sLqTsQNKU,8726
8
+ layout_reader/datasets/lsd.py,sha256=jVE6ByzR9OPnvLvpYbeF5adDWHEeVLFor8-9o9uVzGE,4211
9
+ layout_reader/datasets/utils.py,sha256=Nvw3y441WS-fU-9qXCMcM3_UZaEBnBsS19Is0zu6Z_w,3393
10
+ layout_reader/train/sft.py,sha256=yb2MqEZhQDxkBcUXqvPC-ebjG7BuaLzoT2vEaRo7HlE,1872
11
+ tests/__init__.py,sha256=Z6xMdy1fM-m3Ncq4LMozIv_LvHsHnUCcFPGe7qaI3h0,78
12
+ tests/conftest.py,sha256=CX31U3NxXMLQp28DK2uU4GELObTAKdaZvlX9B6JlGrY,500
13
+ tests/test_analyze.py,sha256=DeOQ2vy0qYwEmIquhLCN_LrOgVwv_sO55apXaLpYGxo,405
14
+ tests/test_cli.py,sha256=vizFW9k23hKZQSNjXZOXGtwfzS646HK8k4Huag4T1lI,246
15
+ tests/test_extract.py,sha256=ruT3fedzXCQ1FdW3uYYEZA2VdX3GxK67hzdoLi1u9jA,3251
16
+ tests/test_helpers.py,sha256=sOX3JWfZE_EL2ZD6qNF9vFbqSm6Ak2158nZcFsdD91o,13896
17
+ tests/test_predict.py,sha256=DoZZtG8AlJtwxF2f6bz1CPCIrMjLjs4MbHRA8XF4Q5k,1613
18
+ teklia_layout_reader-0.2.1.dist-info/METADATA,sha256=DU5rASxjSgEjEE0LZBTIOvSxgLjUBZM0owXPyAQ5kr8,1735
19
+ teklia_layout_reader-0.2.1.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
20
+ teklia_layout_reader-0.2.1.dist-info/entry_points.txt,sha256=9kJyzLJuIhGeoYnwHxYITSNh3--7v7HTMfgwIpG0_fc,64
21
+ teklia_layout_reader-0.2.1.dist-info/top_level.txt,sha256=G0dTKLPB9MCL_4pPL8Y0aqjT6ujCWb3Y1Si3UutIvCU,20
22
+ teklia_layout_reader-0.2.1.dist-info/RECORD,,
@@ -0,0 +1,5 @@
1
+ Wheel-Version: 1.0
2
+ Generator: setuptools (80.10.2)
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
5
+
@@ -0,0 +1,2 @@
1
+ [console_scripts]
2
+ teklia-layout-reader = layout_reader.cli:main
@@ -0,0 +1,2 @@
1
+ layout_reader
2
+ tests
tests/__init__.py ADDED
@@ -0,0 +1,3 @@
1
+ from pathlib import Path
2
+
3
+ FIXTURES = Path(__file__).resolve().parent / "data"
tests/conftest.py ADDED
@@ -0,0 +1,19 @@
1
+ import os
2
+
3
+ import pytest
4
+
5
+
6
+ @pytest.fixture(autouse=True)
7
+ def _setup_environment(responses):
8
+ """Setup needed environment variables"""
9
+
10
+ # Allow accessing remote API schemas
11
+ # defaulting to the prod environment
12
+ schema_url = os.environ.get(
13
+ "ARKINDEX_API_SCHEMA_URL",
14
+ "https://arkindex.teklia.com/api/v1/openapi/?format=openapi-json",
15
+ )
16
+ responses.add_passthru(schema_url)
17
+
18
+ # Set schema url in environment
19
+ os.environ["ARKINDEX_API_SCHEMA_URL"] = schema_url
tests/test_analyze.py ADDED
@@ -0,0 +1,14 @@
1
+ from layout_reader.datasets.analyze import Statistics
2
+ from tests import FIXTURES
3
+
4
+ DATASET = FIXTURES / "lr_dataset" / "train.jsonl.gz"
5
+ REPORT = FIXTURES / "dataset_report.md"
6
+
7
+
8
+ def test_run(tmp_path) -> None:
9
+ output_file = tmp_path / "report.md"
10
+ stats = Statistics(output=str(output_file))
11
+ stats.run(
12
+ filenames=[DATASET],
13
+ )
14
+ assert output_file.read_text() == REPORT.read_text()
tests/test_cli.py ADDED
@@ -0,0 +1,11 @@
1
+ import importlib
2
+
3
+
4
+ def test_dummy():
5
+ assert True
6
+
7
+
8
+ def test_import():
9
+ """Import our newly created module, through importlib to avoid parsing issues"""
10
+ cli = importlib.import_module("layout_reader.cli")
11
+ assert hasattr(cli, "main")
tests/test_extract.py ADDED
@@ -0,0 +1,130 @@
1
+ import random
2
+ from pathlib import Path
3
+
4
+ import numpy as np
5
+ import pytest
6
+
7
+ from layout_reader.datasets.extract import (
8
+ Mode,
9
+ run_dataset_extraction,
10
+ )
11
+ from layout_reader.datasets.utils import load_gzip_jsonl
12
+ from tests import FIXTURES
13
+
14
+
15
+ @pytest.fixture
16
+ def SPLIT() -> str:
17
+ return "train"
18
+
19
+
20
+ @pytest.fixture
21
+ def hf_dataset() -> str:
22
+ return str(FIXTURES / "hf_dataset")
23
+
24
+
25
+ @pytest.fixture
26
+ def yolo_dataset() -> str:
27
+ return str(FIXTURES / "yolo_dataset")
28
+
29
+
30
+ @pytest.fixture
31
+ def lr_dataset() -> str:
32
+ return str(FIXTURES / "lr_dataset")
33
+
34
+
35
+ @pytest.mark.parametrize(
36
+ ("shuffle_rate", "extract_classes", "extract_separators"),
37
+ [
38
+ (0.0, False, False),
39
+ (0.0, True, False),
40
+ (0.0, False, True),
41
+ (1.0, True, True),
42
+ (0.5, False, True),
43
+ (0.5, True, False),
44
+ ],
45
+ )
46
+ def test_run_dataset_extraction_modes(
47
+ hf_dataset,
48
+ yolo_dataset,
49
+ shuffle_rate,
50
+ extract_classes,
51
+ extract_separators,
52
+ tmp_path,
53
+ ):
54
+ output_dir_hf = tmp_path / "output_hf"
55
+ output_dir_yolo = tmp_path / "output_yolo"
56
+
57
+ random.seed(42)
58
+ np.random.seed(42)
59
+ run_dataset_extraction(
60
+ hf_dataset,
61
+ Mode.HF,
62
+ shuffle_rate=shuffle_rate,
63
+ output_dir=output_dir_hf,
64
+ extract_classes=extract_classes,
65
+ extract_separators=extract_separators,
66
+ )
67
+
68
+ random.seed(42)
69
+ np.random.seed(42)
70
+ run_dataset_extraction(
71
+ yolo_dataset,
72
+ Mode.YOLO,
73
+ shuffle_rate=shuffle_rate,
74
+ output_dir=output_dir_yolo,
75
+ extract_classes=extract_classes,
76
+ extract_separators=extract_separators,
77
+ )
78
+ assert (Path(output_dir_hf) / "train.jsonl.gz").exists()
79
+ assert (Path(output_dir_yolo) / "train.jsonl.gz").exists()
80
+ assert load_gzip_jsonl(Path(output_dir_hf) / "train.jsonl.gz") == load_gzip_jsonl(
81
+ Path(output_dir_yolo) / "train.jsonl.gz"
82
+ )
83
+
84
+
85
+ @pytest.mark.parametrize(
86
+ ("shuffle_rate", "extract_classes", "extract_separators"),
87
+ [
88
+ (0.5, True, True),
89
+ ],
90
+ )
91
+ def test_run_dataset_extraction(
92
+ hf_dataset,
93
+ yolo_dataset,
94
+ lr_dataset,
95
+ shuffle_rate,
96
+ extract_classes,
97
+ extract_separators,
98
+ tmp_path,
99
+ ):
100
+ output_dir_hf = tmp_path / "output_hf"
101
+ output_dir_yolo = tmp_path / "output_yolo"
102
+
103
+ random.seed(42)
104
+ np.random.seed(42)
105
+ run_dataset_extraction(
106
+ hf_dataset,
107
+ Mode.HF,
108
+ shuffle_rate=shuffle_rate,
109
+ output_dir=output_dir_hf,
110
+ extract_classes=extract_classes,
111
+ extract_separators=extract_separators,
112
+ )
113
+
114
+ random.seed(42)
115
+ np.random.seed(42)
116
+ run_dataset_extraction(
117
+ yolo_dataset,
118
+ Mode.YOLO,
119
+ shuffle_rate=shuffle_rate,
120
+ output_dir=output_dir_yolo,
121
+ extract_classes=extract_classes,
122
+ extract_separators=extract_separators,
123
+ )
124
+ assert (Path(output_dir_hf) / "train.jsonl.gz").exists()
125
+ assert (Path(output_dir_yolo) / "train.jsonl.gz").exists()
126
+
127
+ dataset_from_hf = load_gzip_jsonl(Path(output_dir_hf) / "train.jsonl.gz")
128
+ dataset_from_yolo = load_gzip_jsonl(Path(output_dir_yolo) / "train.jsonl.gz")
129
+ expected_dataset = load_gzip_jsonl(Path(lr_dataset) / "train.jsonl.gz")
130
+ assert dataset_from_hf == dataset_from_yolo == expected_dataset