teklia-layout-reader 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- layout_reader/__init__.py +9 -0
- layout_reader/cli.py +27 -0
- layout_reader/datasets/__init__.py +99 -0
- layout_reader/datasets/analyze.py +161 -0
- layout_reader/datasets/extract.py +289 -0
- layout_reader/datasets/lsd.py +133 -0
- layout_reader/datasets/utils.py +128 -0
- layout_reader/helpers.py +358 -0
- layout_reader/inference.py +215 -0
- layout_reader/train/sft.py +69 -0
- teklia_layout_reader-0.2.1.dist-info/METADATA +62 -0
- teklia_layout_reader-0.2.1.dist-info/RECORD +22 -0
- teklia_layout_reader-0.2.1.dist-info/WHEEL +5 -0
- teklia_layout_reader-0.2.1.dist-info/entry_points.txt +2 -0
- teklia_layout_reader-0.2.1.dist-info/top_level.txt +2 -0
- tests/__init__.py +3 -0
- tests/conftest.py +19 -0
- tests/test_analyze.py +14 -0
- tests/test_cli.py +11 -0
- tests/test_extract.py +130 -0
- tests/test_helpers.py +438 -0
- tests/test_predict.py +64 -0
|
@@ -0,0 +1,215 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import logging
|
|
3
|
+
from dataclasses import asdict, dataclass
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
from layout_reader.helpers import (
|
|
8
|
+
boxes_to_inputs,
|
|
9
|
+
load_dataset_split,
|
|
10
|
+
load_model,
|
|
11
|
+
parse_logits,
|
|
12
|
+
prepare_inputs,
|
|
13
|
+
save_visualization,
|
|
14
|
+
sort_sample,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
logger = logging.getLogger(__name__)
|
|
18
|
+
DATA_FILES = ["train", "val", "dev", "test"]
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@dataclass
|
|
22
|
+
class PageResult:
|
|
23
|
+
boxes: list[Any] # List of bounding boxes to reorder
|
|
24
|
+
classes: list[Any] # List of corresponding classes
|
|
25
|
+
separators: list[Any] # List of separators, used as context by the model
|
|
26
|
+
predicted_order: list[int] # Detected reading order
|
|
27
|
+
|
|
28
|
+
# Optional evaluation outputs
|
|
29
|
+
target_order: list[int] | None = None # Target reading order
|
|
30
|
+
|
|
31
|
+
# Optional visualization output
|
|
32
|
+
visualization: str | None = None # Path to the visualization image
|
|
33
|
+
|
|
34
|
+
@property
|
|
35
|
+
def average_relative_distance(self) -> float:
|
|
36
|
+
"""
|
|
37
|
+
Compute the Average Relative Distance (ARD) between predicted and ground-truth ordering.
|
|
38
|
+
"""
|
|
39
|
+
gt_zones = [self.boxes[i] for i in self.target_order]
|
|
40
|
+
predicted_zones = [self.boxes[i] for i in self.predicted_order]
|
|
41
|
+
|
|
42
|
+
ard = 0
|
|
43
|
+
|
|
44
|
+
if not gt_zones and not predicted_zones:
|
|
45
|
+
return 0
|
|
46
|
+
|
|
47
|
+
if not gt_zones or not predicted_zones:
|
|
48
|
+
return float(max(len(gt_zones), len(predicted_zones)))
|
|
49
|
+
|
|
50
|
+
for pred_idx, pred_zone in enumerate(predicted_zones):
|
|
51
|
+
if pred_zone in gt_zones:
|
|
52
|
+
ard += abs(gt_zones.index(pred_zone) - pred_idx)
|
|
53
|
+
else:
|
|
54
|
+
ard += len(predicted_zones)
|
|
55
|
+
|
|
56
|
+
return ard / len(predicted_zones)
|
|
57
|
+
|
|
58
|
+
def to_json(self) -> dict[str, Any]:
|
|
59
|
+
data = asdict(self)
|
|
60
|
+
|
|
61
|
+
if self.target_order is not None:
|
|
62
|
+
data["average_relative_distance"] = self.average_relative_distance
|
|
63
|
+
|
|
64
|
+
return data
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def add_predict_command(subcommands) -> None:
|
|
68
|
+
parser = subcommands.add_parser(
|
|
69
|
+
"inference",
|
|
70
|
+
description="Predict and evaluate a model on a Layout Reader dataset.",
|
|
71
|
+
help="Predict and evaluate a model on a Layout Reader dataset.",
|
|
72
|
+
)
|
|
73
|
+
parser.add_argument(
|
|
74
|
+
"--dataset",
|
|
75
|
+
type=str,
|
|
76
|
+
help="Path to the local LayoutReader dataset directory. The directory must contain .jsonl.gz files.",
|
|
77
|
+
)
|
|
78
|
+
parser.add_argument(
|
|
79
|
+
"--split",
|
|
80
|
+
choices=DATA_FILES,
|
|
81
|
+
help="Dataset split to use. Must match the name of a corresponding archive in the --dataset directory (e.g., 'train' -> train.jsonl.gz).",
|
|
82
|
+
)
|
|
83
|
+
parser.add_argument(
|
|
84
|
+
"--model",
|
|
85
|
+
type=str,
|
|
86
|
+
help="Name of the LayoutReader checkpoint dataset",
|
|
87
|
+
)
|
|
88
|
+
parser.add_argument(
|
|
89
|
+
"--output-dir",
|
|
90
|
+
type=Path,
|
|
91
|
+
help="Output directory where results will be saved. If --visualization, it will also plot the predicted reading order on each image.",
|
|
92
|
+
required=True,
|
|
93
|
+
)
|
|
94
|
+
parser.add_argument(
|
|
95
|
+
"--with-classes",
|
|
96
|
+
action="store_true",
|
|
97
|
+
help="Whether to use the zone classes for the prediction.",
|
|
98
|
+
)
|
|
99
|
+
parser.add_argument(
|
|
100
|
+
"--with-separators",
|
|
101
|
+
action="store_true",
|
|
102
|
+
help="Whether to use the separators for the prediction.",
|
|
103
|
+
)
|
|
104
|
+
parser.add_argument(
|
|
105
|
+
"--sort-method",
|
|
106
|
+
choices=["sortxy_by_column", "sortxy", "sortyx", "random"],
|
|
107
|
+
default=None,
|
|
108
|
+
help="How to pre-order the input zones. By default, no sorting method will be applied.",
|
|
109
|
+
)
|
|
110
|
+
parser.add_argument(
|
|
111
|
+
"--images",
|
|
112
|
+
type=Path,
|
|
113
|
+
help="Path to the images (required to visualize predictions).",
|
|
114
|
+
default=None,
|
|
115
|
+
)
|
|
116
|
+
parser.add_argument(
|
|
117
|
+
"--visualize",
|
|
118
|
+
action="store_true",
|
|
119
|
+
help="Whether to visualize the predicted reading order. Plots will be saved in output_dir.",
|
|
120
|
+
)
|
|
121
|
+
parser.set_defaults(func=run)
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def predict(model, boxes, classes, separators) -> list[int]:
|
|
125
|
+
if not boxes:
|
|
126
|
+
return []
|
|
127
|
+
inputs = boxes_to_inputs(boxes, cls=classes, separators=separators)
|
|
128
|
+
inputs = prepare_inputs(inputs, model)
|
|
129
|
+
logits = model(**inputs).logits.squeeze(0)
|
|
130
|
+
predicted_orders = parse_logits(logits, len(boxes))
|
|
131
|
+
return predicted_orders
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def process_page(
|
|
135
|
+
page,
|
|
136
|
+
model,
|
|
137
|
+
with_classes: bool,
|
|
138
|
+
with_separators: bool,
|
|
139
|
+
visualize: bool,
|
|
140
|
+
images: Path | None,
|
|
141
|
+
output_dir: Path,
|
|
142
|
+
sort_method: str | None,
|
|
143
|
+
) -> PageResult:
|
|
144
|
+
# Pre-order boxes and classes
|
|
145
|
+
if sort_method is not None:
|
|
146
|
+
page = sort_sample(page, sort_ratio=1.0, sort_method=sort_method)
|
|
147
|
+
|
|
148
|
+
boxes = page["source_boxes"]
|
|
149
|
+
classes = page["source_classes"] if with_classes else []
|
|
150
|
+
separators = page["separators"] if with_separators else []
|
|
151
|
+
predicted_order = predict(model, boxes, classes, separators)
|
|
152
|
+
|
|
153
|
+
result = PageResult(
|
|
154
|
+
boxes=boxes,
|
|
155
|
+
classes=classes,
|
|
156
|
+
separators=separators,
|
|
157
|
+
predicted_order=predicted_order,
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
# Run evaluation if the ground truth is available
|
|
161
|
+
if "target_index" in page:
|
|
162
|
+
result.target_order = [i - 1 for i in page["target_index"]]
|
|
163
|
+
|
|
164
|
+
if visualize and images is not None and result.boxes:
|
|
165
|
+
image_path = (images / page["sample_id"]).with_suffix(".jpg")
|
|
166
|
+
output_path = output_dir / f"{page['sample_id']}_visualize.jpg"
|
|
167
|
+
save_visualization(
|
|
168
|
+
image_path=image_path,
|
|
169
|
+
boxes=boxes,
|
|
170
|
+
predicted_order=predicted_order,
|
|
171
|
+
output_path=output_path,
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
result.visualization = str(output_path)
|
|
175
|
+
return result
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
def run(
|
|
179
|
+
dataset: str,
|
|
180
|
+
split: str,
|
|
181
|
+
model: str,
|
|
182
|
+
output_dir: Path,
|
|
183
|
+
with_classes: bool = False,
|
|
184
|
+
with_separators: bool = False,
|
|
185
|
+
images: Path | None = None,
|
|
186
|
+
visualize: bool = False,
|
|
187
|
+
sort_method: str | None = None,
|
|
188
|
+
) -> None:
|
|
189
|
+
logger.info(f"Loading model {model}")
|
|
190
|
+
model = load_model(model_path=model)
|
|
191
|
+
|
|
192
|
+
logger.info(f"Loading dataset {dataset}/{split}.jsonl.gz")
|
|
193
|
+
dataset = load_dataset_split(dataset, split)
|
|
194
|
+
|
|
195
|
+
results = {}
|
|
196
|
+
|
|
197
|
+
if visualize and images is None:
|
|
198
|
+
logger.warning("Skipping visualization, as --images is not defined")
|
|
199
|
+
|
|
200
|
+
for page in dataset:
|
|
201
|
+
page_result = process_page(
|
|
202
|
+
page=page,
|
|
203
|
+
model=model,
|
|
204
|
+
with_classes=with_classes,
|
|
205
|
+
with_separators=with_separators,
|
|
206
|
+
visualize=visualize,
|
|
207
|
+
images=images,
|
|
208
|
+
output_dir=output_dir,
|
|
209
|
+
sort_method=sort_method,
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
results[page["sample_id"]] = page_result.to_json()
|
|
213
|
+
|
|
214
|
+
output_dir.mkdir(exist_ok=True)
|
|
215
|
+
(output_dir / "predictions.json").write_text(json.dumps(results, indent=4))
|
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
|
|
4
|
+
from transformers import (
|
|
5
|
+
LayoutLMv3ForTokenClassification,
|
|
6
|
+
)
|
|
7
|
+
from transformers.trainer import Trainer
|
|
8
|
+
from trl import SFTConfig
|
|
9
|
+
|
|
10
|
+
from layout_reader.helpers import (
|
|
11
|
+
MAX_LEN,
|
|
12
|
+
DataCollator,
|
|
13
|
+
load_dataset_split,
|
|
14
|
+
read_yaml,
|
|
15
|
+
sort_sample,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
logger = logging.getLogger()
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def run_training(config: dict):
|
|
22
|
+
train_dataset = load_dataset_split(Path(config["dataset_dir"]), "train")
|
|
23
|
+
dev_dataset = load_dataset_split(Path(config["dataset_dir"]), "dev")
|
|
24
|
+
if config["sort_ratio"] > 0:
|
|
25
|
+
|
|
26
|
+
def apply_sort_fn(dataset):
|
|
27
|
+
return sort_sample(
|
|
28
|
+
dataset,
|
|
29
|
+
sort_ratio=config["sort_ratio"],
|
|
30
|
+
sort_method=config["sort_method"],
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
train_dataset = train_dataset.map(apply_sort_fn)
|
|
34
|
+
dev_dataset = dev_dataset.map(apply_sort_fn)
|
|
35
|
+
|
|
36
|
+
logger.info(
|
|
37
|
+
f"Train dataset size: {len(train_dataset)}, Dev dataset size: {len(dev_dataset)}"
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
model = LayoutLMv3ForTokenClassification.from_pretrained(
|
|
41
|
+
config["model_dir"], num_labels=MAX_LEN, visual_embed=False
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
data_collator = DataCollator(
|
|
45
|
+
with_classes=config["with_classes"], with_separators=config["with_separators"]
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
trainer = Trainer(
|
|
49
|
+
model=model,
|
|
50
|
+
args=SFTConfig(**config["SFTTrainer"]),
|
|
51
|
+
train_dataset=train_dataset,
|
|
52
|
+
eval_dataset=dev_dataset,
|
|
53
|
+
data_collator=data_collator,
|
|
54
|
+
)
|
|
55
|
+
trainer.train()
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def add_training_command(subcommands) -> None:
|
|
59
|
+
parser = subcommands.add_parser(
|
|
60
|
+
"train",
|
|
61
|
+
description="Train a Layout Reader model on a dataset.",
|
|
62
|
+
help="Train a Layout Reader model on a dataset.",
|
|
63
|
+
)
|
|
64
|
+
parser.add_argument(
|
|
65
|
+
"--config",
|
|
66
|
+
type=read_yaml,
|
|
67
|
+
help="Name of the LayoutReader dataset",
|
|
68
|
+
)
|
|
69
|
+
parser.set_defaults(func=run_training)
|
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: teklia-layout-reader
|
|
3
|
+
Version: 0.2.1
|
|
4
|
+
Summary: Scripts for Layout Reader
|
|
5
|
+
Author-email: Teklia <contact@teklia.com>
|
|
6
|
+
Maintainer-email: Teklia <contact@teklia.com>
|
|
7
|
+
License-Expression: MIT
|
|
8
|
+
Keywords: python
|
|
9
|
+
Classifier: Programming Language :: Python :: 3 :: Only
|
|
10
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
11
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
12
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
13
|
+
Requires-Python: >=3.10
|
|
14
|
+
Description-Content-Type: text/markdown
|
|
15
|
+
Requires-Dist: teklia-toolbox==0.1.11
|
|
16
|
+
Requires-Dist: datasets==3.2.0
|
|
17
|
+
Requires-Dist: mdutils==1.6.0
|
|
18
|
+
Requires-Dist: prettytable==3.13.0
|
|
19
|
+
Requires-Dist: opencv-python-headless==4.12.0.88
|
|
20
|
+
Requires-Dist: matplotlib==3.10.6
|
|
21
|
+
Requires-Dist: deepspeed==0.18.1
|
|
22
|
+
Requires-Dist: trl==0.23.0
|
|
23
|
+
Requires-Dist: mpi4py==4.1.1
|
|
24
|
+
Requires-Dist: wandb==0.23.1
|
|
25
|
+
Requires-Dist: transformers==4.57.6
|
|
26
|
+
Requires-Dist: colour==0.1.5
|
|
27
|
+
Requires-Dist: accelerate==1.12.0
|
|
28
|
+
|
|
29
|
+
# Layout Reader
|
|
30
|
+
|
|
31
|
+
Scripts for Layout Reader
|
|
32
|
+
|
|
33
|
+
### Development
|
|
34
|
+
|
|
35
|
+
For development and tests purpose it may be useful to install the project as a editable package with pip.
|
|
36
|
+
|
|
37
|
+
* Use a virtualenv (e.g. with virtualenvwrapper `mkvirtualenv -a . layout-reader`)
|
|
38
|
+
* Install layout-reader as a package (e.g. `pip install -e .`)
|
|
39
|
+
|
|
40
|
+
### Linter
|
|
41
|
+
|
|
42
|
+
Code syntax is analyzed before submitting the code.\
|
|
43
|
+
To run the linter tools suite you may use pre-commit.
|
|
44
|
+
|
|
45
|
+
```shell
|
|
46
|
+
pip install pre-commit
|
|
47
|
+
pre-commit run -a
|
|
48
|
+
```
|
|
49
|
+
|
|
50
|
+
### Run tests
|
|
51
|
+
|
|
52
|
+
Tests are executed with `tox` using [pytest](https://pytest.org).
|
|
53
|
+
|
|
54
|
+
```shell
|
|
55
|
+
pip install tox
|
|
56
|
+
tox
|
|
57
|
+
```
|
|
58
|
+
|
|
59
|
+
To recreate tox virtual environment (e.g. a dependencies update), you may run `tox -r`.
|
|
60
|
+
|
|
61
|
+
Run a single test module: `tox -- <test_path>`
|
|
62
|
+
Run a single test: `tox -- <test_path>::<test_function>`
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
layout_reader/__init__.py,sha256=5doGcBcb3kTQMoDbjiESh1g0CrG_mHTJlGRX4Uo7NRU,156
|
|
2
|
+
layout_reader/cli.py,sha256=4M4nzV7kZmF-LmUqHie_eeXJIhvHdhgAi-_q9gm7yEo,665
|
|
3
|
+
layout_reader/helpers.py,sha256=rWJT0glBqOYCj4tvktiPunxoTHiNb5v2oD5MainDrhU,11675
|
|
4
|
+
layout_reader/inference.py,sha256=ECujeVPfYI6DM0yl2WorzmWwD5bMZTblm4uFg2ggldY,6616
|
|
5
|
+
layout_reader/datasets/__init__.py,sha256=ppjB6O0uDJLR3FG3aKnZmtZzYLWJC8Aeb3KjnAWawzE,2663
|
|
6
|
+
layout_reader/datasets/analyze.py,sha256=Lw_Coqpudlz2v4quZPvWr61AEJYTkSYJNUHkzeCxZ38,4700
|
|
7
|
+
layout_reader/datasets/extract.py,sha256=Sdoe8kMBrjKd2ydGYDySc-6e2w6_0BgFa8sLqTsQNKU,8726
|
|
8
|
+
layout_reader/datasets/lsd.py,sha256=jVE6ByzR9OPnvLvpYbeF5adDWHEeVLFor8-9o9uVzGE,4211
|
|
9
|
+
layout_reader/datasets/utils.py,sha256=Nvw3y441WS-fU-9qXCMcM3_UZaEBnBsS19Is0zu6Z_w,3393
|
|
10
|
+
layout_reader/train/sft.py,sha256=yb2MqEZhQDxkBcUXqvPC-ebjG7BuaLzoT2vEaRo7HlE,1872
|
|
11
|
+
tests/__init__.py,sha256=Z6xMdy1fM-m3Ncq4LMozIv_LvHsHnUCcFPGe7qaI3h0,78
|
|
12
|
+
tests/conftest.py,sha256=CX31U3NxXMLQp28DK2uU4GELObTAKdaZvlX9B6JlGrY,500
|
|
13
|
+
tests/test_analyze.py,sha256=DeOQ2vy0qYwEmIquhLCN_LrOgVwv_sO55apXaLpYGxo,405
|
|
14
|
+
tests/test_cli.py,sha256=vizFW9k23hKZQSNjXZOXGtwfzS646HK8k4Huag4T1lI,246
|
|
15
|
+
tests/test_extract.py,sha256=ruT3fedzXCQ1FdW3uYYEZA2VdX3GxK67hzdoLi1u9jA,3251
|
|
16
|
+
tests/test_helpers.py,sha256=sOX3JWfZE_EL2ZD6qNF9vFbqSm6Ak2158nZcFsdD91o,13896
|
|
17
|
+
tests/test_predict.py,sha256=DoZZtG8AlJtwxF2f6bz1CPCIrMjLjs4MbHRA8XF4Q5k,1613
|
|
18
|
+
teklia_layout_reader-0.2.1.dist-info/METADATA,sha256=DU5rASxjSgEjEE0LZBTIOvSxgLjUBZM0owXPyAQ5kr8,1735
|
|
19
|
+
teklia_layout_reader-0.2.1.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
|
|
20
|
+
teklia_layout_reader-0.2.1.dist-info/entry_points.txt,sha256=9kJyzLJuIhGeoYnwHxYITSNh3--7v7HTMfgwIpG0_fc,64
|
|
21
|
+
teklia_layout_reader-0.2.1.dist-info/top_level.txt,sha256=G0dTKLPB9MCL_4pPL8Y0aqjT6ujCWb3Y1Si3UutIvCU,20
|
|
22
|
+
teklia_layout_reader-0.2.1.dist-info/RECORD,,
|
tests/__init__.py
ADDED
tests/conftest.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
import os
|
|
2
|
+
|
|
3
|
+
import pytest
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
@pytest.fixture(autouse=True)
|
|
7
|
+
def _setup_environment(responses):
|
|
8
|
+
"""Setup needed environment variables"""
|
|
9
|
+
|
|
10
|
+
# Allow accessing remote API schemas
|
|
11
|
+
# defaulting to the prod environment
|
|
12
|
+
schema_url = os.environ.get(
|
|
13
|
+
"ARKINDEX_API_SCHEMA_URL",
|
|
14
|
+
"https://arkindex.teklia.com/api/v1/openapi/?format=openapi-json",
|
|
15
|
+
)
|
|
16
|
+
responses.add_passthru(schema_url)
|
|
17
|
+
|
|
18
|
+
# Set schema url in environment
|
|
19
|
+
os.environ["ARKINDEX_API_SCHEMA_URL"] = schema_url
|
tests/test_analyze.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
from layout_reader.datasets.analyze import Statistics
|
|
2
|
+
from tests import FIXTURES
|
|
3
|
+
|
|
4
|
+
DATASET = FIXTURES / "lr_dataset" / "train.jsonl.gz"
|
|
5
|
+
REPORT = FIXTURES / "dataset_report.md"
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def test_run(tmp_path) -> None:
|
|
9
|
+
output_file = tmp_path / "report.md"
|
|
10
|
+
stats = Statistics(output=str(output_file))
|
|
11
|
+
stats.run(
|
|
12
|
+
filenames=[DATASET],
|
|
13
|
+
)
|
|
14
|
+
assert output_file.read_text() == REPORT.read_text()
|
tests/test_cli.py
ADDED
tests/test_extract.py
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
1
|
+
import random
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
import pytest
|
|
6
|
+
|
|
7
|
+
from layout_reader.datasets.extract import (
|
|
8
|
+
Mode,
|
|
9
|
+
run_dataset_extraction,
|
|
10
|
+
)
|
|
11
|
+
from layout_reader.datasets.utils import load_gzip_jsonl
|
|
12
|
+
from tests import FIXTURES
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@pytest.fixture
|
|
16
|
+
def SPLIT() -> str:
|
|
17
|
+
return "train"
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@pytest.fixture
|
|
21
|
+
def hf_dataset() -> str:
|
|
22
|
+
return str(FIXTURES / "hf_dataset")
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@pytest.fixture
|
|
26
|
+
def yolo_dataset() -> str:
|
|
27
|
+
return str(FIXTURES / "yolo_dataset")
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@pytest.fixture
|
|
31
|
+
def lr_dataset() -> str:
|
|
32
|
+
return str(FIXTURES / "lr_dataset")
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@pytest.mark.parametrize(
|
|
36
|
+
("shuffle_rate", "extract_classes", "extract_separators"),
|
|
37
|
+
[
|
|
38
|
+
(0.0, False, False),
|
|
39
|
+
(0.0, True, False),
|
|
40
|
+
(0.0, False, True),
|
|
41
|
+
(1.0, True, True),
|
|
42
|
+
(0.5, False, True),
|
|
43
|
+
(0.5, True, False),
|
|
44
|
+
],
|
|
45
|
+
)
|
|
46
|
+
def test_run_dataset_extraction_modes(
|
|
47
|
+
hf_dataset,
|
|
48
|
+
yolo_dataset,
|
|
49
|
+
shuffle_rate,
|
|
50
|
+
extract_classes,
|
|
51
|
+
extract_separators,
|
|
52
|
+
tmp_path,
|
|
53
|
+
):
|
|
54
|
+
output_dir_hf = tmp_path / "output_hf"
|
|
55
|
+
output_dir_yolo = tmp_path / "output_yolo"
|
|
56
|
+
|
|
57
|
+
random.seed(42)
|
|
58
|
+
np.random.seed(42)
|
|
59
|
+
run_dataset_extraction(
|
|
60
|
+
hf_dataset,
|
|
61
|
+
Mode.HF,
|
|
62
|
+
shuffle_rate=shuffle_rate,
|
|
63
|
+
output_dir=output_dir_hf,
|
|
64
|
+
extract_classes=extract_classes,
|
|
65
|
+
extract_separators=extract_separators,
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
random.seed(42)
|
|
69
|
+
np.random.seed(42)
|
|
70
|
+
run_dataset_extraction(
|
|
71
|
+
yolo_dataset,
|
|
72
|
+
Mode.YOLO,
|
|
73
|
+
shuffle_rate=shuffle_rate,
|
|
74
|
+
output_dir=output_dir_yolo,
|
|
75
|
+
extract_classes=extract_classes,
|
|
76
|
+
extract_separators=extract_separators,
|
|
77
|
+
)
|
|
78
|
+
assert (Path(output_dir_hf) / "train.jsonl.gz").exists()
|
|
79
|
+
assert (Path(output_dir_yolo) / "train.jsonl.gz").exists()
|
|
80
|
+
assert load_gzip_jsonl(Path(output_dir_hf) / "train.jsonl.gz") == load_gzip_jsonl(
|
|
81
|
+
Path(output_dir_yolo) / "train.jsonl.gz"
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
@pytest.mark.parametrize(
|
|
86
|
+
("shuffle_rate", "extract_classes", "extract_separators"),
|
|
87
|
+
[
|
|
88
|
+
(0.5, True, True),
|
|
89
|
+
],
|
|
90
|
+
)
|
|
91
|
+
def test_run_dataset_extraction(
|
|
92
|
+
hf_dataset,
|
|
93
|
+
yolo_dataset,
|
|
94
|
+
lr_dataset,
|
|
95
|
+
shuffle_rate,
|
|
96
|
+
extract_classes,
|
|
97
|
+
extract_separators,
|
|
98
|
+
tmp_path,
|
|
99
|
+
):
|
|
100
|
+
output_dir_hf = tmp_path / "output_hf"
|
|
101
|
+
output_dir_yolo = tmp_path / "output_yolo"
|
|
102
|
+
|
|
103
|
+
random.seed(42)
|
|
104
|
+
np.random.seed(42)
|
|
105
|
+
run_dataset_extraction(
|
|
106
|
+
hf_dataset,
|
|
107
|
+
Mode.HF,
|
|
108
|
+
shuffle_rate=shuffle_rate,
|
|
109
|
+
output_dir=output_dir_hf,
|
|
110
|
+
extract_classes=extract_classes,
|
|
111
|
+
extract_separators=extract_separators,
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
random.seed(42)
|
|
115
|
+
np.random.seed(42)
|
|
116
|
+
run_dataset_extraction(
|
|
117
|
+
yolo_dataset,
|
|
118
|
+
Mode.YOLO,
|
|
119
|
+
shuffle_rate=shuffle_rate,
|
|
120
|
+
output_dir=output_dir_yolo,
|
|
121
|
+
extract_classes=extract_classes,
|
|
122
|
+
extract_separators=extract_separators,
|
|
123
|
+
)
|
|
124
|
+
assert (Path(output_dir_hf) / "train.jsonl.gz").exists()
|
|
125
|
+
assert (Path(output_dir_yolo) / "train.jsonl.gz").exists()
|
|
126
|
+
|
|
127
|
+
dataset_from_hf = load_gzip_jsonl(Path(output_dir_hf) / "train.jsonl.gz")
|
|
128
|
+
dataset_from_yolo = load_gzip_jsonl(Path(output_dir_yolo) / "train.jsonl.gz")
|
|
129
|
+
expected_dataset = load_gzip_jsonl(Path(lr_dataset) / "train.jsonl.gz")
|
|
130
|
+
assert dataset_from_hf == dataset_from_yolo == expected_dataset
|