magic-pdf 0.5.13__py3-none-any.whl → 0.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- magic_pdf/cli/magicpdf.py +18 -7
- magic_pdf/libs/config_reader.py +10 -0
- magic_pdf/libs/version.py +1 -1
- magic_pdf/model/__init__.py +1 -0
- magic_pdf/model/doc_analyze_by_custom_model.py +38 -15
- magic_pdf/model/model_list.py +1 -0
- magic_pdf/model/pdf_extract_kit.py +196 -0
- magic_pdf/model/pek_sub_modules/__init__.py +0 -0
- magic_pdf/model/pek_sub_modules/layoutlmv3/__init__.py +0 -0
- magic_pdf/model/pek_sub_modules/layoutlmv3/backbone.py +179 -0
- magic_pdf/model/pek_sub_modules/layoutlmv3/beit.py +671 -0
- magic_pdf/model/pek_sub_modules/layoutlmv3/deit.py +476 -0
- magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/__init__.py +7 -0
- magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/data/__init__.py +2 -0
- magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/data/cord.py +171 -0
- magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/data/data_collator.py +124 -0
- magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/data/funsd.py +136 -0
- magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/data/image_utils.py +284 -0
- magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/data/xfund.py +213 -0
- magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/models/__init__.py +7 -0
- magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/models/layoutlmv3/__init__.py +24 -0
- magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/models/layoutlmv3/configuration_layoutlmv3.py +60 -0
- magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/models/layoutlmv3/modeling_layoutlmv3.py +1282 -0
- magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/models/layoutlmv3/tokenization_layoutlmv3.py +32 -0
- magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/models/layoutlmv3/tokenization_layoutlmv3_fast.py +34 -0
- magic_pdf/model/pek_sub_modules/layoutlmv3/model_init.py +150 -0
- magic_pdf/model/pek_sub_modules/layoutlmv3/rcnn_vl.py +163 -0
- magic_pdf/model/pek_sub_modules/layoutlmv3/visualizer.py +1236 -0
- magic_pdf/model/pek_sub_modules/post_process.py +36 -0
- magic_pdf/model/pek_sub_modules/self_modify.py +260 -0
- magic_pdf/model/pp_structure_v2.py +7 -0
- magic_pdf/pipe/AbsPipe.py +8 -14
- magic_pdf/pipe/OCRPipe.py +12 -8
- magic_pdf/pipe/TXTPipe.py +12 -8
- magic_pdf/pipe/UNIPipe.py +9 -7
- magic_pdf/resources/model_config/UniMERNet/demo.yaml +46 -0
- magic_pdf/resources/model_config/layoutlmv3/layoutlmv3_base_inference.yaml +351 -0
- magic_pdf/resources/model_config/model_configs.yaml +9 -0
- {magic_pdf-0.5.13.dist-info → magic_pdf-0.6.0.dist-info}/METADATA +18 -8
- {magic_pdf-0.5.13.dist-info → magic_pdf-0.6.0.dist-info}/RECORD +44 -18
- magic_pdf/model/360_layout_analysis.py +0 -8
- {magic_pdf-0.5.13.dist-info → magic_pdf-0.6.0.dist-info}/LICENSE.md +0 -0
- {magic_pdf-0.5.13.dist-info → magic_pdf-0.6.0.dist-info}/WHEEL +0 -0
- {magic_pdf-0.5.13.dist-info → magic_pdf-0.6.0.dist-info}/entry_points.txt +0 -0
- {magic_pdf-0.5.13.dist-info → magic_pdf-0.6.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,124 @@
|
|
1
|
+
import torch
|
2
|
+
from dataclasses import dataclass
|
3
|
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
4
|
+
|
5
|
+
from transformers import BatchEncoding, PreTrainedTokenizerBase
|
6
|
+
from transformers.data.data_collator import (
|
7
|
+
DataCollatorMixin,
|
8
|
+
_torch_collate_batch,
|
9
|
+
)
|
10
|
+
from transformers.file_utils import PaddingStrategy
|
11
|
+
|
12
|
+
from typing import NewType
|
13
|
+
InputDataClass = NewType("InputDataClass", Any)
|
14
|
+
|
15
|
+
def pre_calc_rel_mat(segment_ids):
|
16
|
+
valid_span = torch.zeros((segment_ids.shape[0], segment_ids.shape[1], segment_ids.shape[1]),
|
17
|
+
device=segment_ids.device, dtype=torch.bool)
|
18
|
+
for i in range(segment_ids.shape[0]):
|
19
|
+
for j in range(segment_ids.shape[1]):
|
20
|
+
valid_span[i, j, :] = segment_ids[i, :] == segment_ids[i, j]
|
21
|
+
|
22
|
+
return valid_span
|
23
|
+
|
24
|
+
@dataclass
|
25
|
+
class DataCollatorForKeyValueExtraction(DataCollatorMixin):
|
26
|
+
"""
|
27
|
+
Data collator that will dynamically pad the inputs received, as well as the labels.
|
28
|
+
Args:
|
29
|
+
tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
|
30
|
+
The tokenizer used for encoding the data.
|
31
|
+
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.file_utils.PaddingStrategy`, `optional`, defaults to :obj:`True`):
|
32
|
+
Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
|
33
|
+
among:
|
34
|
+
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
|
35
|
+
sequence if provided).
|
36
|
+
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
|
37
|
+
maximum acceptable input length for the model if that argument is not provided.
|
38
|
+
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
|
39
|
+
different lengths).
|
40
|
+
max_length (:obj:`int`, `optional`):
|
41
|
+
Maximum length of the returned list and optionally padding length (see above).
|
42
|
+
pad_to_multiple_of (:obj:`int`, `optional`):
|
43
|
+
If set will pad the sequence to a multiple of the provided value.
|
44
|
+
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
|
45
|
+
7.5 (Volta).
|
46
|
+
label_pad_token_id (:obj:`int`, `optional`, defaults to -100):
|
47
|
+
The id to use when padding the labels (-100 will be automatically ignore by PyTorch loss functions).
|
48
|
+
"""
|
49
|
+
|
50
|
+
tokenizer: PreTrainedTokenizerBase
|
51
|
+
padding: Union[bool, str, PaddingStrategy] = True
|
52
|
+
max_length: Optional[int] = None
|
53
|
+
pad_to_multiple_of: Optional[int] = None
|
54
|
+
label_pad_token_id: int = -100
|
55
|
+
|
56
|
+
def __call__(self, features):
|
57
|
+
label_name = "label" if "label" in features[0].keys() else "labels"
|
58
|
+
labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None
|
59
|
+
|
60
|
+
images = None
|
61
|
+
if "images" in features[0]:
|
62
|
+
images = torch.stack([torch.tensor(d.pop("images")) for d in features])
|
63
|
+
IMAGE_LEN = int(images.shape[-1] / 16) * int(images.shape[-1] / 16) + 1
|
64
|
+
|
65
|
+
batch = self.tokenizer.pad(
|
66
|
+
features,
|
67
|
+
padding=self.padding,
|
68
|
+
max_length=self.max_length,
|
69
|
+
pad_to_multiple_of=self.pad_to_multiple_of,
|
70
|
+
# Conversion to tensors will fail if we have labels as they are not of the same length yet.
|
71
|
+
return_tensors="pt" if labels is None else None,
|
72
|
+
)
|
73
|
+
|
74
|
+
if images is not None:
|
75
|
+
batch["images"] = images
|
76
|
+
batch = {k: torch.tensor(v, dtype=torch.int64) if isinstance(v[0], list) and k == 'attention_mask' else v
|
77
|
+
for k, v in batch.items()}
|
78
|
+
visual_attention_mask = torch.ones((len(batch['input_ids']), IMAGE_LEN), dtype=torch.long)
|
79
|
+
batch["attention_mask"] = torch.cat([batch['attention_mask'], visual_attention_mask], dim=1)
|
80
|
+
|
81
|
+
if labels is None:
|
82
|
+
return batch
|
83
|
+
|
84
|
+
has_bbox_input = "bbox" in features[0]
|
85
|
+
has_position_input = "position_ids" in features[0]
|
86
|
+
padding_idx=self.tokenizer.pad_token_id
|
87
|
+
sequence_length = torch.tensor(batch["input_ids"]).shape[1]
|
88
|
+
padding_side = self.tokenizer.padding_side
|
89
|
+
if padding_side == "right":
|
90
|
+
batch["labels"] = [label + [self.label_pad_token_id] * (sequence_length - len(label)) for label in labels]
|
91
|
+
if has_bbox_input:
|
92
|
+
batch["bbox"] = [bbox + [[0, 0, 0, 0]] * (sequence_length - len(bbox)) for bbox in batch["bbox"]]
|
93
|
+
if has_position_input:
|
94
|
+
batch["position_ids"] = [position_id + [padding_idx] * (sequence_length - len(position_id))
|
95
|
+
for position_id in batch["position_ids"]]
|
96
|
+
|
97
|
+
else:
|
98
|
+
batch["labels"] = [[self.label_pad_token_id] * (sequence_length - len(label)) + label for label in labels]
|
99
|
+
if has_bbox_input:
|
100
|
+
batch["bbox"] = [[[0, 0, 0, 0]] * (sequence_length - len(bbox)) + bbox for bbox in batch["bbox"]]
|
101
|
+
if has_position_input:
|
102
|
+
batch["position_ids"] = [[padding_idx] * (sequence_length - len(position_id))
|
103
|
+
+ position_id for position_id in batch["position_ids"]]
|
104
|
+
|
105
|
+
if 'segment_ids' in batch:
|
106
|
+
assert 'position_ids' in batch
|
107
|
+
for i in range(len(batch['segment_ids'])):
|
108
|
+
batch['segment_ids'][i] = batch['segment_ids'][i] + [batch['segment_ids'][i][-1] + 1] * (sequence_length - len(batch['segment_ids'][i])) + [
|
109
|
+
batch['segment_ids'][i][-1] + 2] * IMAGE_LEN
|
110
|
+
|
111
|
+
batch = {k: torch.tensor(v, dtype=torch.int64) if isinstance(v[0], list) else v for k, v in batch.items()}
|
112
|
+
|
113
|
+
if 'segment_ids' in batch:
|
114
|
+
valid_span = pre_calc_rel_mat(
|
115
|
+
segment_ids=batch['segment_ids']
|
116
|
+
)
|
117
|
+
batch['valid_span'] = valid_span
|
118
|
+
del batch['segment_ids']
|
119
|
+
|
120
|
+
if images is not None:
|
121
|
+
visual_labels = torch.ones((len(batch['input_ids']), IMAGE_LEN), dtype=torch.long) * -100
|
122
|
+
batch["labels"] = torch.cat([batch['labels'], visual_labels], dim=1)
|
123
|
+
|
124
|
+
return batch
|
@@ -0,0 +1,136 @@
|
|
1
|
+
# coding=utf-8
|
2
|
+
'''
|
3
|
+
Reference: https://huggingface.co/datasets/nielsr/funsd/blob/main/funsd.py
|
4
|
+
'''
|
5
|
+
import json
|
6
|
+
import os
|
7
|
+
|
8
|
+
import datasets
|
9
|
+
|
10
|
+
from .image_utils import load_image, normalize_bbox
|
11
|
+
|
12
|
+
|
13
|
+
logger = datasets.logging.get_logger(__name__)
|
14
|
+
|
15
|
+
|
16
|
+
_CITATION = """\
|
17
|
+
@article{Jaume2019FUNSDAD,
|
18
|
+
title={FUNSD: A Dataset for Form Understanding in Noisy Scanned Documents},
|
19
|
+
author={Guillaume Jaume and H. K. Ekenel and J. Thiran},
|
20
|
+
journal={2019 International Conference on Document Analysis and Recognition Workshops (ICDARW)},
|
21
|
+
year={2019},
|
22
|
+
volume={2},
|
23
|
+
pages={1-6}
|
24
|
+
}
|
25
|
+
"""
|
26
|
+
|
27
|
+
_DESCRIPTION = """\
|
28
|
+
https://guillaumejaume.github.io/FUNSD/
|
29
|
+
"""
|
30
|
+
|
31
|
+
|
32
|
+
class FunsdConfig(datasets.BuilderConfig):
|
33
|
+
"""BuilderConfig for FUNSD"""
|
34
|
+
|
35
|
+
def __init__(self, **kwargs):
|
36
|
+
"""BuilderConfig for FUNSD.
|
37
|
+
|
38
|
+
Args:
|
39
|
+
**kwargs: keyword arguments forwarded to super.
|
40
|
+
"""
|
41
|
+
super(FunsdConfig, self).__init__(**kwargs)
|
42
|
+
|
43
|
+
|
44
|
+
class Funsd(datasets.GeneratorBasedBuilder):
|
45
|
+
"""Conll2003 dataset."""
|
46
|
+
|
47
|
+
BUILDER_CONFIGS = [
|
48
|
+
FunsdConfig(name="funsd", version=datasets.Version("1.0.0"), description="FUNSD dataset"),
|
49
|
+
]
|
50
|
+
|
51
|
+
def _info(self):
|
52
|
+
return datasets.DatasetInfo(
|
53
|
+
description=_DESCRIPTION,
|
54
|
+
features=datasets.Features(
|
55
|
+
{
|
56
|
+
"id": datasets.Value("string"),
|
57
|
+
"tokens": datasets.Sequence(datasets.Value("string")),
|
58
|
+
"bboxes": datasets.Sequence(datasets.Sequence(datasets.Value("int64"))),
|
59
|
+
"ner_tags": datasets.Sequence(
|
60
|
+
datasets.features.ClassLabel(
|
61
|
+
names=["O", "B-HEADER", "I-HEADER", "B-QUESTION", "I-QUESTION", "B-ANSWER", "I-ANSWER"]
|
62
|
+
)
|
63
|
+
),
|
64
|
+
"image": datasets.Array3D(shape=(3, 224, 224), dtype="uint8"),
|
65
|
+
"image_path": datasets.Value("string"),
|
66
|
+
}
|
67
|
+
),
|
68
|
+
supervised_keys=None,
|
69
|
+
homepage="https://guillaumejaume.github.io/FUNSD/",
|
70
|
+
citation=_CITATION,
|
71
|
+
)
|
72
|
+
|
73
|
+
def _split_generators(self, dl_manager):
|
74
|
+
"""Returns SplitGenerators."""
|
75
|
+
downloaded_file = dl_manager.download_and_extract("https://guillaumejaume.github.io/FUNSD/dataset.zip")
|
76
|
+
return [
|
77
|
+
datasets.SplitGenerator(
|
78
|
+
name=datasets.Split.TRAIN, gen_kwargs={"filepath": f"{downloaded_file}/dataset/training_data/"}
|
79
|
+
),
|
80
|
+
datasets.SplitGenerator(
|
81
|
+
name=datasets.Split.TEST, gen_kwargs={"filepath": f"{downloaded_file}/dataset/testing_data/"}
|
82
|
+
),
|
83
|
+
]
|
84
|
+
|
85
|
+
def get_line_bbox(self, bboxs):
|
86
|
+
x = [bboxs[i][j] for i in range(len(bboxs)) for j in range(0, len(bboxs[i]), 2)]
|
87
|
+
y = [bboxs[i][j] for i in range(len(bboxs)) for j in range(1, len(bboxs[i]), 2)]
|
88
|
+
|
89
|
+
x0, y0, x1, y1 = min(x), min(y), max(x), max(y)
|
90
|
+
|
91
|
+
assert x1 >= x0 and y1 >= y0
|
92
|
+
bbox = [[x0, y0, x1, y1] for _ in range(len(bboxs))]
|
93
|
+
return bbox
|
94
|
+
|
95
|
+
def _generate_examples(self, filepath):
|
96
|
+
logger.info("⏳ Generating examples from = %s", filepath)
|
97
|
+
ann_dir = os.path.join(filepath, "annotations")
|
98
|
+
img_dir = os.path.join(filepath, "images")
|
99
|
+
for guid, file in enumerate(sorted(os.listdir(ann_dir))):
|
100
|
+
tokens = []
|
101
|
+
bboxes = []
|
102
|
+
ner_tags = []
|
103
|
+
|
104
|
+
file_path = os.path.join(ann_dir, file)
|
105
|
+
with open(file_path, "r", encoding="utf8") as f:
|
106
|
+
data = json.load(f)
|
107
|
+
image_path = os.path.join(img_dir, file)
|
108
|
+
image_path = image_path.replace("json", "png")
|
109
|
+
image, size = load_image(image_path)
|
110
|
+
for item in data["form"]:
|
111
|
+
cur_line_bboxes = []
|
112
|
+
words, label = item["words"], item["label"]
|
113
|
+
words = [w for w in words if w["text"].strip() != ""]
|
114
|
+
if len(words) == 0:
|
115
|
+
continue
|
116
|
+
if label == "other":
|
117
|
+
for w in words:
|
118
|
+
tokens.append(w["text"])
|
119
|
+
ner_tags.append("O")
|
120
|
+
cur_line_bboxes.append(normalize_bbox(w["box"], size))
|
121
|
+
else:
|
122
|
+
tokens.append(words[0]["text"])
|
123
|
+
ner_tags.append("B-" + label.upper())
|
124
|
+
cur_line_bboxes.append(normalize_bbox(words[0]["box"], size))
|
125
|
+
for w in words[1:]:
|
126
|
+
tokens.append(w["text"])
|
127
|
+
ner_tags.append("I-" + label.upper())
|
128
|
+
cur_line_bboxes.append(normalize_bbox(w["box"], size))
|
129
|
+
# by default: --segment_level_layout 1
|
130
|
+
# if do not want to use segment_level_layout, comment the following line
|
131
|
+
cur_line_bboxes = self.get_line_bbox(cur_line_bboxes)
|
132
|
+
# box = normalize_bbox(item["box"], size)
|
133
|
+
# cur_line_bboxes = [box for _ in range(len(words))]
|
134
|
+
bboxes.extend(cur_line_bboxes)
|
135
|
+
yield guid, {"id": str(guid), "tokens": tokens, "bboxes": bboxes, "ner_tags": ner_tags,
|
136
|
+
"image": image, "image_path": image_path}
|
@@ -0,0 +1,284 @@
|
|
1
|
+
import torchvision.transforms.functional as F
|
2
|
+
import warnings
|
3
|
+
import math
|
4
|
+
import random
|
5
|
+
import numpy as np
|
6
|
+
from PIL import Image
|
7
|
+
import torch
|
8
|
+
|
9
|
+
from detectron2.data.detection_utils import read_image
|
10
|
+
from detectron2.data.transforms import ResizeTransform, TransformList
|
11
|
+
|
12
|
+
def normalize_bbox(bbox, size):
|
13
|
+
return [
|
14
|
+
int(1000 * bbox[0] / size[0]),
|
15
|
+
int(1000 * bbox[1] / size[1]),
|
16
|
+
int(1000 * bbox[2] / size[0]),
|
17
|
+
int(1000 * bbox[3] / size[1]),
|
18
|
+
]
|
19
|
+
|
20
|
+
|
21
|
+
def load_image(image_path):
|
22
|
+
image = read_image(image_path, format="BGR")
|
23
|
+
h = image.shape[0]
|
24
|
+
w = image.shape[1]
|
25
|
+
img_trans = TransformList([ResizeTransform(h=h, w=w, new_h=224, new_w=224)])
|
26
|
+
image = torch.tensor(img_trans.apply_image(image).copy()).permute(2, 0, 1) # copy to make it writeable
|
27
|
+
return image, (w, h)
|
28
|
+
|
29
|
+
|
30
|
+
def crop(image, i, j, h, w, boxes=None):
|
31
|
+
cropped_image = F.crop(image, i, j, h, w)
|
32
|
+
|
33
|
+
if boxes is not None:
|
34
|
+
# Currently we cannot use this case since when some boxes is out of the cropped image,
|
35
|
+
# it may be better to drop out these boxes along with their text input (instead of min or clamp)
|
36
|
+
# which haven't been implemented here
|
37
|
+
max_size = torch.as_tensor([w, h], dtype=torch.float32)
|
38
|
+
cropped_boxes = torch.as_tensor(boxes) - torch.as_tensor([j, i, j, i])
|
39
|
+
cropped_boxes = torch.min(cropped_boxes.reshape(-1, 2, 2), max_size)
|
40
|
+
cropped_boxes = cropped_boxes.clamp(min=0)
|
41
|
+
boxes = cropped_boxes.reshape(-1, 4)
|
42
|
+
|
43
|
+
return cropped_image, boxes
|
44
|
+
|
45
|
+
|
46
|
+
def resize(image, size, interpolation, boxes=None):
|
47
|
+
# It seems that we do not need to resize boxes here, since the boxes will be resized to 1000x1000 finally,
|
48
|
+
# which is compatible with a square image size of 224x224
|
49
|
+
rescaled_image = F.resize(image, size, interpolation)
|
50
|
+
|
51
|
+
if boxes is None:
|
52
|
+
return rescaled_image, None
|
53
|
+
|
54
|
+
ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(rescaled_image.size, image.size))
|
55
|
+
ratio_width, ratio_height = ratios
|
56
|
+
|
57
|
+
# boxes = boxes.copy()
|
58
|
+
scaled_boxes = boxes * torch.as_tensor([ratio_width, ratio_height, ratio_width, ratio_height])
|
59
|
+
|
60
|
+
return rescaled_image, scaled_boxes
|
61
|
+
|
62
|
+
|
63
|
+
def clamp(num, min_value, max_value):
|
64
|
+
return max(min(num, max_value), min_value)
|
65
|
+
|
66
|
+
|
67
|
+
def get_bb(bb, page_size):
|
68
|
+
bbs = [float(j) for j in bb]
|
69
|
+
xs, ys = [], []
|
70
|
+
for i, b in enumerate(bbs):
|
71
|
+
if i % 2 == 0:
|
72
|
+
xs.append(b)
|
73
|
+
else:
|
74
|
+
ys.append(b)
|
75
|
+
(width, height) = page_size
|
76
|
+
return_bb = [
|
77
|
+
clamp(min(xs), 0, width - 1),
|
78
|
+
clamp(min(ys), 0, height - 1),
|
79
|
+
clamp(max(xs), 0, width - 1),
|
80
|
+
clamp(max(ys), 0, height - 1),
|
81
|
+
]
|
82
|
+
return_bb = [
|
83
|
+
int(1000 * return_bb[0] / width),
|
84
|
+
int(1000 * return_bb[1] / height),
|
85
|
+
int(1000 * return_bb[2] / width),
|
86
|
+
int(1000 * return_bb[3] / height),
|
87
|
+
]
|
88
|
+
return return_bb
|
89
|
+
|
90
|
+
|
91
|
+
class ToNumpy:
|
92
|
+
|
93
|
+
def __call__(self, pil_img):
|
94
|
+
np_img = np.array(pil_img, dtype=np.uint8)
|
95
|
+
if np_img.ndim < 3:
|
96
|
+
np_img = np.expand_dims(np_img, axis=-1)
|
97
|
+
np_img = np.rollaxis(np_img, 2) # HWC to CHW
|
98
|
+
return np_img
|
99
|
+
|
100
|
+
|
101
|
+
class ToTensor:
|
102
|
+
|
103
|
+
def __init__(self, dtype=torch.float32):
|
104
|
+
self.dtype = dtype
|
105
|
+
|
106
|
+
def __call__(self, pil_img):
|
107
|
+
np_img = np.array(pil_img, dtype=np.uint8)
|
108
|
+
if np_img.ndim < 3:
|
109
|
+
np_img = np.expand_dims(np_img, axis=-1)
|
110
|
+
np_img = np.rollaxis(np_img, 2) # HWC to CHW
|
111
|
+
return torch.from_numpy(np_img).to(dtype=self.dtype)
|
112
|
+
|
113
|
+
|
114
|
+
_pil_interpolation_to_str = {
|
115
|
+
F.InterpolationMode.NEAREST: 'F.InterpolationMode.NEAREST',
|
116
|
+
F.InterpolationMode.BILINEAR: 'F.InterpolationMode.BILINEAR',
|
117
|
+
F.InterpolationMode.BICUBIC: 'F.InterpolationMode.BICUBIC',
|
118
|
+
F.InterpolationMode.LANCZOS: 'F.InterpolationMode.LANCZOS',
|
119
|
+
F.InterpolationMode.HAMMING: 'F.InterpolationMode.HAMMING',
|
120
|
+
F.InterpolationMode.BOX: 'F.InterpolationMode.BOX',
|
121
|
+
}
|
122
|
+
|
123
|
+
|
124
|
+
def _pil_interp(method):
|
125
|
+
if method == 'bicubic':
|
126
|
+
return F.InterpolationMode.BICUBIC
|
127
|
+
elif method == 'lanczos':
|
128
|
+
return F.InterpolationMode.LANCZOS
|
129
|
+
elif method == 'hamming':
|
130
|
+
return F.InterpolationMode.HAMMING
|
131
|
+
else:
|
132
|
+
# default bilinear, do we want to allow nearest?
|
133
|
+
return F.InterpolationMode.BILINEAR
|
134
|
+
|
135
|
+
|
136
|
+
class Compose:
|
137
|
+
"""Composes several transforms together. This transform does not support torchscript.
|
138
|
+
Please, see the note below.
|
139
|
+
|
140
|
+
Args:
|
141
|
+
transforms (list of ``Transform`` objects): list of transforms to compose.
|
142
|
+
|
143
|
+
Example:
|
144
|
+
>>> transforms.Compose([
|
145
|
+
>>> transforms.CenterCrop(10),
|
146
|
+
>>> transforms.PILToTensor(),
|
147
|
+
>>> transforms.ConvertImageDtype(torch.float),
|
148
|
+
>>> ])
|
149
|
+
|
150
|
+
.. note::
|
151
|
+
In order to script the transformations, please use ``torch.nn.Sequential`` as below.
|
152
|
+
|
153
|
+
>>> transforms = torch.nn.Sequential(
|
154
|
+
>>> transforms.CenterCrop(10),
|
155
|
+
>>> transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
|
156
|
+
>>> )
|
157
|
+
>>> scripted_transforms = torch.jit.script(transforms)
|
158
|
+
|
159
|
+
Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require
|
160
|
+
`lambda` functions or ``PIL.Image``.
|
161
|
+
|
162
|
+
"""
|
163
|
+
|
164
|
+
def __init__(self, transforms):
|
165
|
+
self.transforms = transforms
|
166
|
+
|
167
|
+
def __call__(self, img, augmentation=False, box=None):
|
168
|
+
for t in self.transforms:
|
169
|
+
img = t(img, augmentation, box)
|
170
|
+
return img
|
171
|
+
|
172
|
+
|
173
|
+
class RandomResizedCropAndInterpolationWithTwoPic:
|
174
|
+
"""Crop the given PIL Image to random size and aspect ratio with random interpolation.
|
175
|
+
A crop of random size (default: of 0.08 to 1.0) of the original size and a random
|
176
|
+
aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop
|
177
|
+
is finally resized to given size.
|
178
|
+
This is popularly used to train the Inception networks.
|
179
|
+
Args:
|
180
|
+
size: expected output size of each edge
|
181
|
+
scale: range of size of the origin size cropped
|
182
|
+
ratio: range of aspect ratio of the origin aspect ratio cropped
|
183
|
+
interpolation: Default: PIL.Image.BILINEAR
|
184
|
+
"""
|
185
|
+
|
186
|
+
def __init__(self, size, second_size=None, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.),
|
187
|
+
interpolation='bilinear', second_interpolation='lanczos'):
|
188
|
+
if isinstance(size, tuple):
|
189
|
+
self.size = size
|
190
|
+
else:
|
191
|
+
self.size = (size, size)
|
192
|
+
if second_size is not None:
|
193
|
+
if isinstance(second_size, tuple):
|
194
|
+
self.second_size = second_size
|
195
|
+
else:
|
196
|
+
self.second_size = (second_size, second_size)
|
197
|
+
else:
|
198
|
+
self.second_size = None
|
199
|
+
if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
|
200
|
+
warnings.warn("range should be of kind (min, max)")
|
201
|
+
|
202
|
+
self.interpolation = _pil_interp(interpolation)
|
203
|
+
self.second_interpolation = _pil_interp(second_interpolation)
|
204
|
+
self.scale = scale
|
205
|
+
self.ratio = ratio
|
206
|
+
|
207
|
+
@staticmethod
|
208
|
+
def get_params(img, scale, ratio):
|
209
|
+
"""Get parameters for ``crop`` for a random sized crop.
|
210
|
+
Args:
|
211
|
+
img (PIL Image): Image to be cropped.
|
212
|
+
scale (tuple): range of size of the origin size cropped
|
213
|
+
ratio (tuple): range of aspect ratio of the origin aspect ratio cropped
|
214
|
+
Returns:
|
215
|
+
tuple: params (i, j, h, w) to be passed to ``crop`` for a random
|
216
|
+
sized crop.
|
217
|
+
"""
|
218
|
+
area = img.size[0] * img.size[1]
|
219
|
+
|
220
|
+
for attempt in range(10):
|
221
|
+
target_area = random.uniform(*scale) * area
|
222
|
+
log_ratio = (math.log(ratio[0]), math.log(ratio[1]))
|
223
|
+
aspect_ratio = math.exp(random.uniform(*log_ratio))
|
224
|
+
|
225
|
+
w = int(round(math.sqrt(target_area * aspect_ratio)))
|
226
|
+
h = int(round(math.sqrt(target_area / aspect_ratio)))
|
227
|
+
|
228
|
+
if w <= img.size[0] and h <= img.size[1]:
|
229
|
+
i = random.randint(0, img.size[1] - h)
|
230
|
+
j = random.randint(0, img.size[0] - w)
|
231
|
+
return i, j, h, w
|
232
|
+
|
233
|
+
# Fallback to central crop
|
234
|
+
in_ratio = img.size[0] / img.size[1]
|
235
|
+
if in_ratio < min(ratio):
|
236
|
+
w = img.size[0]
|
237
|
+
h = int(round(w / min(ratio)))
|
238
|
+
elif in_ratio > max(ratio):
|
239
|
+
h = img.size[1]
|
240
|
+
w = int(round(h * max(ratio)))
|
241
|
+
else: # whole image
|
242
|
+
w = img.size[0]
|
243
|
+
h = img.size[1]
|
244
|
+
i = (img.size[1] - h) // 2
|
245
|
+
j = (img.size[0] - w) // 2
|
246
|
+
return i, j, h, w
|
247
|
+
|
248
|
+
def __call__(self, img, augmentation=False, box=None):
|
249
|
+
"""
|
250
|
+
Args:
|
251
|
+
img (PIL Image): Image to be cropped and resized.
|
252
|
+
Returns:
|
253
|
+
PIL Image: Randomly cropped and resized image.
|
254
|
+
"""
|
255
|
+
if augmentation:
|
256
|
+
i, j, h, w = self.get_params(img, self.scale, self.ratio)
|
257
|
+
img = F.crop(img, i, j, h, w)
|
258
|
+
# img, box = crop(img, i, j, h, w, box)
|
259
|
+
img = F.resize(img, self.size, self.interpolation)
|
260
|
+
second_img = F.resize(img, self.second_size, self.second_interpolation) \
|
261
|
+
if self.second_size is not None else None
|
262
|
+
return img, second_img
|
263
|
+
|
264
|
+
def __repr__(self):
|
265
|
+
if isinstance(self.interpolation, (tuple, list)):
|
266
|
+
interpolate_str = ' '.join([_pil_interpolation_to_str[x] for x in self.interpolation])
|
267
|
+
else:
|
268
|
+
interpolate_str = _pil_interpolation_to_str[self.interpolation]
|
269
|
+
format_string = self.__class__.__name__ + '(size={0}'.format(self.size)
|
270
|
+
format_string += ', scale={0}'.format(tuple(round(s, 4) for s in self.scale))
|
271
|
+
format_string += ', ratio={0}'.format(tuple(round(r, 4) for r in self.ratio))
|
272
|
+
format_string += ', interpolation={0}'.format(interpolate_str)
|
273
|
+
if self.second_size is not None:
|
274
|
+
format_string += ', second_size={0}'.format(self.second_size)
|
275
|
+
format_string += ', second_interpolation={0}'.format(_pil_interpolation_to_str[self.second_interpolation])
|
276
|
+
format_string += ')'
|
277
|
+
return format_string
|
278
|
+
|
279
|
+
|
280
|
+
def pil_loader(path: str) -> Image.Image:
|
281
|
+
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
|
282
|
+
with open(path, 'rb') as f:
|
283
|
+
img = Image.open(f)
|
284
|
+
return img.convert('RGB')
|