docling-ibm-models 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docling_ibm_models/layoutmodel/layout_predictor.py +171 -0
- docling_ibm_models/tableformer/__init__.py +0 -0
- docling_ibm_models/tableformer/common.py +200 -0
- docling_ibm_models/tableformer/data_management/__init__.py +0 -0
- docling_ibm_models/tableformer/data_management/data_transformer.py +504 -0
- docling_ibm_models/tableformer/data_management/functional.py +574 -0
- docling_ibm_models/tableformer/data_management/matching_post_processor.py +1325 -0
- docling_ibm_models/tableformer/data_management/tf_cell_matcher.py +596 -0
- docling_ibm_models/tableformer/data_management/tf_dataset.py +1233 -0
- docling_ibm_models/tableformer/data_management/tf_predictor.py +1020 -0
- docling_ibm_models/tableformer/data_management/transforms.py +396 -0
- docling_ibm_models/tableformer/models/__init__.py +0 -0
- docling_ibm_models/tableformer/models/common/__init__.py +0 -0
- docling_ibm_models/tableformer/models/common/base_model.py +279 -0
- docling_ibm_models/tableformer/models/table04_rs/__init__.py +0 -0
- docling_ibm_models/tableformer/models/table04_rs/bbox_decoder_rs.py +163 -0
- docling_ibm_models/tableformer/models/table04_rs/encoder04_rs.py +72 -0
- docling_ibm_models/tableformer/models/table04_rs/tablemodel04_rs.py +324 -0
- docling_ibm_models/tableformer/models/table04_rs/transformer_rs.py +203 -0
- docling_ibm_models/tableformer/otsl.py +541 -0
- docling_ibm_models/tableformer/settings.py +90 -0
- docling_ibm_models/tableformer/test_dataset_cache.py +37 -0
- docling_ibm_models/tableformer/test_prepare_image.py +99 -0
- docling_ibm_models/tableformer/utils/__init__.py +0 -0
- docling_ibm_models/tableformer/utils/app_profiler.py +243 -0
- docling_ibm_models/tableformer/utils/torch_utils.py +216 -0
- docling_ibm_models/tableformer/utils/utils.py +376 -0
- docling_ibm_models/tableformer/utils/variance.py +175 -0
- docling_ibm_models-0.1.0.dist-info/LICENSE +21 -0
- docling_ibm_models-0.1.0.dist-info/METADATA +172 -0
- docling_ibm_models-0.1.0.dist-info/RECORD +32 -0
- docling_ibm_models-0.1.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,171 @@
|
|
1
|
+
#
|
2
|
+
# Copyright IBM Corp. 2024 - 2024
|
3
|
+
# SPDX-License-Identifier: MIT
|
4
|
+
#
|
5
|
+
import os
|
6
|
+
from collections.abc import Iterable
|
7
|
+
from typing import Union
|
8
|
+
|
9
|
+
import numpy as np
|
10
|
+
import onnxruntime as ort
|
11
|
+
from PIL import Image
|
12
|
+
|
13
|
+
MODEL_CHECKPOINT_FN = "model.pt"
|
14
|
+
DEFAULT_NUM_THREADS = 4
|
15
|
+
|
16
|
+
|
17
|
+
# Classes:
|
18
|
+
CLASSES_MAP = {
|
19
|
+
0: "background",
|
20
|
+
1: "Caption",
|
21
|
+
2: "Footnote",
|
22
|
+
3: "Formula",
|
23
|
+
4: "List-item",
|
24
|
+
5: "Page-footer",
|
25
|
+
6: "Page-header",
|
26
|
+
7: "Picture",
|
27
|
+
8: "Section-header",
|
28
|
+
9: "Table",
|
29
|
+
10: "Text",
|
30
|
+
11: "Title",
|
31
|
+
12: "Document Index",
|
32
|
+
13: "Code",
|
33
|
+
14: "Checkbox-Selected",
|
34
|
+
15: "Checkbox-Unselected",
|
35
|
+
16: "Form",
|
36
|
+
17: "Key-Value Region",
|
37
|
+
}
|
38
|
+
|
39
|
+
|
40
|
+
class LayoutPredictor:
|
41
|
+
r"""
|
42
|
+
Document layout prediction using ONNX
|
43
|
+
"""
|
44
|
+
|
45
|
+
def __init__(
|
46
|
+
self, artifact_path: str, num_threads: int = None, use_cpu_only: bool = False
|
47
|
+
):
|
48
|
+
r"""
|
49
|
+
Provide the artifact path that contains the LayoutModel ONNX file
|
50
|
+
|
51
|
+
The number of threads is decided, in the following order, by:
|
52
|
+
1. The init method parameter `num_threads`, if it is set.
|
53
|
+
2. The envvar "OMP_NUM_THREADS", if it is set.
|
54
|
+
3. The default value DEFAULT_NUM_THREADS.
|
55
|
+
|
56
|
+
The execution provided is decided, in the following order:
|
57
|
+
1. If the init method parameter `cpu_only` is True or the envvar "USE_CPU_ONLY" is set,
|
58
|
+
it uses the "CPUExecutionProvider".
|
59
|
+
3. Otherwise if the "CUDAExecutionProvider" is present, use:
|
60
|
+
["CUDAExecutionProvider", "CPUExecutionProvider"]:
|
61
|
+
|
62
|
+
Parameters
|
63
|
+
----------
|
64
|
+
artifact_path: Path for the model ONNX file.
|
65
|
+
num_threads: (Optional) Number of threads to run the inference.
|
66
|
+
use_cpu_only: (Optional) If True, it forces CPU as the execution provider.
|
67
|
+
|
68
|
+
Raises
|
69
|
+
------
|
70
|
+
FileNotFoundError when the model's ONNX file is missing
|
71
|
+
"""
|
72
|
+
# Set basic params
|
73
|
+
self._threshold = 0.6 # Score threshold
|
74
|
+
self._image_size = 640
|
75
|
+
self._size = np.asarray([[self._image_size, self._image_size]], dtype=np.int64)
|
76
|
+
|
77
|
+
# Get env vars
|
78
|
+
self._use_cpu_only = use_cpu_only or ("USE_CPU_ONLY" in os.environ)
|
79
|
+
if num_threads is None:
|
80
|
+
num_threads = int(os.environ.get("OMP_NUM_THREADS", DEFAULT_NUM_THREADS))
|
81
|
+
self._num_threads = num_threads
|
82
|
+
|
83
|
+
# Decide the execution providers
|
84
|
+
if (
|
85
|
+
not self._use_cpu_only
|
86
|
+
and "CUDAExecutionProvider" in ort.get_available_providers()
|
87
|
+
):
|
88
|
+
providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
|
89
|
+
else:
|
90
|
+
providers = ["CPUExecutionProvider"]
|
91
|
+
self._providers = providers
|
92
|
+
|
93
|
+
# Model ONNX file
|
94
|
+
self._onnx_fn = os.path.join(artifact_path, MODEL_CHECKPOINT_FN)
|
95
|
+
if not os.path.isfile(self._onnx_fn):
|
96
|
+
raise FileNotFoundError("Missing ONNX file: {}".format(self._onnx_fn))
|
97
|
+
|
98
|
+
# ONNX options
|
99
|
+
self._options = ort.SessionOptions()
|
100
|
+
self._options.intra_op_num_threads = self._num_threads
|
101
|
+
self.sess = ort.InferenceSession(
|
102
|
+
self._onnx_fn,
|
103
|
+
sess_options=self._options,
|
104
|
+
providers=self._providers,
|
105
|
+
)
|
106
|
+
|
107
|
+
def info(self) -> dict:
|
108
|
+
r"""
|
109
|
+
Get information about the configuration of LayoutPredictor
|
110
|
+
"""
|
111
|
+
info = {
|
112
|
+
"onnx_file": self._onnx_fn,
|
113
|
+
"intra_op_num_threads": self._num_threads,
|
114
|
+
"use_cpu_only": self._use_cpu_only,
|
115
|
+
"providers": self._providers,
|
116
|
+
"image_size": self._image_size,
|
117
|
+
"threshold": self._threshold,
|
118
|
+
}
|
119
|
+
return info
|
120
|
+
|
121
|
+
def predict(self, orig_img: Union[Image, np.array]) -> Iterable[dict]:
|
122
|
+
r"""
|
123
|
+
Predict bounding boxes for a given image.
|
124
|
+
The origin (0, 0) is the top-left corner and the predicted bbox coords are provided as:
|
125
|
+
[left, top, right, bottom]
|
126
|
+
|
127
|
+
Parameter
|
128
|
+
---------
|
129
|
+
origin_img: Image to be predicted as a PIL Image object or numpy array.
|
130
|
+
|
131
|
+
Yield
|
132
|
+
-----
|
133
|
+
Bounding box as a dict with the keys: "label", "confidence", "l", "t", "r", "b"
|
134
|
+
|
135
|
+
Raises
|
136
|
+
------
|
137
|
+
TypeError when the input image is not supported
|
138
|
+
"""
|
139
|
+
# Convert image format
|
140
|
+
if isinstance(orig_img, Image.Image):
|
141
|
+
page_img = orig_img.convert("RGB")
|
142
|
+
elif isinstance(orig_img, np.ndarray):
|
143
|
+
page_img = Image.fromarray(orig_img).convert("RGB")
|
144
|
+
else:
|
145
|
+
raise TypeError("Not supported input image format")
|
146
|
+
|
147
|
+
w, h = page_img.size
|
148
|
+
page_img = page_img.resize((self._image_size, self._image_size))
|
149
|
+
page_data = np.array(page_img, dtype=np.uint8) / np.float32(255.0)
|
150
|
+
page_data = np.expand_dims(np.transpose(page_data, axes=[2, 0, 1]), axis=0)
|
151
|
+
|
152
|
+
# Predict
|
153
|
+
labels, boxes, scores = self.sess.run(
|
154
|
+
output_names=None,
|
155
|
+
input_feed={
|
156
|
+
"images": page_data,
|
157
|
+
"orig_target_sizes": self._size,
|
158
|
+
},
|
159
|
+
)
|
160
|
+
|
161
|
+
# Yield output
|
162
|
+
for label, box, score in zip(labels[0], boxes[0], scores[0]):
|
163
|
+
if score > self._threshold:
|
164
|
+
yield {
|
165
|
+
"l": box[0] / self._image_size * w,
|
166
|
+
"t": box[1] / self._image_size * h,
|
167
|
+
"r": box[2] / self._image_size * w,
|
168
|
+
"b": box[3] / self._image_size * h,
|
169
|
+
"label": CLASSES_MAP[label],
|
170
|
+
"confidence": score,
|
171
|
+
}
|
File without changes
|
@@ -0,0 +1,200 @@
|
|
1
|
+
#
|
2
|
+
# Copyright IBM Corp. 2024 - 2024
|
3
|
+
# SPDX-License-Identifier: MIT
|
4
|
+
#
|
5
|
+
import argparse
|
6
|
+
import json
|
7
|
+
import logging
|
8
|
+
import os
|
9
|
+
|
10
|
+
import torch
|
11
|
+
|
12
|
+
import docling_ibm_models.tableformer.settings as s
|
13
|
+
from docling_ibm_models.tableformer.models.common.base_model import BaseModel
|
14
|
+
|
15
|
+
LOG_LEVEL = logging.DEBUG
|
16
|
+
logger = s.get_custom_logger("common", LOG_LEVEL)
|
17
|
+
|
18
|
+
|
19
|
+
def validate_config(config):
|
20
|
+
r"""
|
21
|
+
Validate the provided configuration file.
|
22
|
+
A ValueError exception will be thrown in case the config file is invalid
|
23
|
+
|
24
|
+
Parameters
|
25
|
+
----------
|
26
|
+
config : dictionary
|
27
|
+
Configuration for the tablemodel
|
28
|
+
|
29
|
+
Returns
|
30
|
+
-------
|
31
|
+
bool : True on success
|
32
|
+
"""
|
33
|
+
if "model" not in config:
|
34
|
+
return True
|
35
|
+
if "preparation" not in config:
|
36
|
+
return True
|
37
|
+
assert (
|
38
|
+
"max_tag_len" in config["preparation"]
|
39
|
+
), "Config error: 'preparation.max_tag_len' parameter is missing"
|
40
|
+
if "seq_len" in config["model"]:
|
41
|
+
assert (
|
42
|
+
config["model"]["seq_len"] > 0
|
43
|
+
), "Config error: 'model.seq_len' should be positive"
|
44
|
+
assert config["model"]["seq_len"] <= (
|
45
|
+
config["preparation"]["max_tag_len"] + 2
|
46
|
+
), "Config error: 'model.seq_len' should be up to 'preparation.max_tag_len' + 2"
|
47
|
+
|
48
|
+
return True
|
49
|
+
|
50
|
+
|
51
|
+
def parse_arguments():
|
52
|
+
r"""
|
53
|
+
Parse the input arguments
|
54
|
+
A ValueError exception will be thrown in case the config file is invalid
|
55
|
+
"""
|
56
|
+
parser = argparse.ArgumentParser(description="Train the TableModel")
|
57
|
+
parser.add_argument(
|
58
|
+
"-c", "--config", required=True, default=None, help="configuration file (JSON)"
|
59
|
+
)
|
60
|
+
args = parser.parse_args()
|
61
|
+
config_filename = args.config
|
62
|
+
|
63
|
+
assert os.path.isfile(config_filename), "FAILURE: Config file not found."
|
64
|
+
return read_config(config_filename)
|
65
|
+
|
66
|
+
|
67
|
+
def read_config(config_filename):
|
68
|
+
with open(config_filename, "r") as fd:
|
69
|
+
config = json.load(fd)
|
70
|
+
|
71
|
+
# Validate the config file
|
72
|
+
validate_config(config)
|
73
|
+
|
74
|
+
return config
|
75
|
+
|
76
|
+
|
77
|
+
def safe_get_parameter(input_dict, index_path, default=None, required=False):
|
78
|
+
r"""
|
79
|
+
Safe get parameter from a nested dictionary.
|
80
|
+
|
81
|
+
Provide a nested dictionary (dictionary of dictionaries) and a list of indices:
|
82
|
+
- If the whole index path exists the value pointed by it is returned
|
83
|
+
- Otherwise the default value is returned.
|
84
|
+
|
85
|
+
Input:
|
86
|
+
input_dict: Data structure of nested dictionaries.
|
87
|
+
index_path: List with the indices path to follow inside the input_dict.
|
88
|
+
default: Default value to return if the indices path is broken.
|
89
|
+
required: If true a ValueError exception will be raised in case the parameter does not exist
|
90
|
+
Output:
|
91
|
+
The value pointed by the index path or "default".
|
92
|
+
"""
|
93
|
+
if input_dict is None or index_path is None:
|
94
|
+
return default
|
95
|
+
|
96
|
+
d = input_dict
|
97
|
+
for i in index_path[:-1]:
|
98
|
+
if i not in d:
|
99
|
+
if required:
|
100
|
+
raise ValueError("Missing parameter: {}".format(i))
|
101
|
+
return default
|
102
|
+
d = d[i]
|
103
|
+
|
104
|
+
last_index = index_path[-1]
|
105
|
+
if last_index not in d:
|
106
|
+
if required:
|
107
|
+
raise ValueError("Missing parameter: {}".format(last_index))
|
108
|
+
return default
|
109
|
+
|
110
|
+
return d[last_index]
|
111
|
+
|
112
|
+
|
113
|
+
def get_prepared_data_filename(prepared_data_part, dataset_name):
|
114
|
+
r"""
|
115
|
+
Build the full filename of the prepared data part
|
116
|
+
|
117
|
+
Parameters
|
118
|
+
----------
|
119
|
+
prepared_data_part : string
|
120
|
+
Part of the prepared data
|
121
|
+
dataset_name : string
|
122
|
+
Name of the dataset
|
123
|
+
|
124
|
+
Returns
|
125
|
+
-------
|
126
|
+
string
|
127
|
+
The full filename for the prepared file
|
128
|
+
"""
|
129
|
+
template = s.PREPARED_DATA_PARTS[prepared_data_part]
|
130
|
+
if "<POSTFIX>" in template:
|
131
|
+
template = template.replace("<POSTFIX>", dataset_name)
|
132
|
+
return template
|
133
|
+
|
134
|
+
|
135
|
+
def create_dataset_and_model(config, purpose, fixed_padding=False):
|
136
|
+
r"""
|
137
|
+
Gets a model from configuration
|
138
|
+
|
139
|
+
Parameters
|
140
|
+
---------
|
141
|
+
config : Dictionary
|
142
|
+
The configuration of the model
|
143
|
+
purpose : string
|
144
|
+
One of "train", "eval", "predict"
|
145
|
+
fixed_padding : bool
|
146
|
+
Parameter passed to the constructor of the DataLoader
|
147
|
+
|
148
|
+
Returns
|
149
|
+
-------
|
150
|
+
In case a Model cannot be initialized return None, None, None. Otherwise:
|
151
|
+
|
152
|
+
device : selected device
|
153
|
+
dataset : Instance of the DataLoader
|
154
|
+
model : Instance of the model
|
155
|
+
"""
|
156
|
+
from docling_ibm_models.tableformer.data_management.tf_dataset import TFDataset
|
157
|
+
|
158
|
+
model_type = config["model"]["type"]
|
159
|
+
model = None
|
160
|
+
|
161
|
+
# Get env vars:
|
162
|
+
use_cpu_only = os.environ.get("USE_CPU_ONLY", False)
|
163
|
+
use_cuda_only = not use_cpu_only
|
164
|
+
|
165
|
+
# Use the cpu for the evaluation
|
166
|
+
device = "cpu" # Default, run on CPU
|
167
|
+
num_gpus = torch.cuda.device_count() # Check if GPU is available
|
168
|
+
if use_cuda_only:
|
169
|
+
device = "cuda:0" if num_gpus > 0 else "cpu" # Run on first available GPU
|
170
|
+
else:
|
171
|
+
device = "cpu"
|
172
|
+
|
173
|
+
# Create the DataLoader
|
174
|
+
# loader = DataLoader(config, purpose, fixed_padding=fixed_padding)
|
175
|
+
dataset = TFDataset(config, purpose, fixed_padding=fixed_padding)
|
176
|
+
dataset.set_device(device)
|
177
|
+
dataset_val = None
|
178
|
+
if config["train"]["validation"] and purpose == "train":
|
179
|
+
dataset_val = TFDataset(config, "val", fixed_padding=fixed_padding)
|
180
|
+
dataset_val.set_device(device)
|
181
|
+
if model_type == "TableModel04_rs":
|
182
|
+
from docling_ibm_models.tableformer.models.table04_rs.tablemodel04_rs import ( # noqa: F401
|
183
|
+
TableModel04_rs,
|
184
|
+
)
|
185
|
+
# Find the model class and create an instance of it
|
186
|
+
for candidate in BaseModel.__subclasses__():
|
187
|
+
if candidate.__name__ == model_type:
|
188
|
+
init_data = dataset.get_init_data()
|
189
|
+
model = candidate(config, init_data, purpose, device)
|
190
|
+
|
191
|
+
if model is None:
|
192
|
+
logger.warn("Not found model: " + str(model_type))
|
193
|
+
return None, None, None
|
194
|
+
|
195
|
+
logger.info("Found model: " + str(model_type))
|
196
|
+
|
197
|
+
if purpose == s.PREDICT_PURPOSE:
|
198
|
+
return device, dataset, model
|
199
|
+
else:
|
200
|
+
return device, dataset, dataset_val, model
|
File without changes
|