teklia-layout-reader 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tests/test_helpers.py ADDED
@@ -0,0 +1,438 @@
1
+ import pytest
2
+ import torch
3
+
4
+ from layout_reader.helpers import (
5
+ CLS_TOKEN_ID,
6
+ EOS_TOKEN_ID,
7
+ IGNORE_LABEL_ID,
8
+ PAD_TOKEN_ID,
9
+ DataCollator,
10
+ boxes_to_inputs,
11
+ load_dataset_split,
12
+ parse_logits,
13
+ sort_sample,
14
+ )
15
+ from tests import FIXTURES
16
+
17
+
18
+ @pytest.fixture
19
+ def dataset() -> str:
20
+ return FIXTURES / "lr_dataset"
21
+
22
+
23
+ @pytest.mark.parametrize(
24
+ (
25
+ "features",
26
+ "with_classes",
27
+ "with_separators",
28
+ "expected_boxes",
29
+ "expected_input_ids",
30
+ "expected_attention_masks",
31
+ "expected_labels",
32
+ ),
33
+ [
34
+ (
35
+ [
36
+ {
37
+ "source_boxes": [[1, 1, 2, 2], [2, 2, 3, 3]],
38
+ "separators": [[1, 1, 1, 6]],
39
+ "target_index": [1, 2],
40
+ "source_classes": [10, 11],
41
+ }
42
+ ],
43
+ True,
44
+ True,
45
+ torch.tensor(
46
+ [[[0, 0, 0, 0], [1, 1, 2, 2], [2, 2, 3, 3], [1, 1, 1, 6], [0, 0, 0, 0]]]
47
+ ),
48
+ torch.tensor([[CLS_TOKEN_ID, 10, 11, PAD_TOKEN_ID, EOS_TOKEN_ID]]),
49
+ torch.tensor([[1, 1, 1, 1, 1]]),
50
+ torch.tensor([[IGNORE_LABEL_ID, 0, 1, IGNORE_LABEL_ID, IGNORE_LABEL_ID]]),
51
+ ),
52
+ (
53
+ [
54
+ {
55
+ "source_boxes": [[1, 1, 2, 2], [2, 2, 3, 3]],
56
+ "separators": [[1, 1, 1, 6]],
57
+ "target_index": [1, 2],
58
+ "source_classes": [10, 11],
59
+ }
60
+ ],
61
+ False,
62
+ True,
63
+ torch.tensor(
64
+ [[[0, 0, 0, 0], [1, 1, 2, 2], [2, 2, 3, 3], [1, 1, 1, 6], [0, 0, 0, 0]]]
65
+ ),
66
+ torch.tensor(
67
+ [[CLS_TOKEN_ID, PAD_TOKEN_ID, PAD_TOKEN_ID, PAD_TOKEN_ID, EOS_TOKEN_ID]]
68
+ ),
69
+ torch.tensor([[1, 1, 1, 1, 1]]),
70
+ torch.tensor([[IGNORE_LABEL_ID, 0, 1, IGNORE_LABEL_ID, IGNORE_LABEL_ID]]),
71
+ ),
72
+ (
73
+ [
74
+ {
75
+ "source_boxes": [[1, 1, 2, 2], [2, 2, 3, 3]],
76
+ "separators": [[1, 1, 1, 6]],
77
+ "target_index": [1, 2],
78
+ "source_classes": [10, 11],
79
+ }
80
+ ],
81
+ False,
82
+ False,
83
+ torch.tensor([[[0, 0, 0, 0], [1, 1, 2, 2], [2, 2, 3, 3], [0, 0, 0, 0]]]),
84
+ torch.tensor([[CLS_TOKEN_ID, PAD_TOKEN_ID, PAD_TOKEN_ID, EOS_TOKEN_ID]]),
85
+ torch.tensor([[1, 1, 1, 1]]),
86
+ torch.tensor([[IGNORE_LABEL_ID, 0, 1, IGNORE_LABEL_ID]]),
87
+ ),
88
+ (
89
+ [
90
+ {
91
+ "source_boxes": [[1, 1, 2, 2], [2, 2, 3, 3]],
92
+ "separators": [[1, 1, 1, 6]],
93
+ "target_index": [1, 2],
94
+ "source_classes": [10, 11],
95
+ }
96
+ ],
97
+ True,
98
+ False,
99
+ torch.tensor([[[0, 0, 0, 0], [1, 1, 2, 2], [2, 2, 3, 3], [0, 0, 0, 0]]]),
100
+ torch.tensor([[CLS_TOKEN_ID, 10, 11, EOS_TOKEN_ID]]),
101
+ torch.tensor([[1, 1, 1, 1]]),
102
+ torch.tensor([[IGNORE_LABEL_ID, 0, 1, IGNORE_LABEL_ID]]),
103
+ ),
104
+ ],
105
+ )
106
+ def test_data_collator(
107
+ features,
108
+ with_classes,
109
+ with_separators,
110
+ expected_labels,
111
+ expected_boxes,
112
+ expected_input_ids,
113
+ expected_attention_masks,
114
+ ):
115
+ collator = DataCollator(with_classes=with_classes, with_separators=with_separators)
116
+ output = collator(features)
117
+
118
+ assert set(output.keys()) == {
119
+ "bbox",
120
+ "attention_mask",
121
+ "labels",
122
+ "input_ids",
123
+ }
124
+
125
+ assert torch.equal(output["labels"], expected_labels)
126
+ assert torch.equal(output["bbox"], expected_boxes)
127
+ assert torch.equal(output["input_ids"], expected_input_ids)
128
+ assert torch.equal(output["attention_mask"], expected_attention_masks)
129
+
130
+
131
+ @pytest.mark.parametrize(
132
+ ("boxes", "classes", "separators", "expected_output"),
133
+ [
134
+ (
135
+ [[1, 1, 2, 2], [2, 2, 3, 3]],
136
+ [10, 11],
137
+ [[1, 1, 1, 6]],
138
+ {
139
+ "bbox": torch.tensor(
140
+ [
141
+ [
142
+ [0, 0, 0, 0],
143
+ [1, 1, 2, 2],
144
+ [2, 2, 3, 3],
145
+ [1, 1, 1, 6],
146
+ [0, 0, 0, 0],
147
+ ]
148
+ ]
149
+ ),
150
+ "attention_mask": torch.tensor([[1, 1, 1, 1, 1]]),
151
+ "input_ids": torch.tensor(
152
+ [[CLS_TOKEN_ID, 10, 11, PAD_TOKEN_ID, EOS_TOKEN_ID]]
153
+ ),
154
+ },
155
+ ),
156
+ (
157
+ [[1, 1, 2, 2], [2, 2, 3, 3]],
158
+ [],
159
+ [[1, 1, 1, 6]],
160
+ {
161
+ "bbox": torch.tensor(
162
+ [
163
+ [
164
+ [0, 0, 0, 0],
165
+ [1, 1, 2, 2],
166
+ [2, 2, 3, 3],
167
+ [1, 1, 1, 6],
168
+ [0, 0, 0, 0],
169
+ ]
170
+ ]
171
+ ),
172
+ "attention_mask": torch.tensor([[1, 1, 1, 1, 1]]),
173
+ "input_ids": torch.tensor(
174
+ [
175
+ [
176
+ CLS_TOKEN_ID,
177
+ PAD_TOKEN_ID,
178
+ PAD_TOKEN_ID,
179
+ PAD_TOKEN_ID,
180
+ EOS_TOKEN_ID,
181
+ ]
182
+ ]
183
+ ),
184
+ },
185
+ ),
186
+ (
187
+ [[1, 1, 2, 2], [2, 2, 3, 3]],
188
+ [],
189
+ [],
190
+ {
191
+ "bbox": torch.tensor(
192
+ [[[0, 0, 0, 0], [1, 1, 2, 2], [2, 2, 3, 3], [0, 0, 0, 0]]]
193
+ ),
194
+ "attention_mask": torch.tensor([[1, 1, 1, 1]]),
195
+ "input_ids": torch.tensor(
196
+ [[CLS_TOKEN_ID, PAD_TOKEN_ID, PAD_TOKEN_ID, EOS_TOKEN_ID]]
197
+ ),
198
+ },
199
+ ),
200
+ ],
201
+ )
202
+ def test_boxes_to_input(boxes, classes, separators, expected_output):
203
+ output = boxes_to_inputs(boxes, classes, separators)
204
+
205
+ assert set(output.keys()) == {
206
+ "bbox",
207
+ "attention_mask",
208
+ "input_ids",
209
+ }
210
+ assert torch.equal(output["attention_mask"], expected_output["attention_mask"])
211
+ assert torch.equal(output["bbox"], expected_output["bbox"])
212
+ assert torch.equal(output["input_ids"], expected_output["input_ids"])
213
+
214
+
215
+ @pytest.mark.parametrize(
216
+ ("logits", "length", "expected_order"),
217
+ [
218
+ (
219
+ torch.tensor(
220
+ [
221
+ [0.2, 0.2, 0.2, 0.2, 0.2], # CLS token
222
+ [0.1, 0.8, 0.05, 0.05, 0.0], # Element 0 - position 0
223
+ [0.7, 0.1, 0.05, 0.1, 0.05], # Element 1 - ignored
224
+ [0.05, 0.25, 0.1, 0.55, 0.05], # Element 2 - ignored
225
+ [0.05, 0.15, 0.7, 0.05, 0.05], # Element 3 - ignored
226
+ [0.2, 0.2, 0.2, 0.2, 0.2], # EOS token
227
+ ]
228
+ ),
229
+ 1,
230
+ [0],
231
+ ),
232
+ (
233
+ torch.tensor(
234
+ [
235
+ [0.2, 0.2, 0.2, 0.2, 0.2], # CLS token
236
+ [0.1, 0.8, 0.05, 0.05, 0.0], # Element 0 - position 1
237
+ [0.7, 0.1, 0.05, 0.1, 0.05], # Element 1 - position 0
238
+ [0.05, 0.25, 0.1, 0.55, 0.05], # Element 2 - position 2
239
+ [0.05, 0.15, 0.7, 0.05, 0.05], # Element 3 - ignored
240
+ [0.2, 0.2, 0.2, 0.2, 0.2], # EOS token
241
+ ]
242
+ ),
243
+ 3,
244
+ [1, 0, 2],
245
+ ),
246
+ (
247
+ torch.tensor(
248
+ [
249
+ [0.2, 0.2, 0.2, 0.2, 0.2], # CLS token
250
+ [0.1, 0.8, 0.05, 0.05, 0.0], # Element 0 - position 1
251
+ [0.7, 0.1, 0.05, 0.1, 0.05], # Element 1 - position 0
252
+ [0.05, 0.25, 0.1, 0.55, 0.05], # Element 2 - position 3
253
+ [0.05, 0.15, 0.7, 0.05, 0.05], # Element 3 - position 2
254
+ [0.2, 0.2, 0.2, 0.2, 0.2], # EOS token
255
+ ]
256
+ ),
257
+ 4,
258
+ [1, 0, 3, 2],
259
+ ),
260
+ (
261
+ torch.tensor(
262
+ [ # Conflict element 0 and 1 (not equal)
263
+ [0.2, 0.2, 0.2, 0.2, 0.2], # CLS token
264
+ [
265
+ 0.6,
266
+ 0.2,
267
+ 0.05,
268
+ 0.1,
269
+ 0.05,
270
+ ], # Element 0 - position 0 preferred (conflict -> position 1)
271
+ [
272
+ 0.7,
273
+ 0.1,
274
+ 0.05,
275
+ 0.1,
276
+ 0.05,
277
+ ], # Element 1 - position 0 preferred (conflict -> position 0)
278
+ [0.05, 0.25, 0.1, 0.55, 0.05], # Element 2 - position 3 preferred
279
+ [0.05, 0.15, 0.7, 0.05, 0.05], # Element 3 - position 2 preferred
280
+ [0.2, 0.2, 0.2, 0.2, 0.2], # EOS token
281
+ ]
282
+ ),
283
+ 4,
284
+ [1, 0, 3, 2],
285
+ ),
286
+ (
287
+ torch.tensor(
288
+ [ # Conflict element 0 and 1 (equal)
289
+ [0.2, 0.2, 0.2, 0.2, 0.2], # CLS token
290
+ [
291
+ 0.7,
292
+ 0.1,
293
+ 0.05,
294
+ 0.1,
295
+ 0.05,
296
+ ], # Element 0 - position 0 preferred (conflict -> position 1)
297
+ [
298
+ 0.7,
299
+ 0.1,
300
+ 0.05,
301
+ 0.1,
302
+ 0.05,
303
+ ], # Element 1 - position 0 preferred (conflict -> position 0)
304
+ [0.05, 0.25, 0.1, 0.55, 0.05], # Element 2 - position 3 preferred
305
+ [0.05, 0.15, 0.7, 0.05, 0.05], # Element 3 - position 2 preferred
306
+ [0.2, 0.2, 0.2, 0.2, 0.2], # EOS token
307
+ ]
308
+ ),
309
+ 4,
310
+ [0, 1, 3, 2],
311
+ ),
312
+ (
313
+ torch.tensor(
314
+ [ # Cascade conflicts
315
+ [0.2, 0.2, 0.2, 0.2, 0.2], # CLS token
316
+ [
317
+ 0.6,
318
+ 0.1,
319
+ 0.05,
320
+ 0.1,
321
+ 0.05,
322
+ ], # Element 0 - position 0 preferred (conflict elements #1 #2 #3 -> position 3)
323
+ [
324
+ 0.7,
325
+ 0.1,
326
+ 0.05,
327
+ 0.1,
328
+ 0.05,
329
+ ], # Element 1 - position 0 preferred (conflict element #0 -> position 0)
330
+ [
331
+ 0.05,
332
+ 0.25,
333
+ 0.55,
334
+ 0.1,
335
+ 0.05,
336
+ ], # Element 2 - position 2 preferred (conflict element #0 -> position 2)
337
+ [
338
+ 0.0,
339
+ 0.7,
340
+ 0.3,
341
+ 0.05,
342
+ 0.05,
343
+ ], # Element 3 - position 1 preferred (conflict element #0 -> position 1)
344
+ [0.2, 0.2, 0.2, 0.2, 0.2], # EOS token
345
+ ]
346
+ ),
347
+ 4,
348
+ [3, 0, 2, 1],
349
+ ),
350
+ ],
351
+ )
352
+ def test_parse_logits(logits, length, expected_order):
353
+ output = parse_logits(logits, length=length)
354
+ assert output == expected_order
355
+
356
+
357
+ def test_load_sort_dataset(dataset):
358
+ train_dataset = load_dataset_split(dataset, "train")
359
+ print(train_dataset)
360
+
361
+ assert len(train_dataset) == 2
362
+ assert sorted(train_dataset.column_names) == [
363
+ "sample_id",
364
+ "separators",
365
+ "source_boxes",
366
+ "source_classes",
367
+ "target_boxes",
368
+ "target_classes",
369
+ "target_index",
370
+ ]
371
+
372
+ sample = train_dataset[0]
373
+ assert sample["sample_id"] == "84b4fb2c-d62a-4e50-96f8-b0bb04410182"
374
+ assert sample["separators"][0] == [674, 619, 877, 620]
375
+ assert sample["source_boxes"][0] == [30, 141, 182, 175]
376
+ assert sample["target_boxes"][0] == [178, 59, 803, 87]
377
+ assert sample["source_classes"][0] == 11
378
+ assert sample["target_classes"][0] == 10
379
+ assert sample["target_index"][0] == 42
380
+
381
+
382
+ @pytest.mark.parametrize(
383
+ ("sort_ratio", "sort_method", "expected_sorted_boxes"),
384
+ [
385
+ (
386
+ 1,
387
+ "sortxy",
388
+ [
389
+ [0, 150, 100, 300],
390
+ [1, 0, 100, 100],
391
+ [200, 255, 400, 400],
392
+ [202, 0, 400, 250],
393
+ ],
394
+ ),
395
+ (
396
+ 1,
397
+ "sortyx",
398
+ [
399
+ [1, 0, 100, 100],
400
+ [202, 0, 400, 250],
401
+ [0, 150, 100, 300],
402
+ [200, 255, 400, 400],
403
+ ],
404
+ ),
405
+ (
406
+ 1,
407
+ "sortxy_by_column",
408
+ [
409
+ [1, 0, 100, 100],
410
+ [0, 150, 100, 300],
411
+ [202, 0, 400, 250],
412
+ [200, 255, 400, 400],
413
+ ],
414
+ ),
415
+ ],
416
+ )
417
+ def test_sort_dataset(sort_ratio, sort_method, expected_sorted_boxes):
418
+ sample = {
419
+ "target_classes": [
420
+ 0,
421
+ 0,
422
+ 0,
423
+ 0,
424
+ ],
425
+ "target_boxes": [
426
+ [1, 0, 100, 100],
427
+ [0, 150, 100, 300],
428
+ [202, 0, 400, 250],
429
+ [200, 255, 400, 400],
430
+ ],
431
+ "target_index": [1, 2, 3, 4],
432
+ }
433
+ sorted_boxes = sort_sample(
434
+ sample,
435
+ sort_ratio=sort_ratio,
436
+ sort_method=sort_method,
437
+ )["source_boxes"]
438
+ assert sorted_boxes == expected_sorted_boxes
tests/test_predict.py ADDED
@@ -0,0 +1,64 @@
1
+ import json
2
+
3
+ import pytest
4
+
5
+ from layout_reader.helpers import load_model
6
+ from layout_reader.inference import predict, run
7
+ from tests import FIXTURES
8
+
9
+
10
+ @pytest.fixture
11
+ def model() -> str:
12
+ return str(FIXTURES / "model")
13
+
14
+
15
+ @pytest.fixture
16
+ def lr_dataset() -> str:
17
+ return str(FIXTURES / "lr_dataset")
18
+
19
+
20
+ @pytest.fixture
21
+ def split() -> str:
22
+ return "train"
23
+
24
+
25
+ @pytest.fixture
26
+ def images() -> str:
27
+ return str(FIXTURES / "lr_dataset")
28
+
29
+
30
+ @pytest.fixture
31
+ def expected_predictions() -> dict:
32
+ return json.loads((FIXTURES / "predictions.json").read_text())
33
+
34
+
35
+ @pytest.mark.parametrize(
36
+ ("boxes", "classes", "separators", "expected_order"),
37
+ [
38
+ ([], [], [], []),
39
+ ([[1, 1, 2, 2]], [], [], [0]),
40
+ ([[1, 1, 2, 2], [2, 2, 3, 3]], [10, 11], [[1, 1, 1, 6]], [0, 1]),
41
+ ([[1, 1, 2, 2], [2, 2, 3, 3]], [], [[1, 1, 1, 6]], [0, 1]),
42
+ ([[1, 1, 2, 2], [2, 2, 3, 3]], [10, 11], [], [0, 1]),
43
+ ([[1, 1, 2, 2], [2, 2, 3, 3]], [], [], [0, 1]),
44
+ ],
45
+ )
46
+ def test_predict(model, boxes, classes, separators, expected_order):
47
+ model = load_model(model)
48
+ predicted_order = predict(model, boxes, classes, separators)
49
+ assert predicted_order == expected_order
50
+
51
+
52
+ def test_run_inference(lr_dataset, split, model, tmp_path, expected_predictions):
53
+ output_dir = tmp_path / "output"
54
+
55
+ run(
56
+ dataset=lr_dataset,
57
+ split=split,
58
+ model=model,
59
+ output_dir=output_dir,
60
+ )
61
+
62
+ assert (output_dir / "predictions.json").exists()
63
+ predictions = json.loads((output_dir / "predictions.json").read_text())
64
+ assert predictions == expected_predictions