magic-pdf 0.5.13__py3-none-any.whl → 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. magic_pdf/cli/magicpdf.py +18 -7
  2. magic_pdf/libs/config_reader.py +10 -0
  3. magic_pdf/libs/version.py +1 -1
  4. magic_pdf/model/__init__.py +1 -0
  5. magic_pdf/model/doc_analyze_by_custom_model.py +38 -15
  6. magic_pdf/model/model_list.py +1 -0
  7. magic_pdf/model/pdf_extract_kit.py +196 -0
  8. magic_pdf/model/pek_sub_modules/__init__.py +0 -0
  9. magic_pdf/model/pek_sub_modules/layoutlmv3/__init__.py +0 -0
  10. magic_pdf/model/pek_sub_modules/layoutlmv3/backbone.py +179 -0
  11. magic_pdf/model/pek_sub_modules/layoutlmv3/beit.py +671 -0
  12. magic_pdf/model/pek_sub_modules/layoutlmv3/deit.py +476 -0
  13. magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/__init__.py +7 -0
  14. magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/data/__init__.py +2 -0
  15. magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/data/cord.py +171 -0
  16. magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/data/data_collator.py +124 -0
  17. magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/data/funsd.py +136 -0
  18. magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/data/image_utils.py +284 -0
  19. magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/data/xfund.py +213 -0
  20. magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/models/__init__.py +7 -0
  21. magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/models/layoutlmv3/__init__.py +24 -0
  22. magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/models/layoutlmv3/configuration_layoutlmv3.py +60 -0
  23. magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/models/layoutlmv3/modeling_layoutlmv3.py +1282 -0
  24. magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/models/layoutlmv3/tokenization_layoutlmv3.py +32 -0
  25. magic_pdf/model/pek_sub_modules/layoutlmv3/layoutlmft/models/layoutlmv3/tokenization_layoutlmv3_fast.py +34 -0
  26. magic_pdf/model/pek_sub_modules/layoutlmv3/model_init.py +150 -0
  27. magic_pdf/model/pek_sub_modules/layoutlmv3/rcnn_vl.py +163 -0
  28. magic_pdf/model/pek_sub_modules/layoutlmv3/visualizer.py +1236 -0
  29. magic_pdf/model/pek_sub_modules/post_process.py +36 -0
  30. magic_pdf/model/pek_sub_modules/self_modify.py +260 -0
  31. magic_pdf/model/pp_structure_v2.py +7 -0
  32. magic_pdf/pipe/AbsPipe.py +8 -14
  33. magic_pdf/pipe/OCRPipe.py +12 -8
  34. magic_pdf/pipe/TXTPipe.py +12 -8
  35. magic_pdf/pipe/UNIPipe.py +9 -7
  36. magic_pdf/resources/model_config/UniMERNet/demo.yaml +46 -0
  37. magic_pdf/resources/model_config/layoutlmv3/layoutlmv3_base_inference.yaml +351 -0
  38. magic_pdf/resources/model_config/model_configs.yaml +9 -0
  39. {magic_pdf-0.5.13.dist-info → magic_pdf-0.6.0.dist-info}/METADATA +18 -8
  40. {magic_pdf-0.5.13.dist-info → magic_pdf-0.6.0.dist-info}/RECORD +44 -18
  41. magic_pdf/model/360_layout_analysis.py +0 -8
  42. {magic_pdf-0.5.13.dist-info → magic_pdf-0.6.0.dist-info}/LICENSE.md +0 -0
  43. {magic_pdf-0.5.13.dist-info → magic_pdf-0.6.0.dist-info}/WHEEL +0 -0
  44. {magic_pdf-0.5.13.dist-info → magic_pdf-0.6.0.dist-info}/entry_points.txt +0 -0
  45. {magic_pdf-0.5.13.dist-info → magic_pdf-0.6.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,124 @@
1
+ import torch
2
+ from dataclasses import dataclass
3
+ from typing import Any, Dict, List, Optional, Tuple, Union
4
+
5
+ from transformers import BatchEncoding, PreTrainedTokenizerBase
6
+ from transformers.data.data_collator import (
7
+ DataCollatorMixin,
8
+ _torch_collate_batch,
9
+ )
10
+ from transformers.file_utils import PaddingStrategy
11
+
12
+ from typing import NewType
13
+ InputDataClass = NewType("InputDataClass", Any)
14
+
15
+ def pre_calc_rel_mat(segment_ids):
16
+ valid_span = torch.zeros((segment_ids.shape[0], segment_ids.shape[1], segment_ids.shape[1]),
17
+ device=segment_ids.device, dtype=torch.bool)
18
+ for i in range(segment_ids.shape[0]):
19
+ for j in range(segment_ids.shape[1]):
20
+ valid_span[i, j, :] = segment_ids[i, :] == segment_ids[i, j]
21
+
22
+ return valid_span
23
+
24
+ @dataclass
25
+ class DataCollatorForKeyValueExtraction(DataCollatorMixin):
26
+ """
27
+ Data collator that will dynamically pad the inputs received, as well as the labels.
28
+ Args:
29
+ tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
30
+ The tokenizer used for encoding the data.
31
+ padding (:obj:`bool`, :obj:`str` or :class:`~transformers.file_utils.PaddingStrategy`, `optional`, defaults to :obj:`True`):
32
+ Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
33
+ among:
34
+ * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
35
+ sequence if provided).
36
+ * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
37
+ maximum acceptable input length for the model if that argument is not provided.
38
+ * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
39
+ different lengths).
40
+ max_length (:obj:`int`, `optional`):
41
+ Maximum length of the returned list and optionally padding length (see above).
42
+ pad_to_multiple_of (:obj:`int`, `optional`):
43
+ If set will pad the sequence to a multiple of the provided value.
44
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
45
+ 7.5 (Volta).
46
+ label_pad_token_id (:obj:`int`, `optional`, defaults to -100):
47
+ The id to use when padding the labels (-100 will be automatically ignore by PyTorch loss functions).
48
+ """
49
+
50
+ tokenizer: PreTrainedTokenizerBase
51
+ padding: Union[bool, str, PaddingStrategy] = True
52
+ max_length: Optional[int] = None
53
+ pad_to_multiple_of: Optional[int] = None
54
+ label_pad_token_id: int = -100
55
+
56
+ def __call__(self, features):
57
+ label_name = "label" if "label" in features[0].keys() else "labels"
58
+ labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None
59
+
60
+ images = None
61
+ if "images" in features[0]:
62
+ images = torch.stack([torch.tensor(d.pop("images")) for d in features])
63
+ IMAGE_LEN = int(images.shape[-1] / 16) * int(images.shape[-1] / 16) + 1
64
+
65
+ batch = self.tokenizer.pad(
66
+ features,
67
+ padding=self.padding,
68
+ max_length=self.max_length,
69
+ pad_to_multiple_of=self.pad_to_multiple_of,
70
+ # Conversion to tensors will fail if we have labels as they are not of the same length yet.
71
+ return_tensors="pt" if labels is None else None,
72
+ )
73
+
74
+ if images is not None:
75
+ batch["images"] = images
76
+ batch = {k: torch.tensor(v, dtype=torch.int64) if isinstance(v[0], list) and k == 'attention_mask' else v
77
+ for k, v in batch.items()}
78
+ visual_attention_mask = torch.ones((len(batch['input_ids']), IMAGE_LEN), dtype=torch.long)
79
+ batch["attention_mask"] = torch.cat([batch['attention_mask'], visual_attention_mask], dim=1)
80
+
81
+ if labels is None:
82
+ return batch
83
+
84
+ has_bbox_input = "bbox" in features[0]
85
+ has_position_input = "position_ids" in features[0]
86
+ padding_idx=self.tokenizer.pad_token_id
87
+ sequence_length = torch.tensor(batch["input_ids"]).shape[1]
88
+ padding_side = self.tokenizer.padding_side
89
+ if padding_side == "right":
90
+ batch["labels"] = [label + [self.label_pad_token_id] * (sequence_length - len(label)) for label in labels]
91
+ if has_bbox_input:
92
+ batch["bbox"] = [bbox + [[0, 0, 0, 0]] * (sequence_length - len(bbox)) for bbox in batch["bbox"]]
93
+ if has_position_input:
94
+ batch["position_ids"] = [position_id + [padding_idx] * (sequence_length - len(position_id))
95
+ for position_id in batch["position_ids"]]
96
+
97
+ else:
98
+ batch["labels"] = [[self.label_pad_token_id] * (sequence_length - len(label)) + label for label in labels]
99
+ if has_bbox_input:
100
+ batch["bbox"] = [[[0, 0, 0, 0]] * (sequence_length - len(bbox)) + bbox for bbox in batch["bbox"]]
101
+ if has_position_input:
102
+ batch["position_ids"] = [[padding_idx] * (sequence_length - len(position_id))
103
+ + position_id for position_id in batch["position_ids"]]
104
+
105
+ if 'segment_ids' in batch:
106
+ assert 'position_ids' in batch
107
+ for i in range(len(batch['segment_ids'])):
108
+ batch['segment_ids'][i] = batch['segment_ids'][i] + [batch['segment_ids'][i][-1] + 1] * (sequence_length - len(batch['segment_ids'][i])) + [
109
+ batch['segment_ids'][i][-1] + 2] * IMAGE_LEN
110
+
111
+ batch = {k: torch.tensor(v, dtype=torch.int64) if isinstance(v[0], list) else v for k, v in batch.items()}
112
+
113
+ if 'segment_ids' in batch:
114
+ valid_span = pre_calc_rel_mat(
115
+ segment_ids=batch['segment_ids']
116
+ )
117
+ batch['valid_span'] = valid_span
118
+ del batch['segment_ids']
119
+
120
+ if images is not None:
121
+ visual_labels = torch.ones((len(batch['input_ids']), IMAGE_LEN), dtype=torch.long) * -100
122
+ batch["labels"] = torch.cat([batch['labels'], visual_labels], dim=1)
123
+
124
+ return batch
@@ -0,0 +1,136 @@
1
+ # coding=utf-8
2
+ '''
3
+ Reference: https://huggingface.co/datasets/nielsr/funsd/blob/main/funsd.py
4
+ '''
5
+ import json
6
+ import os
7
+
8
+ import datasets
9
+
10
+ from .image_utils import load_image, normalize_bbox
11
+
12
+
13
+ logger = datasets.logging.get_logger(__name__)
14
+
15
+
16
+ _CITATION = """\
17
+ @article{Jaume2019FUNSDAD,
18
+ title={FUNSD: A Dataset for Form Understanding in Noisy Scanned Documents},
19
+ author={Guillaume Jaume and H. K. Ekenel and J. Thiran},
20
+ journal={2019 International Conference on Document Analysis and Recognition Workshops (ICDARW)},
21
+ year={2019},
22
+ volume={2},
23
+ pages={1-6}
24
+ }
25
+ """
26
+
27
+ _DESCRIPTION = """\
28
+ https://guillaumejaume.github.io/FUNSD/
29
+ """
30
+
31
+
32
+ class FunsdConfig(datasets.BuilderConfig):
33
+ """BuilderConfig for FUNSD"""
34
+
35
+ def __init__(self, **kwargs):
36
+ """BuilderConfig for FUNSD.
37
+
38
+ Args:
39
+ **kwargs: keyword arguments forwarded to super.
40
+ """
41
+ super(FunsdConfig, self).__init__(**kwargs)
42
+
43
+
44
+ class Funsd(datasets.GeneratorBasedBuilder):
45
+ """Conll2003 dataset."""
46
+
47
+ BUILDER_CONFIGS = [
48
+ FunsdConfig(name="funsd", version=datasets.Version("1.0.0"), description="FUNSD dataset"),
49
+ ]
50
+
51
+ def _info(self):
52
+ return datasets.DatasetInfo(
53
+ description=_DESCRIPTION,
54
+ features=datasets.Features(
55
+ {
56
+ "id": datasets.Value("string"),
57
+ "tokens": datasets.Sequence(datasets.Value("string")),
58
+ "bboxes": datasets.Sequence(datasets.Sequence(datasets.Value("int64"))),
59
+ "ner_tags": datasets.Sequence(
60
+ datasets.features.ClassLabel(
61
+ names=["O", "B-HEADER", "I-HEADER", "B-QUESTION", "I-QUESTION", "B-ANSWER", "I-ANSWER"]
62
+ )
63
+ ),
64
+ "image": datasets.Array3D(shape=(3, 224, 224), dtype="uint8"),
65
+ "image_path": datasets.Value("string"),
66
+ }
67
+ ),
68
+ supervised_keys=None,
69
+ homepage="https://guillaumejaume.github.io/FUNSD/",
70
+ citation=_CITATION,
71
+ )
72
+
73
+ def _split_generators(self, dl_manager):
74
+ """Returns SplitGenerators."""
75
+ downloaded_file = dl_manager.download_and_extract("https://guillaumejaume.github.io/FUNSD/dataset.zip")
76
+ return [
77
+ datasets.SplitGenerator(
78
+ name=datasets.Split.TRAIN, gen_kwargs={"filepath": f"{downloaded_file}/dataset/training_data/"}
79
+ ),
80
+ datasets.SplitGenerator(
81
+ name=datasets.Split.TEST, gen_kwargs={"filepath": f"{downloaded_file}/dataset/testing_data/"}
82
+ ),
83
+ ]
84
+
85
+ def get_line_bbox(self, bboxs):
86
+ x = [bboxs[i][j] for i in range(len(bboxs)) for j in range(0, len(bboxs[i]), 2)]
87
+ y = [bboxs[i][j] for i in range(len(bboxs)) for j in range(1, len(bboxs[i]), 2)]
88
+
89
+ x0, y0, x1, y1 = min(x), min(y), max(x), max(y)
90
+
91
+ assert x1 >= x0 and y1 >= y0
92
+ bbox = [[x0, y0, x1, y1] for _ in range(len(bboxs))]
93
+ return bbox
94
+
95
+ def _generate_examples(self, filepath):
96
+ logger.info("⏳ Generating examples from = %s", filepath)
97
+ ann_dir = os.path.join(filepath, "annotations")
98
+ img_dir = os.path.join(filepath, "images")
99
+ for guid, file in enumerate(sorted(os.listdir(ann_dir))):
100
+ tokens = []
101
+ bboxes = []
102
+ ner_tags = []
103
+
104
+ file_path = os.path.join(ann_dir, file)
105
+ with open(file_path, "r", encoding="utf8") as f:
106
+ data = json.load(f)
107
+ image_path = os.path.join(img_dir, file)
108
+ image_path = image_path.replace("json", "png")
109
+ image, size = load_image(image_path)
110
+ for item in data["form"]:
111
+ cur_line_bboxes = []
112
+ words, label = item["words"], item["label"]
113
+ words = [w for w in words if w["text"].strip() != ""]
114
+ if len(words) == 0:
115
+ continue
116
+ if label == "other":
117
+ for w in words:
118
+ tokens.append(w["text"])
119
+ ner_tags.append("O")
120
+ cur_line_bboxes.append(normalize_bbox(w["box"], size))
121
+ else:
122
+ tokens.append(words[0]["text"])
123
+ ner_tags.append("B-" + label.upper())
124
+ cur_line_bboxes.append(normalize_bbox(words[0]["box"], size))
125
+ for w in words[1:]:
126
+ tokens.append(w["text"])
127
+ ner_tags.append("I-" + label.upper())
128
+ cur_line_bboxes.append(normalize_bbox(w["box"], size))
129
+ # by default: --segment_level_layout 1
130
+ # if do not want to use segment_level_layout, comment the following line
131
+ cur_line_bboxes = self.get_line_bbox(cur_line_bboxes)
132
+ # box = normalize_bbox(item["box"], size)
133
+ # cur_line_bboxes = [box for _ in range(len(words))]
134
+ bboxes.extend(cur_line_bboxes)
135
+ yield guid, {"id": str(guid), "tokens": tokens, "bboxes": bboxes, "ner_tags": ner_tags,
136
+ "image": image, "image_path": image_path}
@@ -0,0 +1,284 @@
1
+ import torchvision.transforms.functional as F
2
+ import warnings
3
+ import math
4
+ import random
5
+ import numpy as np
6
+ from PIL import Image
7
+ import torch
8
+
9
+ from detectron2.data.detection_utils import read_image
10
+ from detectron2.data.transforms import ResizeTransform, TransformList
11
+
12
+ def normalize_bbox(bbox, size):
13
+ return [
14
+ int(1000 * bbox[0] / size[0]),
15
+ int(1000 * bbox[1] / size[1]),
16
+ int(1000 * bbox[2] / size[0]),
17
+ int(1000 * bbox[3] / size[1]),
18
+ ]
19
+
20
+
21
+ def load_image(image_path):
22
+ image = read_image(image_path, format="BGR")
23
+ h = image.shape[0]
24
+ w = image.shape[1]
25
+ img_trans = TransformList([ResizeTransform(h=h, w=w, new_h=224, new_w=224)])
26
+ image = torch.tensor(img_trans.apply_image(image).copy()).permute(2, 0, 1) # copy to make it writeable
27
+ return image, (w, h)
28
+
29
+
30
+ def crop(image, i, j, h, w, boxes=None):
31
+ cropped_image = F.crop(image, i, j, h, w)
32
+
33
+ if boxes is not None:
34
+ # Currently we cannot use this case since when some boxes is out of the cropped image,
35
+ # it may be better to drop out these boxes along with their text input (instead of min or clamp)
36
+ # which haven't been implemented here
37
+ max_size = torch.as_tensor([w, h], dtype=torch.float32)
38
+ cropped_boxes = torch.as_tensor(boxes) - torch.as_tensor([j, i, j, i])
39
+ cropped_boxes = torch.min(cropped_boxes.reshape(-1, 2, 2), max_size)
40
+ cropped_boxes = cropped_boxes.clamp(min=0)
41
+ boxes = cropped_boxes.reshape(-1, 4)
42
+
43
+ return cropped_image, boxes
44
+
45
+
46
+ def resize(image, size, interpolation, boxes=None):
47
+ # It seems that we do not need to resize boxes here, since the boxes will be resized to 1000x1000 finally,
48
+ # which is compatible with a square image size of 224x224
49
+ rescaled_image = F.resize(image, size, interpolation)
50
+
51
+ if boxes is None:
52
+ return rescaled_image, None
53
+
54
+ ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(rescaled_image.size, image.size))
55
+ ratio_width, ratio_height = ratios
56
+
57
+ # boxes = boxes.copy()
58
+ scaled_boxes = boxes * torch.as_tensor([ratio_width, ratio_height, ratio_width, ratio_height])
59
+
60
+ return rescaled_image, scaled_boxes
61
+
62
+
63
+ def clamp(num, min_value, max_value):
64
+ return max(min(num, max_value), min_value)
65
+
66
+
67
+ def get_bb(bb, page_size):
68
+ bbs = [float(j) for j in bb]
69
+ xs, ys = [], []
70
+ for i, b in enumerate(bbs):
71
+ if i % 2 == 0:
72
+ xs.append(b)
73
+ else:
74
+ ys.append(b)
75
+ (width, height) = page_size
76
+ return_bb = [
77
+ clamp(min(xs), 0, width - 1),
78
+ clamp(min(ys), 0, height - 1),
79
+ clamp(max(xs), 0, width - 1),
80
+ clamp(max(ys), 0, height - 1),
81
+ ]
82
+ return_bb = [
83
+ int(1000 * return_bb[0] / width),
84
+ int(1000 * return_bb[1] / height),
85
+ int(1000 * return_bb[2] / width),
86
+ int(1000 * return_bb[3] / height),
87
+ ]
88
+ return return_bb
89
+
90
+
91
+ class ToNumpy:
92
+
93
+ def __call__(self, pil_img):
94
+ np_img = np.array(pil_img, dtype=np.uint8)
95
+ if np_img.ndim < 3:
96
+ np_img = np.expand_dims(np_img, axis=-1)
97
+ np_img = np.rollaxis(np_img, 2) # HWC to CHW
98
+ return np_img
99
+
100
+
101
+ class ToTensor:
102
+
103
+ def __init__(self, dtype=torch.float32):
104
+ self.dtype = dtype
105
+
106
+ def __call__(self, pil_img):
107
+ np_img = np.array(pil_img, dtype=np.uint8)
108
+ if np_img.ndim < 3:
109
+ np_img = np.expand_dims(np_img, axis=-1)
110
+ np_img = np.rollaxis(np_img, 2) # HWC to CHW
111
+ return torch.from_numpy(np_img).to(dtype=self.dtype)
112
+
113
+
114
+ _pil_interpolation_to_str = {
115
+ F.InterpolationMode.NEAREST: 'F.InterpolationMode.NEAREST',
116
+ F.InterpolationMode.BILINEAR: 'F.InterpolationMode.BILINEAR',
117
+ F.InterpolationMode.BICUBIC: 'F.InterpolationMode.BICUBIC',
118
+ F.InterpolationMode.LANCZOS: 'F.InterpolationMode.LANCZOS',
119
+ F.InterpolationMode.HAMMING: 'F.InterpolationMode.HAMMING',
120
+ F.InterpolationMode.BOX: 'F.InterpolationMode.BOX',
121
+ }
122
+
123
+
124
+ def _pil_interp(method):
125
+ if method == 'bicubic':
126
+ return F.InterpolationMode.BICUBIC
127
+ elif method == 'lanczos':
128
+ return F.InterpolationMode.LANCZOS
129
+ elif method == 'hamming':
130
+ return F.InterpolationMode.HAMMING
131
+ else:
132
+ # default bilinear, do we want to allow nearest?
133
+ return F.InterpolationMode.BILINEAR
134
+
135
+
136
+ class Compose:
137
+ """Composes several transforms together. This transform does not support torchscript.
138
+ Please, see the note below.
139
+
140
+ Args:
141
+ transforms (list of ``Transform`` objects): list of transforms to compose.
142
+
143
+ Example:
144
+ >>> transforms.Compose([
145
+ >>> transforms.CenterCrop(10),
146
+ >>> transforms.PILToTensor(),
147
+ >>> transforms.ConvertImageDtype(torch.float),
148
+ >>> ])
149
+
150
+ .. note::
151
+ In order to script the transformations, please use ``torch.nn.Sequential`` as below.
152
+
153
+ >>> transforms = torch.nn.Sequential(
154
+ >>> transforms.CenterCrop(10),
155
+ >>> transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
156
+ >>> )
157
+ >>> scripted_transforms = torch.jit.script(transforms)
158
+
159
+ Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require
160
+ `lambda` functions or ``PIL.Image``.
161
+
162
+ """
163
+
164
+ def __init__(self, transforms):
165
+ self.transforms = transforms
166
+
167
+ def __call__(self, img, augmentation=False, box=None):
168
+ for t in self.transforms:
169
+ img = t(img, augmentation, box)
170
+ return img
171
+
172
+
173
+ class RandomResizedCropAndInterpolationWithTwoPic:
174
+ """Crop the given PIL Image to random size and aspect ratio with random interpolation.
175
+ A crop of random size (default: of 0.08 to 1.0) of the original size and a random
176
+ aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop
177
+ is finally resized to given size.
178
+ This is popularly used to train the Inception networks.
179
+ Args:
180
+ size: expected output size of each edge
181
+ scale: range of size of the origin size cropped
182
+ ratio: range of aspect ratio of the origin aspect ratio cropped
183
+ interpolation: Default: PIL.Image.BILINEAR
184
+ """
185
+
186
+ def __init__(self, size, second_size=None, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.),
187
+ interpolation='bilinear', second_interpolation='lanczos'):
188
+ if isinstance(size, tuple):
189
+ self.size = size
190
+ else:
191
+ self.size = (size, size)
192
+ if second_size is not None:
193
+ if isinstance(second_size, tuple):
194
+ self.second_size = second_size
195
+ else:
196
+ self.second_size = (second_size, second_size)
197
+ else:
198
+ self.second_size = None
199
+ if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
200
+ warnings.warn("range should be of kind (min, max)")
201
+
202
+ self.interpolation = _pil_interp(interpolation)
203
+ self.second_interpolation = _pil_interp(second_interpolation)
204
+ self.scale = scale
205
+ self.ratio = ratio
206
+
207
+ @staticmethod
208
+ def get_params(img, scale, ratio):
209
+ """Get parameters for ``crop`` for a random sized crop.
210
+ Args:
211
+ img (PIL Image): Image to be cropped.
212
+ scale (tuple): range of size of the origin size cropped
213
+ ratio (tuple): range of aspect ratio of the origin aspect ratio cropped
214
+ Returns:
215
+ tuple: params (i, j, h, w) to be passed to ``crop`` for a random
216
+ sized crop.
217
+ """
218
+ area = img.size[0] * img.size[1]
219
+
220
+ for attempt in range(10):
221
+ target_area = random.uniform(*scale) * area
222
+ log_ratio = (math.log(ratio[0]), math.log(ratio[1]))
223
+ aspect_ratio = math.exp(random.uniform(*log_ratio))
224
+
225
+ w = int(round(math.sqrt(target_area * aspect_ratio)))
226
+ h = int(round(math.sqrt(target_area / aspect_ratio)))
227
+
228
+ if w <= img.size[0] and h <= img.size[1]:
229
+ i = random.randint(0, img.size[1] - h)
230
+ j = random.randint(0, img.size[0] - w)
231
+ return i, j, h, w
232
+
233
+ # Fallback to central crop
234
+ in_ratio = img.size[0] / img.size[1]
235
+ if in_ratio < min(ratio):
236
+ w = img.size[0]
237
+ h = int(round(w / min(ratio)))
238
+ elif in_ratio > max(ratio):
239
+ h = img.size[1]
240
+ w = int(round(h * max(ratio)))
241
+ else: # whole image
242
+ w = img.size[0]
243
+ h = img.size[1]
244
+ i = (img.size[1] - h) // 2
245
+ j = (img.size[0] - w) // 2
246
+ return i, j, h, w
247
+
248
+ def __call__(self, img, augmentation=False, box=None):
249
+ """
250
+ Args:
251
+ img (PIL Image): Image to be cropped and resized.
252
+ Returns:
253
+ PIL Image: Randomly cropped and resized image.
254
+ """
255
+ if augmentation:
256
+ i, j, h, w = self.get_params(img, self.scale, self.ratio)
257
+ img = F.crop(img, i, j, h, w)
258
+ # img, box = crop(img, i, j, h, w, box)
259
+ img = F.resize(img, self.size, self.interpolation)
260
+ second_img = F.resize(img, self.second_size, self.second_interpolation) \
261
+ if self.second_size is not None else None
262
+ return img, second_img
263
+
264
+ def __repr__(self):
265
+ if isinstance(self.interpolation, (tuple, list)):
266
+ interpolate_str = ' '.join([_pil_interpolation_to_str[x] for x in self.interpolation])
267
+ else:
268
+ interpolate_str = _pil_interpolation_to_str[self.interpolation]
269
+ format_string = self.__class__.__name__ + '(size={0}'.format(self.size)
270
+ format_string += ', scale={0}'.format(tuple(round(s, 4) for s in self.scale))
271
+ format_string += ', ratio={0}'.format(tuple(round(r, 4) for r in self.ratio))
272
+ format_string += ', interpolation={0}'.format(interpolate_str)
273
+ if self.second_size is not None:
274
+ format_string += ', second_size={0}'.format(self.second_size)
275
+ format_string += ', second_interpolation={0}'.format(_pil_interpolation_to_str[self.second_interpolation])
276
+ format_string += ')'
277
+ return format_string
278
+
279
+
280
+ def pil_loader(path: str) -> Image.Image:
281
+ # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
282
+ with open(path, 'rb') as f:
283
+ img = Image.open(f)
284
+ return img.convert('RGB')