custom-layoutparser 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- custom_layoutparser-0.1.0.dist-info/METADATA +5 -0
- custom_layoutparser-0.1.0.dist-info/RECORD +36 -0
- custom_layoutparser-0.1.0.dist-info/WHEEL +5 -0
- custom_layoutparser-0.1.0.dist-info/top_level.txt +1 -0
- layoutparser/__init__.py +89 -0
- layoutparser/elements/__init__.py +25 -0
- layoutparser/elements/base.py +275 -0
- layoutparser/elements/errors.py +26 -0
- layoutparser/elements/layout.py +348 -0
- layoutparser/elements/layout_elements.py +1352 -0
- layoutparser/elements/utils.py +82 -0
- layoutparser/file_utils.py +235 -0
- layoutparser/io/__init__.py +2 -0
- layoutparser/io/basic.py +148 -0
- layoutparser/io/pdf.py +225 -0
- layoutparser/models/__init__.py +18 -0
- layoutparser/models/auto_layoutmodel.py +70 -0
- layoutparser/models/base_catalog.py +34 -0
- layoutparser/models/base_layoutmodel.py +88 -0
- layoutparser/models/detectron2/__init__.py +18 -0
- layoutparser/models/detectron2/catalog.py +142 -0
- layoutparser/models/detectron2/layoutmodel.py +168 -0
- layoutparser/models/effdet/__init__.py +16 -0
- layoutparser/models/effdet/catalog.py +88 -0
- layoutparser/models/effdet/layoutmodel.py +256 -0
- layoutparser/models/model_config.py +133 -0
- layoutparser/models/paddledetection/__init__.py +17 -0
- layoutparser/models/paddledetection/catalog.py +214 -0
- layoutparser/models/paddledetection/layoutmodel.py +297 -0
- layoutparser/ocr/__init__.py +16 -0
- layoutparser/ocr/base.py +41 -0
- layoutparser/ocr/gcv_agent.py +288 -0
- layoutparser/ocr/tesseract_agent.py +193 -0
- layoutparser/tools/__init__.py +5 -0
- layoutparser/tools/shape_operations.py +167 -0
- layoutparser/visualization.py +571 -0
|
@@ -0,0 +1,256 @@
|
|
|
1
|
+
# Copyright 2021 The Layout Parser team. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import List, Optional, Union, Dict, Any, Tuple
|
|
16
|
+
|
|
17
|
+
from PIL import Image
|
|
18
|
+
import numpy as np
|
|
19
|
+
|
|
20
|
+
from .catalog import PathManager, LABEL_MAP_CATALOG, MODEL_CATALOG
|
|
21
|
+
from ..base_layoutmodel import BaseLayoutModel
|
|
22
|
+
from ...elements import Rectangle, TextBlock, Layout
|
|
23
|
+
|
|
24
|
+
from ...file_utils import is_effdet_available, is_torch_cuda_available
|
|
25
|
+
|
|
26
|
+
if is_effdet_available():
|
|
27
|
+
import torch
|
|
28
|
+
from effdet import create_model
|
|
29
|
+
from effdet.data.transforms import (
|
|
30
|
+
IMAGENET_DEFAULT_MEAN,
|
|
31
|
+
IMAGENET_DEFAULT_STD,
|
|
32
|
+
transforms_coco_eval,
|
|
33
|
+
)
|
|
34
|
+
else:
|
|
35
|
+
# Copied from https://github.com/rwightman/efficientdet-pytorch/blob/c5b694aa34900fdee6653210d856ca8320bf7d4e/effdet/data/transforms.py#L13
|
|
36
|
+
# Such that when effdet is not loaded, we'll still have default values for IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
|
37
|
+
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
|
|
38
|
+
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
|
|
39
|
+
# IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5)
|
|
40
|
+
# IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class InputTransform:
|
|
44
|
+
def __init__(
|
|
45
|
+
self,
|
|
46
|
+
image_size,
|
|
47
|
+
mean=IMAGENET_DEFAULT_MEAN,
|
|
48
|
+
std=IMAGENET_DEFAULT_STD,
|
|
49
|
+
):
|
|
50
|
+
|
|
51
|
+
self.mean = mean
|
|
52
|
+
self.std = std
|
|
53
|
+
|
|
54
|
+
self.transform = transforms_coco_eval(
|
|
55
|
+
image_size,
|
|
56
|
+
interpolation="bilinear",
|
|
57
|
+
use_prefetcher=True,
|
|
58
|
+
fill_color="mean",
|
|
59
|
+
mean=self.mean,
|
|
60
|
+
std=self.std,
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
self.mean_tensor = torch.tensor([x * 255 for x in mean]).view(1, 3, 1, 1)
|
|
64
|
+
self.std_tensor = torch.tensor([x * 255 for x in std]).view(1, 3, 1, 1)
|
|
65
|
+
|
|
66
|
+
def preprocess(self, image: Image) -> Tuple["torch.Tensor", Dict]:
|
|
67
|
+
|
|
68
|
+
image = image.convert("RGB")
|
|
69
|
+
image_info = {"img_size": image.size}
|
|
70
|
+
|
|
71
|
+
input, image_info = self.transform(image, image_info)
|
|
72
|
+
image_info = {
|
|
73
|
+
key: torch.tensor(val).unsqueeze(0) for key, val in image_info.items()
|
|
74
|
+
}
|
|
75
|
+
|
|
76
|
+
input = torch.tensor(input).unsqueeze(0)
|
|
77
|
+
input = input.float().sub_(self.mean_tensor).div_(self.std_tensor)
|
|
78
|
+
|
|
79
|
+
return input, image_info
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
class EfficientDetLayoutModel(BaseLayoutModel):
|
|
83
|
+
"""Create a EfficientDet-based Layout Detection Model
|
|
84
|
+
|
|
85
|
+
Args:
|
|
86
|
+
config_path (:obj:`str`):
|
|
87
|
+
The path to the configuration file.
|
|
88
|
+
model_path (:obj:`str`, None):
|
|
89
|
+
The path to the saved weights of the model.
|
|
90
|
+
If set, overwrite the weights in the configuration file.
|
|
91
|
+
Defaults to `None`.
|
|
92
|
+
label_map (:obj:`dict`, optional):
|
|
93
|
+
The map from the model prediction (ids) to real
|
|
94
|
+
word labels (strings). If the config is from one of the supported
|
|
95
|
+
datasets, Layout Parser will automatically initialize the label_map.
|
|
96
|
+
Defaults to `None`.
|
|
97
|
+
enforce_cpu(:obj:`bool`, optional):
|
|
98
|
+
When set to `True`, it will enforce using cpu even if it is on a CUDA
|
|
99
|
+
available device.
|
|
100
|
+
extra_config (:obj:`dict`, optional):
|
|
101
|
+
Extra configuration passed to the EfficientDet model
|
|
102
|
+
configuration. Currently supported arguments:
|
|
103
|
+
num_classes: specifying the number of classes for the models
|
|
104
|
+
output_confidence_threshold: minmum object prediction confidence to retain
|
|
105
|
+
|
|
106
|
+
Examples::
|
|
107
|
+
>>> import layoutparser as lp
|
|
108
|
+
>>> model = lp.EfficientDetLayoutModel("lp://PubLayNet/tf_efficientdet_d0/config")
|
|
109
|
+
>>> model.detect(image)
|
|
110
|
+
|
|
111
|
+
"""
|
|
112
|
+
|
|
113
|
+
DEPENDENCIES = ["effdet"]
|
|
114
|
+
DETECTOR_NAME = "efficientdet"
|
|
115
|
+
MODEL_CATALOG = MODEL_CATALOG
|
|
116
|
+
|
|
117
|
+
DEFAULT_OUTPUT_CONFIDENCE_THRESHOLD = 0.25
|
|
118
|
+
|
|
119
|
+
def __init__(
|
|
120
|
+
self,
|
|
121
|
+
config_path: str,
|
|
122
|
+
model_path: str = None,
|
|
123
|
+
label_map: Optional[Dict] = None,
|
|
124
|
+
extra_config: Optional[Dict] = None,
|
|
125
|
+
enforce_cpu: bool = False,
|
|
126
|
+
device: str = None,
|
|
127
|
+
):
|
|
128
|
+
|
|
129
|
+
if is_torch_cuda_available():
|
|
130
|
+
if device is None:
|
|
131
|
+
device = "cuda"
|
|
132
|
+
else:
|
|
133
|
+
device = "cpu"
|
|
134
|
+
self.device = device
|
|
135
|
+
|
|
136
|
+
extra_config = extra_config if extra_config is not None else {}
|
|
137
|
+
|
|
138
|
+
self._initialize_model(config_path, model_path, label_map, extra_config)
|
|
139
|
+
|
|
140
|
+
self.output_confidence_threshold = extra_config.get(
|
|
141
|
+
"output_confidence_threshold", self.DEFAULT_OUTPUT_CONFIDENCE_THRESHOLD
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
self.preprocessor = InputTransform(self.config.image_size)
|
|
145
|
+
|
|
146
|
+
def _initialize_model(
|
|
147
|
+
self,
|
|
148
|
+
config_path: str,
|
|
149
|
+
model_path: Optional[str],
|
|
150
|
+
label_map: Optional[Dict],
|
|
151
|
+
extra_config: Optional[Dict],
|
|
152
|
+
):
|
|
153
|
+
|
|
154
|
+
config_path, model_path = self.config_parser(config_path, model_path)
|
|
155
|
+
|
|
156
|
+
if config_path.startswith("lp://"):
|
|
157
|
+
# If it's officially supported by layoutparser
|
|
158
|
+
dataset_name, model_name = config_path.lstrip("lp://").split("/")[1:3]
|
|
159
|
+
|
|
160
|
+
if label_map is None:
|
|
161
|
+
label_map = LABEL_MAP_CATALOG[dataset_name]
|
|
162
|
+
num_classes = len(label_map)
|
|
163
|
+
|
|
164
|
+
model_path = PathManager.get_local_path(model_path)
|
|
165
|
+
|
|
166
|
+
self.model = create_model(
|
|
167
|
+
model_name,
|
|
168
|
+
num_classes=num_classes,
|
|
169
|
+
bench_task="predict",
|
|
170
|
+
pretrained=True,
|
|
171
|
+
checkpoint_path=model_path,
|
|
172
|
+
)
|
|
173
|
+
else:
|
|
174
|
+
assert (
|
|
175
|
+
model_path is not None
|
|
176
|
+
), f"When the specified model is not layoutparser-based, you need to specify the model_path"
|
|
177
|
+
|
|
178
|
+
assert (
|
|
179
|
+
label_map is not None or "num_classes" in extra_config
|
|
180
|
+
), "When the specified model is not layoutparser-based, you need to specify the label_map or add num_classes in the extra_config"
|
|
181
|
+
|
|
182
|
+
model_name = config_path
|
|
183
|
+
model_path = PathManager.get_local_path(
|
|
184
|
+
model_path
|
|
185
|
+
) # It might be an https URL
|
|
186
|
+
|
|
187
|
+
num_classes = len(label_map) if label_map else extra_config["num_classes"]
|
|
188
|
+
|
|
189
|
+
self.model = create_model(
|
|
190
|
+
model_name,
|
|
191
|
+
num_classes=num_classes,
|
|
192
|
+
bench_task="predict",
|
|
193
|
+
pretrained=True,
|
|
194
|
+
checkpoint_path=model_path,
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
self.model.to(self.device)
|
|
198
|
+
self.model.eval()
|
|
199
|
+
self.config = self.model.config
|
|
200
|
+
self.label_map = label_map if label_map is not None else {}
|
|
201
|
+
|
|
202
|
+
def detect(self, image: Union["np.ndarray", "Image.Image"]):
|
|
203
|
+
|
|
204
|
+
image = self.image_loader(image)
|
|
205
|
+
|
|
206
|
+
model_inputs, image_info = self.preprocessor.preprocess(image)
|
|
207
|
+
|
|
208
|
+
model_outputs = self.model(
|
|
209
|
+
model_inputs.to(self.device),
|
|
210
|
+
{key: val.to(self.device) for key, val in image_info.items()},
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
layout = self.gather_output(model_outputs)
|
|
214
|
+
return layout
|
|
215
|
+
|
|
216
|
+
def gather_output(self, model_outputs: "torch.Tensor") -> Layout:
|
|
217
|
+
|
|
218
|
+
model_outputs = model_outputs.cpu().detach()
|
|
219
|
+
box_predictions = Layout()
|
|
220
|
+
|
|
221
|
+
for index, sample in enumerate(model_outputs):
|
|
222
|
+
sample[:, 2] -= sample[:, 0]
|
|
223
|
+
sample[:, 3] -= sample[:, 1]
|
|
224
|
+
|
|
225
|
+
for det in sample:
|
|
226
|
+
|
|
227
|
+
score = float(det[4])
|
|
228
|
+
pred_cat = int(det[5])
|
|
229
|
+
x, y, w, h = det[0:4].tolist()
|
|
230
|
+
|
|
231
|
+
if (
|
|
232
|
+
score < self.output_confidence_threshold
|
|
233
|
+
): # stop when below this threshold, scores in descending order
|
|
234
|
+
break
|
|
235
|
+
|
|
236
|
+
box_predictions.append(
|
|
237
|
+
TextBlock(
|
|
238
|
+
block=Rectangle(x, y, w + x, h + y),
|
|
239
|
+
score=score,
|
|
240
|
+
id=index,
|
|
241
|
+
type=self.label_map.get(pred_cat, pred_cat),
|
|
242
|
+
)
|
|
243
|
+
)
|
|
244
|
+
|
|
245
|
+
return box_predictions
|
|
246
|
+
|
|
247
|
+
def image_loader(self, image: Union["np.ndarray", "Image.Image"]):
|
|
248
|
+
|
|
249
|
+
# Convert cv2 Image Input
|
|
250
|
+
if isinstance(image, np.ndarray):
|
|
251
|
+
# In this case, we assume the image is loaded by cv2
|
|
252
|
+
# and the channel order is BGR
|
|
253
|
+
image = image[..., ::-1]
|
|
254
|
+
image = Image.fromarray(image, mode="RGB")
|
|
255
|
+
|
|
256
|
+
return image
|
|
@@ -0,0 +1,133 @@
|
|
|
1
|
+
# Copyright 2021 The Layout Parser team. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""
|
|
16
|
+
Inside layoutparser, we support the following formats for specifying layout model configs
|
|
17
|
+
or weights:
|
|
18
|
+
|
|
19
|
+
1. URL-based formats:
|
|
20
|
+
- A local path: ~/models/publaynet/path
|
|
21
|
+
- Link to the models: https://web/url/to/models
|
|
22
|
+
|
|
23
|
+
2. LayoutParser Based Model/Config Path Formats:
|
|
24
|
+
- Full format: lp://<backend-name>/<dataset-name>/<model-architecture-name>
|
|
25
|
+
- Short format: lp://<dataset-name>/<model-architecture-name>
|
|
26
|
+
- Brief format: lp://<dataset-name>
|
|
27
|
+
|
|
28
|
+
For each LayoutParser-based format, you could also add a `config` or `weight` identifier
|
|
29
|
+
after them:
|
|
30
|
+
- Full format: lp://<backend-name>/<dataset-name>/<model-architecture-name>/<config, weight>
|
|
31
|
+
- Short format: lp://<dataset-name>/<model-architecture-name>/<config, weight>
|
|
32
|
+
- Brief format: lp://<dataset-name>/<config, weight>
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
from dataclasses import dataclass
|
|
36
|
+
|
|
37
|
+
LAYOUT_PARSER_MODEL_PREFIX = "lp://"
|
|
38
|
+
ALLOWED_LAYOUT_MODEL_IDENTIFIER_NAMES = ["config", "weight"]
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
@dataclass
|
|
42
|
+
class LayoutModelConfig:
|
|
43
|
+
|
|
44
|
+
backend_name: str
|
|
45
|
+
dataset_name: str
|
|
46
|
+
model_arch: str
|
|
47
|
+
identifier: str
|
|
48
|
+
|
|
49
|
+
def __post_init__(self):
|
|
50
|
+
assert self.identifier in ALLOWED_LAYOUT_MODEL_IDENTIFIER_NAMES
|
|
51
|
+
|
|
52
|
+
@property
|
|
53
|
+
def full(self):
|
|
54
|
+
return LAYOUT_PARSER_MODEL_PREFIX + "/".join(
|
|
55
|
+
[self.backend_name, self.dataset_name, self.model_arch, self.identifier]
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
@property
|
|
59
|
+
def short(self):
|
|
60
|
+
return LAYOUT_PARSER_MODEL_PREFIX + "/".join(
|
|
61
|
+
[self.dataset_name, self.model_arch, self.identifier]
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
@property
|
|
65
|
+
def brief(self):
|
|
66
|
+
return LAYOUT_PARSER_MODEL_PREFIX + "/".join([self.dataset_name, self.model_arch])
|
|
67
|
+
|
|
68
|
+
def dual(self):
|
|
69
|
+
for identifier in ALLOWED_LAYOUT_MODEL_IDENTIFIER_NAMES:
|
|
70
|
+
if identifier != self.identifier:
|
|
71
|
+
break
|
|
72
|
+
|
|
73
|
+
return self.__class__(
|
|
74
|
+
backend_name=self.backend_name,
|
|
75
|
+
dataset_name=self.dataset_name,
|
|
76
|
+
model_arch=self.model_arch,
|
|
77
|
+
identifier=identifier,
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def is_lp_layout_model_config_any_format(config: str) -> bool:
|
|
82
|
+
if not config.startswith(LAYOUT_PARSER_MODEL_PREFIX):
|
|
83
|
+
return False
|
|
84
|
+
if len(config[len(LAYOUT_PARSER_MODEL_PREFIX) :].split("/")) not in [1, 2, 3, 4]:
|
|
85
|
+
return False
|
|
86
|
+
return True
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def add_identifier_for_config(config: str, identifier: str) -> str:
|
|
90
|
+
return config.rstrip("/").rstrip(f"/{identifier}") + f"/{identifier}"
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def layout_model_config_parser(
|
|
94
|
+
config, backend_name=None, model_arch=None
|
|
95
|
+
) -> LayoutModelConfig:
|
|
96
|
+
|
|
97
|
+
assert config.split("/")[-1] in ALLOWED_LAYOUT_MODEL_IDENTIFIER_NAMES, (
|
|
98
|
+
f"The input config {config} does not contain identifier information."
|
|
99
|
+
f"Consider run `config = add_identifier_for_config(config, identifier)` first."
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
parts = config[len(LAYOUT_PARSER_MODEL_PREFIX) :].split("/")
|
|
103
|
+
if len(parts) == 4: # Full format
|
|
104
|
+
backend_name, dataset_name, model_arch, identifier = parts
|
|
105
|
+
elif len(parts) == 3: # Short format
|
|
106
|
+
assert backend_name != None
|
|
107
|
+
|
|
108
|
+
if parts[0] == backend_name:
|
|
109
|
+
# lp://<backend-name>/<dataset-name>/<identifier>
|
|
110
|
+
assert model_arch != None
|
|
111
|
+
_, dataset_name, identifier = parts
|
|
112
|
+
else:
|
|
113
|
+
# lp://<dataset-name>/<model-arch>/<identifier>
|
|
114
|
+
dataset_name, model_arch, identifier = parts
|
|
115
|
+
|
|
116
|
+
elif len(parts) == 2: # brief format
|
|
117
|
+
assert backend_name != None
|
|
118
|
+
assert model_arch != None
|
|
119
|
+
if parts[0] == backend_name:
|
|
120
|
+
# lp://<backend-name>/<identifier>
|
|
121
|
+
raise ValueError(f"Invalid LP Model Config {config}")
|
|
122
|
+
|
|
123
|
+
# lp://<dataset-name>/<identifier>
|
|
124
|
+
dataset_name, identifier = parts
|
|
125
|
+
else:
|
|
126
|
+
raise ValueError(f"Invalid LP Model Config {config}")
|
|
127
|
+
|
|
128
|
+
return LayoutModelConfig(
|
|
129
|
+
backend_name=backend_name,
|
|
130
|
+
dataset_name=dataset_name,
|
|
131
|
+
model_arch=model_arch,
|
|
132
|
+
identifier=identifier,
|
|
133
|
+
)
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
# Copyright 2021 The Layout Parser team and Paddle Detection model
|
|
2
|
+
# contributors. All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
from . import catalog as _UNUSED
|
|
17
|
+
from .layoutmodel import PaddleDetectionLayoutModel
|
|
@@ -0,0 +1,214 @@
|
|
|
1
|
+
# Copyright 2021 The Layout Parser team and Paddle Detection model
|
|
2
|
+
# contributors. All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
|
|
16
|
+
import os
|
|
17
|
+
import logging
|
|
18
|
+
from typing import Any, Optional
|
|
19
|
+
from urllib.parse import urlparse
|
|
20
|
+
import tarfile
|
|
21
|
+
import uuid
|
|
22
|
+
|
|
23
|
+
from iopath.common.file_io import PathHandler
|
|
24
|
+
from iopath.common.file_io import HTTPURLHandler
|
|
25
|
+
from iopath.common.file_io import get_cache_dir, file_lock
|
|
26
|
+
from iopath.common.download import download
|
|
27
|
+
|
|
28
|
+
from ..base_catalog import PathManager
|
|
29
|
+
|
|
30
|
+
MODEL_CATALOG = {
|
|
31
|
+
"PubLayNet": {
|
|
32
|
+
"ppyolov2_r50vd_dcn_365e": "https://paddle-model-ecology.bj.bcebos.com/model/layout-parser/ppyolov2_r50vd_dcn_365e_publaynet.tar",
|
|
33
|
+
},
|
|
34
|
+
"TableBank": {
|
|
35
|
+
"ppyolov2_r50vd_dcn_365e": "https://paddle-model-ecology.bj.bcebos.com/model/layout-parser/ppyolov2_r50vd_dcn_365e_tableBank_word.tar",
|
|
36
|
+
# "ppyolov2_r50vd_dcn_365e_tableBank_latex": "https://paddle-model-ecology.bj.bcebos.com/model/layout-parser/ppyolov2_r50vd_dcn_365e_tableBank_latex.tar",
|
|
37
|
+
# TODO: Train a single tablebank model for paddlepaddle
|
|
38
|
+
},
|
|
39
|
+
}
|
|
40
|
+
|
|
41
|
+
# fmt: off
|
|
42
|
+
LABEL_MAP_CATALOG = {
|
|
43
|
+
"PubLayNet": {
|
|
44
|
+
0: "Text",
|
|
45
|
+
1: "Title",
|
|
46
|
+
2: "List",
|
|
47
|
+
3: "Table",
|
|
48
|
+
4: "Figure"},
|
|
49
|
+
"TableBank": {
|
|
50
|
+
0: "Table"
|
|
51
|
+
},
|
|
52
|
+
}
|
|
53
|
+
# fmt: on
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
# Paddle model package everything in tar files, and each model's tar file should contain
|
|
57
|
+
# the following files in the list:
|
|
58
|
+
_TAR_FILE_NAME_LIST = [
|
|
59
|
+
"inference.pdiparams",
|
|
60
|
+
"inference.pdiparams.info",
|
|
61
|
+
"inference.pdmodel",
|
|
62
|
+
]
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def _get_untar_directory(tar_file: str) -> str:
|
|
66
|
+
|
|
67
|
+
base_path = os.path.dirname(tar_file)
|
|
68
|
+
file_name = os.path.splitext(os.path.basename(tar_file))[0]
|
|
69
|
+
target_folder = os.path.join(base_path, file_name)
|
|
70
|
+
|
|
71
|
+
return target_folder
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def _untar_model_weights(model_tar):
|
|
75
|
+
"""untar model files"""
|
|
76
|
+
|
|
77
|
+
model_dir = _get_untar_directory(model_tar)
|
|
78
|
+
|
|
79
|
+
if not os.path.exists(
|
|
80
|
+
os.path.join(model_dir, _TAR_FILE_NAME_LIST[0])
|
|
81
|
+
) or not os.path.exists(os.path.join(model_dir, _TAR_FILE_NAME_LIST[2])):
|
|
82
|
+
# the path to save the decompressed file
|
|
83
|
+
os.makedirs(model_dir, exist_ok=True)
|
|
84
|
+
with tarfile.open(model_tar, "r") as tarobj:
|
|
85
|
+
for member in tarobj.getmembers():
|
|
86
|
+
filename = None
|
|
87
|
+
for tar_file_name in _TAR_FILE_NAME_LIST:
|
|
88
|
+
if tar_file_name in member.name:
|
|
89
|
+
filename = tar_file_name
|
|
90
|
+
if filename is None:
|
|
91
|
+
continue
|
|
92
|
+
file = tarobj.extractfile(member)
|
|
93
|
+
with open(os.path.join(model_dir, filename), "wb") as model_file:
|
|
94
|
+
model_file.write(file.read())
|
|
95
|
+
return model_dir
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def is_cached_folder_exists_and_valid(cached):
|
|
99
|
+
possible_extracted_model_folder = _get_untar_directory(cached)
|
|
100
|
+
if not os.path.exists(possible_extracted_model_folder):
|
|
101
|
+
return False
|
|
102
|
+
for tar_file in _TAR_FILE_NAME_LIST:
|
|
103
|
+
if not os.path.exists(os.path.join(possible_extracted_model_folder, tar_file)):
|
|
104
|
+
return False
|
|
105
|
+
return True
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
class PaddleModelURLHandler(HTTPURLHandler):
|
|
109
|
+
"""
|
|
110
|
+
Supports download and file check for Baidu Cloud links
|
|
111
|
+
"""
|
|
112
|
+
|
|
113
|
+
MAX_FILENAME_LEN = 250
|
|
114
|
+
|
|
115
|
+
def _get_supported_prefixes(self):
|
|
116
|
+
return ["https://paddle-model-ecology.bj.bcebos.com"]
|
|
117
|
+
|
|
118
|
+
def _isfile(self, path):
|
|
119
|
+
return path in self.cache_map
|
|
120
|
+
|
|
121
|
+
def _get_local_path(
|
|
122
|
+
self,
|
|
123
|
+
path: str,
|
|
124
|
+
force: bool = False,
|
|
125
|
+
cache_dir: Optional[str] = None,
|
|
126
|
+
**kwargs: Any,
|
|
127
|
+
) -> str:
|
|
128
|
+
"""
|
|
129
|
+
As paddle model stores all files in tar files, we need to extract them
|
|
130
|
+
and get the newly extracted folder path. This function rewrites the base
|
|
131
|
+
function to support the following situations:
|
|
132
|
+
|
|
133
|
+
1. If the tar file is not downloaded, it will download the tar file,
|
|
134
|
+
extract it to the target folder, delete the downloaded tar file,
|
|
135
|
+
and return the folder path.
|
|
136
|
+
2. If the extracted target folder is present, and all the necessary model
|
|
137
|
+
files are present (specified in _TAR_FILE_NAME_LIST), it will
|
|
138
|
+
return the folder path.
|
|
139
|
+
3. If the tar file is downloaded, but the extracted target folder is not
|
|
140
|
+
present (or it doesn't contain the necessary files in _TAR_FILE_NAME_LIST),
|
|
141
|
+
it will extract the tar file to the target folder, delete the tar file,
|
|
142
|
+
and return the folder path.
|
|
143
|
+
|
|
144
|
+
"""
|
|
145
|
+
self._check_kwargs(kwargs)
|
|
146
|
+
if (
|
|
147
|
+
force
|
|
148
|
+
or path not in self.cache_map
|
|
149
|
+
or not os.path.exists(self.cache_map[path])
|
|
150
|
+
):
|
|
151
|
+
logger = logging.getLogger(__name__)
|
|
152
|
+
parsed_url = urlparse(path)
|
|
153
|
+
dirname = os.path.join(
|
|
154
|
+
get_cache_dir(cache_dir), os.path.dirname(parsed_url.path.lstrip("/"))
|
|
155
|
+
)
|
|
156
|
+
filename = path.split("/")[-1]
|
|
157
|
+
if len(filename) > self.MAX_FILENAME_LEN:
|
|
158
|
+
filename = filename[:100] + "_" + uuid.uuid4().hex
|
|
159
|
+
|
|
160
|
+
cached = os.path.join(dirname, filename)
|
|
161
|
+
|
|
162
|
+
if is_cached_folder_exists_and_valid(cached):
|
|
163
|
+
# When the cached folder exists and valid, we don't need to redownload
|
|
164
|
+
# the tar file.
|
|
165
|
+
self.cache_map[path] = _get_untar_directory(cached)
|
|
166
|
+
|
|
167
|
+
else:
|
|
168
|
+
with file_lock(cached):
|
|
169
|
+
if not os.path.isfile(cached):
|
|
170
|
+
logger.info("Downloading {} ...".format(path))
|
|
171
|
+
cached = download(path, dirname, filename=filename)
|
|
172
|
+
|
|
173
|
+
if path.endswith(".tar"):
|
|
174
|
+
model_dir = _untar_model_weights(cached)
|
|
175
|
+
try:
|
|
176
|
+
os.remove(cached) # remove the redundant tar file
|
|
177
|
+
# TODO: remove the .lock file .
|
|
178
|
+
except:
|
|
179
|
+
logger.warning(
|
|
180
|
+
f"Not able to remove the cached tar file {cached}"
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
logger.info("URL {} cached in {}".format(path, model_dir))
|
|
184
|
+
self.cache_map[path] = model_dir
|
|
185
|
+
|
|
186
|
+
return self.cache_map[path]
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
class LayoutParserPaddleModelHandler(PathHandler):
|
|
190
|
+
"""
|
|
191
|
+
Resolve anything that's in LayoutParser model zoo.
|
|
192
|
+
"""
|
|
193
|
+
|
|
194
|
+
PREFIX = "lp://paddledetection/"
|
|
195
|
+
|
|
196
|
+
def _get_supported_prefixes(self):
|
|
197
|
+
return [self.PREFIX]
|
|
198
|
+
|
|
199
|
+
def _get_local_path(self, path, **kwargs):
|
|
200
|
+
model_name = path[len(self.PREFIX) :]
|
|
201
|
+
dataset_name, *model_name, data_type = model_name.split("/")
|
|
202
|
+
|
|
203
|
+
if data_type == "weight":
|
|
204
|
+
model_url = MODEL_CATALOG[dataset_name]["/".join(model_name)]
|
|
205
|
+
else:
|
|
206
|
+
raise ValueError(f"Unknown data_type {data_type}")
|
|
207
|
+
return PathManager.get_local_path(model_url, **kwargs)
|
|
208
|
+
|
|
209
|
+
def _open(self, path, mode="r", **kwargs):
|
|
210
|
+
return PathManager.open(self._get_local_path(path), mode, **kwargs)
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
PathManager.register_handler(PaddleModelURLHandler())
|
|
214
|
+
PathManager.register_handler(LayoutParserPaddleModelHandler())
|