docling-ibm-models 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (32) hide show
  1. docling_ibm_models-0.1.0/LICENSE +21 -0
  2. docling_ibm_models-0.1.0/PKG-INFO +172 -0
  3. docling_ibm_models-0.1.0/README.md +127 -0
  4. docling_ibm_models-0.1.0/docling_ibm_models/layoutmodel/layout_predictor.py +171 -0
  5. docling_ibm_models-0.1.0/docling_ibm_models/tableformer/__init__.py +0 -0
  6. docling_ibm_models-0.1.0/docling_ibm_models/tableformer/common.py +200 -0
  7. docling_ibm_models-0.1.0/docling_ibm_models/tableformer/data_management/__init__.py +0 -0
  8. docling_ibm_models-0.1.0/docling_ibm_models/tableformer/data_management/data_transformer.py +504 -0
  9. docling_ibm_models-0.1.0/docling_ibm_models/tableformer/data_management/functional.py +574 -0
  10. docling_ibm_models-0.1.0/docling_ibm_models/tableformer/data_management/matching_post_processor.py +1325 -0
  11. docling_ibm_models-0.1.0/docling_ibm_models/tableformer/data_management/tf_cell_matcher.py +596 -0
  12. docling_ibm_models-0.1.0/docling_ibm_models/tableformer/data_management/tf_dataset.py +1233 -0
  13. docling_ibm_models-0.1.0/docling_ibm_models/tableformer/data_management/tf_predictor.py +1020 -0
  14. docling_ibm_models-0.1.0/docling_ibm_models/tableformer/data_management/transforms.py +396 -0
  15. docling_ibm_models-0.1.0/docling_ibm_models/tableformer/models/__init__.py +0 -0
  16. docling_ibm_models-0.1.0/docling_ibm_models/tableformer/models/common/__init__.py +0 -0
  17. docling_ibm_models-0.1.0/docling_ibm_models/tableformer/models/common/base_model.py +279 -0
  18. docling_ibm_models-0.1.0/docling_ibm_models/tableformer/models/table04_rs/__init__.py +0 -0
  19. docling_ibm_models-0.1.0/docling_ibm_models/tableformer/models/table04_rs/bbox_decoder_rs.py +163 -0
  20. docling_ibm_models-0.1.0/docling_ibm_models/tableformer/models/table04_rs/encoder04_rs.py +72 -0
  21. docling_ibm_models-0.1.0/docling_ibm_models/tableformer/models/table04_rs/tablemodel04_rs.py +324 -0
  22. docling_ibm_models-0.1.0/docling_ibm_models/tableformer/models/table04_rs/transformer_rs.py +203 -0
  23. docling_ibm_models-0.1.0/docling_ibm_models/tableformer/otsl.py +541 -0
  24. docling_ibm_models-0.1.0/docling_ibm_models/tableformer/settings.py +90 -0
  25. docling_ibm_models-0.1.0/docling_ibm_models/tableformer/test_dataset_cache.py +37 -0
  26. docling_ibm_models-0.1.0/docling_ibm_models/tableformer/test_prepare_image.py +99 -0
  27. docling_ibm_models-0.1.0/docling_ibm_models/tableformer/utils/__init__.py +0 -0
  28. docling_ibm_models-0.1.0/docling_ibm_models/tableformer/utils/app_profiler.py +243 -0
  29. docling_ibm_models-0.1.0/docling_ibm_models/tableformer/utils/torch_utils.py +216 -0
  30. docling_ibm_models-0.1.0/docling_ibm_models/tableformer/utils/utils.py +376 -0
  31. docling_ibm_models-0.1.0/docling_ibm_models/tableformer/utils/variance.py +175 -0
  32. docling_ibm_models-0.1.0/pyproject.toml +82 -0
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) [year] [fullname]
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,172 @@
1
+ Metadata-Version: 2.1
2
+ Name: docling-ibm-models
3
+ Version: 0.1.0
4
+ Summary: This package contains the AI models used by the Docling PDF conversion package
5
+ License: MIT
6
+ Keywords: docling,convert,document,pdf,layout model,segmentation,table structure,table former
7
+ Author: Nikos Livathinos
8
+ Author-email: nli@zurich.ibm.com
9
+ Requires-Python: >=3.11,<4.0
10
+ Classifier: Development Status :: 5 - Production/Stable
11
+ Classifier: Intended Audience :: Developers
12
+ Classifier: Intended Audience :: Science/Research
13
+ Classifier: License :: OSI Approved :: MIT License
14
+ Classifier: Operating System :: MacOS :: MacOS X
15
+ Classifier: Operating System :: POSIX :: Linux
16
+ Classifier: Programming Language :: Python :: 3
17
+ Classifier: Programming Language :: Python :: 3.11
18
+ Classifier: Programming Language :: Python :: 3.12
19
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
20
+ Requires-Dist: Distance (>=0.1.3,<0.2.0)
21
+ Requires-Dist: Pillow (>=10.0.0,<11.0.0)
22
+ Requires-Dist: apted (>=1.0.3,<2.0.0)
23
+ Requires-Dist: jsonlines (>=3.1.0,<4.0.0)
24
+ Requires-Dist: lxml (>=4.9.1,<5.0.0)
25
+ Requires-Dist: mean_average_precision (>=2021.4.26.0,<2022.0.0.0)
26
+ Requires-Dist: numpy (>=1.24.4,<2.0.0)
27
+ Requires-Dist: onnxruntime (>=1.16.2,<2.0.0)
28
+ Requires-Dist: opencv-python (>=4.9.0.80,<5.0.0.0) ; sys_platform != "linux"
29
+ Requires-Dist: opencv-python-headless (>=4.9.0.80,<5.0.0.0) ; sys_platform == "linux"
30
+ Requires-Dist: torch @ https://download.pytorch.org/whl/cpu/torch-2.2.2%2Bcpu-cp311-cp311-linux_x86_64.whl ; python_version == "3.11" and platform_machine == "x86_64" and sys_platform == "linux"
31
+ Requires-Dist: torch @ https://download.pytorch.org/whl/cpu/torch-2.2.2%2Bcpu-cp312-cp312-linux_x86_64.whl ; python_version == "3.12" and platform_machine == "x86_64" and sys_platform == "linux"
32
+ Requires-Dist: torch @ https://download.pytorch.org/whl/cpu/torch-2.2.2-cp311-none-macosx_10_9_x86_64.whl ; python_version == "3.11" and platform_machine == "x86_64" and sys_platform == "darwin"
33
+ Requires-Dist: torch @ https://download.pytorch.org/whl/cpu/torch-2.2.2-cp311-none-macosx_11_0_arm64.whl ; python_version == "3.11" and platform_machine == "arm64" and sys_platform == "darwin"
34
+ Requires-Dist: torch @ https://download.pytorch.org/whl/cpu/torch-2.2.2-cp312-none-macosx_10_9_x86_64.whl ; python_version == "3.12" and platform_machine == "x86_64" and sys_platform == "darwin"
35
+ Requires-Dist: torch @ https://download.pytorch.org/whl/cpu/torch-2.2.2-cp312-none-macosx_11_0_arm64.whl ; python_version == "3.12" and platform_machine == "arm64" and sys_platform == "darwin"
36
+ Requires-Dist: torchvision @ https://download.pytorch.org/whl/cpu/torchvision-0.17.2%2Bcpu-cp311-cp311-linux_x86_64.whl ; python_version == "3.11" and platform_machine == "x86_64" and sys_platform == "linux"
37
+ Requires-Dist: torchvision @ https://download.pytorch.org/whl/cpu/torchvision-0.17.2%2Bcpu-cp312-cp312-linux_x86_64.whl ; python_version == "3.12" and platform_machine == "x86_64" and sys_platform == "linux"
38
+ Requires-Dist: torchvision @ https://download.pytorch.org/whl/cpu/torchvision-0.17.2-cp311-cp311-macosx_10_13_x86_64.whl ; python_version == "3.11" and platform_machine == "x86_64" and sys_platform == "darwin"
39
+ Requires-Dist: torchvision @ https://download.pytorch.org/whl/cpu/torchvision-0.17.2-cp311-cp311-macosx_11_0_arm64.whl ; python_version == "3.11" and platform_machine == "arm64" and sys_platform == "darwin"
40
+ Requires-Dist: torchvision @ https://download.pytorch.org/whl/cpu/torchvision-0.17.2-cp312-cp312-macosx_10_13_x86_64.whl ; python_version == "3.12" and platform_machine == "x86_64" and sys_platform == "darwin"
41
+ Requires-Dist: torchvision @ https://download.pytorch.org/whl/cpu/torchvision-0.17.2-cp312-cp312-macosx_11_0_arm64.whl ; python_version == "3.12" and platform_machine == "arm64" and sys_platform == "darwin"
42
+ Requires-Dist: tqdm (>=4.64.0,<5.0.0)
43
+ Description-Content-Type: text/markdown
44
+
45
+ # Docling-models
46
+
47
+ AI modules to support the Dockling PDF document conversion project.
48
+
49
+ - TableFormer is an AI module that recognizes the structure of a table and the bounding boxes of the table content.
50
+ - Layout model is an AI model that provides among other things ability to detect tables on the page. This package contains inference code for Layout model.
51
+
52
+
53
+ ## Installation Instructions
54
+
55
+ ### MacOS / Linux
56
+
57
+ To install `poetry` locally, use either `pip` or `homebrew`.
58
+
59
+ To install `poetry` on a docker container, do the following:
60
+ ```
61
+ ENV POETRY_NO_INTERACTION=1 \
62
+ POETRY_VIRTUALENVS_CREATE=false
63
+
64
+ # Install poetry
65
+ RUN curl -sSL 'https://install.python-poetry.org' > install-poetry.py \
66
+ && python install-poetry.py \
67
+ && poetry --version \
68
+ && rm install-poetry.py
69
+ ```
70
+
71
+ To install and run the package, simply set up a poetry environment
72
+
73
+ ```
74
+ poetry env use $(which python3.11)
75
+ poetry shell
76
+ ```
77
+
78
+ and install all the dependencies,
79
+
80
+ ```
81
+ poetry install # this will only install the deps from the poetry.lock
82
+
83
+ poetry install --no-dev # this will skip installing dev dependencies
84
+ ```
85
+
86
+ To update or add new dependencies from `pyproject.toml`, rebuild `poetry.lock`
87
+ ```
88
+ poetry update
89
+ ```
90
+
91
+
92
+ ## Pipeline Overview
93
+ ![Architecture](docs/tablemodel_overview_color.png)
94
+
95
+ ## Datasets
96
+ Below we list datasets used with their description, source, and ***"TableFormer Format"***. The TableFormer Format is our processed version of the version of the original format to work with the dataloader out of the box, and to augment the dataset when necassary to add missing groundtruth (bounding boxes for empty cells).
97
+
98
+
99
+ | Name | Description | URL |
100
+ | ------------- |:-------------:|----|
101
+ | PubTabNet | PubTabNet contains heterogeneous tables in both image and HTML format, 516k+ tables in the PubMed Central Open Access Subset | [PubTabNet](https://developer.ibm.com/exchanges/data/all/pubtabnet/) |
102
+ | FinTabNet| A dataset for Financial Report Tables with corresponding ground truth location and structure. 112k+ tables included.| [FinTabNet](https://developer.ibm.com/exchanges/data/all/fintabnet/) |
103
+ | TableBank| TableBank is a new image-based table detection and recognition dataset built with novel weak supervision from Word and Latex documents on the internet, contains 417K high-quality labeled tables. | [TableBank](https://github.com/doc-analysis/TableBank) |
104
+
105
+ ## Models
106
+
107
+ ### TableModel04:
108
+ ![TableModel04](docs/tbm04.png)
109
+ **TableModel04rs (OTSL)** is our SOTA method that using transformers in order to predict table structure and bounding box.
110
+
111
+
112
+ ## Configuration file
113
+
114
+ Example configuration can be seen inside test `tests/test_tf_predictor.py`
115
+ These are the main sections of the configuration file:
116
+
117
+ - `dataset`: The directory for prepared data and the parameters used during the data loading.
118
+ - `model`: The type, name and hyperparameters of the model. Also the directory to save/load the
119
+ trained checkpoint files.
120
+ - `train`: Parameters for the training of the model.
121
+ - `predict`: Parameters for the evaluation of the model.
122
+ - `dataset_wordmap`: Very important part that contains token maps.
123
+
124
+
125
+ ## Model weights
126
+
127
+ You can download the model weights and config files from the links:
128
+
129
+ - [TableFormer Checkpoint](https://huggingface.co/ds4sd/docling-models/tree/main/model_artifacts/tableformer)
130
+ - [beehive_v0.0.5](https://huggingface.co/ds4sd/docling-models/tree/main/model_artifacts/layout/beehive_v0.0.5)
131
+
132
+ Place the downloaded files into `tests/test_data/model_artifacts/` directory.
133
+
134
+
135
+ ## Inference Tests
136
+
137
+ This contains unit tests for Docling models.
138
+
139
+ First download the model weights (see above), then run:
140
+ ```
141
+ ./devtools/check_code.sh
142
+ ```
143
+
144
+ This will also generate prediction and matching visualizations that can be found here:
145
+ `tests\test_data\viz\`
146
+
147
+ Visualization outlines:
148
+ - `Light Pink`: border of recognized table
149
+ - `Grey`: OCR cells
150
+ - `Green`: prediction bboxes
151
+ - `Red`: OCR cells matched with prediction
152
+ - `Blue`: Post processed, match
153
+ - `Bold Blue`: column header
154
+ - `Bold Magenta`: row header
155
+ - `Bold Brown`: section row (if table have one)
156
+
157
+
158
+ ## Demo
159
+
160
+ A demo application allows to apply the `LayoutPredictor` on a directory `<input_dir>` that contains
161
+ `png` images and visualize the predictions inside another directory `<viz_dir>`.
162
+
163
+ First download the model weights (see above), then run:
164
+ ```
165
+ python -m demo.demo_layout_predictor -i <input_dir> -v <viz_dir>
166
+ ```
167
+
168
+ e.g.
169
+ ```
170
+ python -m demo.demo_layout_predictor -i tests/test_data/samples -v viz/
171
+ ```
172
+
@@ -0,0 +1,127 @@
1
+ # Docling-models
2
+
3
+ AI modules to support the Dockling PDF document conversion project.
4
+
5
+ - TableFormer is an AI module that recognizes the structure of a table and the bounding boxes of the table content.
6
+ - Layout model is an AI model that provides among other things ability to detect tables on the page. This package contains inference code for Layout model.
7
+
8
+
9
+ ## Installation Instructions
10
+
11
+ ### MacOS / Linux
12
+
13
+ To install `poetry` locally, use either `pip` or `homebrew`.
14
+
15
+ To install `poetry` on a docker container, do the following:
16
+ ```
17
+ ENV POETRY_NO_INTERACTION=1 \
18
+ POETRY_VIRTUALENVS_CREATE=false
19
+
20
+ # Install poetry
21
+ RUN curl -sSL 'https://install.python-poetry.org' > install-poetry.py \
22
+ && python install-poetry.py \
23
+ && poetry --version \
24
+ && rm install-poetry.py
25
+ ```
26
+
27
+ To install and run the package, simply set up a poetry environment
28
+
29
+ ```
30
+ poetry env use $(which python3.11)
31
+ poetry shell
32
+ ```
33
+
34
+ and install all the dependencies,
35
+
36
+ ```
37
+ poetry install # this will only install the deps from the poetry.lock
38
+
39
+ poetry install --no-dev # this will skip installing dev dependencies
40
+ ```
41
+
42
+ To update or add new dependencies from `pyproject.toml`, rebuild `poetry.lock`
43
+ ```
44
+ poetry update
45
+ ```
46
+
47
+
48
+ ## Pipeline Overview
49
+ ![Architecture](docs/tablemodel_overview_color.png)
50
+
51
+ ## Datasets
52
+ Below we list datasets used with their description, source, and ***"TableFormer Format"***. The TableFormer Format is our processed version of the version of the original format to work with the dataloader out of the box, and to augment the dataset when necassary to add missing groundtruth (bounding boxes for empty cells).
53
+
54
+
55
+ | Name | Description | URL |
56
+ | ------------- |:-------------:|----|
57
+ | PubTabNet | PubTabNet contains heterogeneous tables in both image and HTML format, 516k+ tables in the PubMed Central Open Access Subset | [PubTabNet](https://developer.ibm.com/exchanges/data/all/pubtabnet/) |
58
+ | FinTabNet| A dataset for Financial Report Tables with corresponding ground truth location and structure. 112k+ tables included.| [FinTabNet](https://developer.ibm.com/exchanges/data/all/fintabnet/) |
59
+ | TableBank| TableBank is a new image-based table detection and recognition dataset built with novel weak supervision from Word and Latex documents on the internet, contains 417K high-quality labeled tables. | [TableBank](https://github.com/doc-analysis/TableBank) |
60
+
61
+ ## Models
62
+
63
+ ### TableModel04:
64
+ ![TableModel04](docs/tbm04.png)
65
+ **TableModel04rs (OTSL)** is our SOTA method that using transformers in order to predict table structure and bounding box.
66
+
67
+
68
+ ## Configuration file
69
+
70
+ Example configuration can be seen inside test `tests/test_tf_predictor.py`
71
+ These are the main sections of the configuration file:
72
+
73
+ - `dataset`: The directory for prepared data and the parameters used during the data loading.
74
+ - `model`: The type, name and hyperparameters of the model. Also the directory to save/load the
75
+ trained checkpoint files.
76
+ - `train`: Parameters for the training of the model.
77
+ - `predict`: Parameters for the evaluation of the model.
78
+ - `dataset_wordmap`: Very important part that contains token maps.
79
+
80
+
81
+ ## Model weights
82
+
83
+ You can download the model weights and config files from the links:
84
+
85
+ - [TableFormer Checkpoint](https://huggingface.co/ds4sd/docling-models/tree/main/model_artifacts/tableformer)
86
+ - [beehive_v0.0.5](https://huggingface.co/ds4sd/docling-models/tree/main/model_artifacts/layout/beehive_v0.0.5)
87
+
88
+ Place the downloaded files into `tests/test_data/model_artifacts/` directory.
89
+
90
+
91
+ ## Inference Tests
92
+
93
+ This contains unit tests for Docling models.
94
+
95
+ First download the model weights (see above), then run:
96
+ ```
97
+ ./devtools/check_code.sh
98
+ ```
99
+
100
+ This will also generate prediction and matching visualizations that can be found here:
101
+ `tests\test_data\viz\`
102
+
103
+ Visualization outlines:
104
+ - `Light Pink`: border of recognized table
105
+ - `Grey`: OCR cells
106
+ - `Green`: prediction bboxes
107
+ - `Red`: OCR cells matched with prediction
108
+ - `Blue`: Post processed, match
109
+ - `Bold Blue`: column header
110
+ - `Bold Magenta`: row header
111
+ - `Bold Brown`: section row (if table have one)
112
+
113
+
114
+ ## Demo
115
+
116
+ A demo application allows to apply the `LayoutPredictor` on a directory `<input_dir>` that contains
117
+ `png` images and visualize the predictions inside another directory `<viz_dir>`.
118
+
119
+ First download the model weights (see above), then run:
120
+ ```
121
+ python -m demo.demo_layout_predictor -i <input_dir> -v <viz_dir>
122
+ ```
123
+
124
+ e.g.
125
+ ```
126
+ python -m demo.demo_layout_predictor -i tests/test_data/samples -v viz/
127
+ ```
@@ -0,0 +1,171 @@
1
+ #
2
+ # Copyright IBM Corp. 2024 - 2024
3
+ # SPDX-License-Identifier: MIT
4
+ #
5
+ import os
6
+ from collections.abc import Iterable
7
+ from typing import Union
8
+
9
+ import numpy as np
10
+ import onnxruntime as ort
11
+ from PIL import Image
12
+
13
+ MODEL_CHECKPOINT_FN = "model.pt"
14
+ DEFAULT_NUM_THREADS = 4
15
+
16
+
17
+ # Classes:
18
+ CLASSES_MAP = {
19
+ 0: "background",
20
+ 1: "Caption",
21
+ 2: "Footnote",
22
+ 3: "Formula",
23
+ 4: "List-item",
24
+ 5: "Page-footer",
25
+ 6: "Page-header",
26
+ 7: "Picture",
27
+ 8: "Section-header",
28
+ 9: "Table",
29
+ 10: "Text",
30
+ 11: "Title",
31
+ 12: "Document Index",
32
+ 13: "Code",
33
+ 14: "Checkbox-Selected",
34
+ 15: "Checkbox-Unselected",
35
+ 16: "Form",
36
+ 17: "Key-Value Region",
37
+ }
38
+
39
+
40
+ class LayoutPredictor:
41
+ r"""
42
+ Document layout prediction using ONNX
43
+ """
44
+
45
+ def __init__(
46
+ self, artifact_path: str, num_threads: int = None, use_cpu_only: bool = False
47
+ ):
48
+ r"""
49
+ Provide the artifact path that contains the LayoutModel ONNX file
50
+
51
+ The number of threads is decided, in the following order, by:
52
+ 1. The init method parameter `num_threads`, if it is set.
53
+ 2. The envvar "OMP_NUM_THREADS", if it is set.
54
+ 3. The default value DEFAULT_NUM_THREADS.
55
+
56
+ The execution provided is decided, in the following order:
57
+ 1. If the init method parameter `cpu_only` is True or the envvar "USE_CPU_ONLY" is set,
58
+ it uses the "CPUExecutionProvider".
59
+ 3. Otherwise if the "CUDAExecutionProvider" is present, use:
60
+ ["CUDAExecutionProvider", "CPUExecutionProvider"]:
61
+
62
+ Parameters
63
+ ----------
64
+ artifact_path: Path for the model ONNX file.
65
+ num_threads: (Optional) Number of threads to run the inference.
66
+ use_cpu_only: (Optional) If True, it forces CPU as the execution provider.
67
+
68
+ Raises
69
+ ------
70
+ FileNotFoundError when the model's ONNX file is missing
71
+ """
72
+ # Set basic params
73
+ self._threshold = 0.6 # Score threshold
74
+ self._image_size = 640
75
+ self._size = np.asarray([[self._image_size, self._image_size]], dtype=np.int64)
76
+
77
+ # Get env vars
78
+ self._use_cpu_only = use_cpu_only or ("USE_CPU_ONLY" in os.environ)
79
+ if num_threads is None:
80
+ num_threads = int(os.environ.get("OMP_NUM_THREADS", DEFAULT_NUM_THREADS))
81
+ self._num_threads = num_threads
82
+
83
+ # Decide the execution providers
84
+ if (
85
+ not self._use_cpu_only
86
+ and "CUDAExecutionProvider" in ort.get_available_providers()
87
+ ):
88
+ providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
89
+ else:
90
+ providers = ["CPUExecutionProvider"]
91
+ self._providers = providers
92
+
93
+ # Model ONNX file
94
+ self._onnx_fn = os.path.join(artifact_path, MODEL_CHECKPOINT_FN)
95
+ if not os.path.isfile(self._onnx_fn):
96
+ raise FileNotFoundError("Missing ONNX file: {}".format(self._onnx_fn))
97
+
98
+ # ONNX options
99
+ self._options = ort.SessionOptions()
100
+ self._options.intra_op_num_threads = self._num_threads
101
+ self.sess = ort.InferenceSession(
102
+ self._onnx_fn,
103
+ sess_options=self._options,
104
+ providers=self._providers,
105
+ )
106
+
107
+ def info(self) -> dict:
108
+ r"""
109
+ Get information about the configuration of LayoutPredictor
110
+ """
111
+ info = {
112
+ "onnx_file": self._onnx_fn,
113
+ "intra_op_num_threads": self._num_threads,
114
+ "use_cpu_only": self._use_cpu_only,
115
+ "providers": self._providers,
116
+ "image_size": self._image_size,
117
+ "threshold": self._threshold,
118
+ }
119
+ return info
120
+
121
+ def predict(self, orig_img: Union[Image, np.array]) -> Iterable[dict]:
122
+ r"""
123
+ Predict bounding boxes for a given image.
124
+ The origin (0, 0) is the top-left corner and the predicted bbox coords are provided as:
125
+ [left, top, right, bottom]
126
+
127
+ Parameter
128
+ ---------
129
+ origin_img: Image to be predicted as a PIL Image object or numpy array.
130
+
131
+ Yield
132
+ -----
133
+ Bounding box as a dict with the keys: "label", "confidence", "l", "t", "r", "b"
134
+
135
+ Raises
136
+ ------
137
+ TypeError when the input image is not supported
138
+ """
139
+ # Convert image format
140
+ if isinstance(orig_img, Image.Image):
141
+ page_img = orig_img.convert("RGB")
142
+ elif isinstance(orig_img, np.ndarray):
143
+ page_img = Image.fromarray(orig_img).convert("RGB")
144
+ else:
145
+ raise TypeError("Not supported input image format")
146
+
147
+ w, h = page_img.size
148
+ page_img = page_img.resize((self._image_size, self._image_size))
149
+ page_data = np.array(page_img, dtype=np.uint8) / np.float32(255.0)
150
+ page_data = np.expand_dims(np.transpose(page_data, axes=[2, 0, 1]), axis=0)
151
+
152
+ # Predict
153
+ labels, boxes, scores = self.sess.run(
154
+ output_names=None,
155
+ input_feed={
156
+ "images": page_data,
157
+ "orig_target_sizes": self._size,
158
+ },
159
+ )
160
+
161
+ # Yield output
162
+ for label, box, score in zip(labels[0], boxes[0], scores[0]):
163
+ if score > self._threshold:
164
+ yield {
165
+ "l": box[0] / self._image_size * w,
166
+ "t": box[1] / self._image_size * h,
167
+ "r": box[2] / self._image_size * w,
168
+ "b": box[3] / self._image_size * h,
169
+ "label": CLASSES_MAP[label],
170
+ "confidence": score,
171
+ }
@@ -0,0 +1,200 @@
1
+ #
2
+ # Copyright IBM Corp. 2024 - 2024
3
+ # SPDX-License-Identifier: MIT
4
+ #
5
+ import argparse
6
+ import json
7
+ import logging
8
+ import os
9
+
10
+ import torch
11
+
12
+ import docling_ibm_models.tableformer.settings as s
13
+ from docling_ibm_models.tableformer.models.common.base_model import BaseModel
14
+
15
+ LOG_LEVEL = logging.DEBUG
16
+ logger = s.get_custom_logger("common", LOG_LEVEL)
17
+
18
+
19
+ def validate_config(config):
20
+ r"""
21
+ Validate the provided configuration file.
22
+ A ValueError exception will be thrown in case the config file is invalid
23
+
24
+ Parameters
25
+ ----------
26
+ config : dictionary
27
+ Configuration for the tablemodel
28
+
29
+ Returns
30
+ -------
31
+ bool : True on success
32
+ """
33
+ if "model" not in config:
34
+ return True
35
+ if "preparation" not in config:
36
+ return True
37
+ assert (
38
+ "max_tag_len" in config["preparation"]
39
+ ), "Config error: 'preparation.max_tag_len' parameter is missing"
40
+ if "seq_len" in config["model"]:
41
+ assert (
42
+ config["model"]["seq_len"] > 0
43
+ ), "Config error: 'model.seq_len' should be positive"
44
+ assert config["model"]["seq_len"] <= (
45
+ config["preparation"]["max_tag_len"] + 2
46
+ ), "Config error: 'model.seq_len' should be up to 'preparation.max_tag_len' + 2"
47
+
48
+ return True
49
+
50
+
51
+ def parse_arguments():
52
+ r"""
53
+ Parse the input arguments
54
+ A ValueError exception will be thrown in case the config file is invalid
55
+ """
56
+ parser = argparse.ArgumentParser(description="Train the TableModel")
57
+ parser.add_argument(
58
+ "-c", "--config", required=True, default=None, help="configuration file (JSON)"
59
+ )
60
+ args = parser.parse_args()
61
+ config_filename = args.config
62
+
63
+ assert os.path.isfile(config_filename), "FAILURE: Config file not found."
64
+ return read_config(config_filename)
65
+
66
+
67
+ def read_config(config_filename):
68
+ with open(config_filename, "r") as fd:
69
+ config = json.load(fd)
70
+
71
+ # Validate the config file
72
+ validate_config(config)
73
+
74
+ return config
75
+
76
+
77
+ def safe_get_parameter(input_dict, index_path, default=None, required=False):
78
+ r"""
79
+ Safe get parameter from a nested dictionary.
80
+
81
+ Provide a nested dictionary (dictionary of dictionaries) and a list of indices:
82
+ - If the whole index path exists the value pointed by it is returned
83
+ - Otherwise the default value is returned.
84
+
85
+ Input:
86
+ input_dict: Data structure of nested dictionaries.
87
+ index_path: List with the indices path to follow inside the input_dict.
88
+ default: Default value to return if the indices path is broken.
89
+ required: If true a ValueError exception will be raised in case the parameter does not exist
90
+ Output:
91
+ The value pointed by the index path or "default".
92
+ """
93
+ if input_dict is None or index_path is None:
94
+ return default
95
+
96
+ d = input_dict
97
+ for i in index_path[:-1]:
98
+ if i not in d:
99
+ if required:
100
+ raise ValueError("Missing parameter: {}".format(i))
101
+ return default
102
+ d = d[i]
103
+
104
+ last_index = index_path[-1]
105
+ if last_index not in d:
106
+ if required:
107
+ raise ValueError("Missing parameter: {}".format(last_index))
108
+ return default
109
+
110
+ return d[last_index]
111
+
112
+
113
+ def get_prepared_data_filename(prepared_data_part, dataset_name):
114
+ r"""
115
+ Build the full filename of the prepared data part
116
+
117
+ Parameters
118
+ ----------
119
+ prepared_data_part : string
120
+ Part of the prepared data
121
+ dataset_name : string
122
+ Name of the dataset
123
+
124
+ Returns
125
+ -------
126
+ string
127
+ The full filename for the prepared file
128
+ """
129
+ template = s.PREPARED_DATA_PARTS[prepared_data_part]
130
+ if "<POSTFIX>" in template:
131
+ template = template.replace("<POSTFIX>", dataset_name)
132
+ return template
133
+
134
+
135
+ def create_dataset_and_model(config, purpose, fixed_padding=False):
136
+ r"""
137
+ Gets a model from configuration
138
+
139
+ Parameters
140
+ ---------
141
+ config : Dictionary
142
+ The configuration of the model
143
+ purpose : string
144
+ One of "train", "eval", "predict"
145
+ fixed_padding : bool
146
+ Parameter passed to the constructor of the DataLoader
147
+
148
+ Returns
149
+ -------
150
+ In case a Model cannot be initialized return None, None, None. Otherwise:
151
+
152
+ device : selected device
153
+ dataset : Instance of the DataLoader
154
+ model : Instance of the model
155
+ """
156
+ from docling_ibm_models.tableformer.data_management.tf_dataset import TFDataset
157
+
158
+ model_type = config["model"]["type"]
159
+ model = None
160
+
161
+ # Get env vars:
162
+ use_cpu_only = os.environ.get("USE_CPU_ONLY", False)
163
+ use_cuda_only = not use_cpu_only
164
+
165
+ # Use the cpu for the evaluation
166
+ device = "cpu" # Default, run on CPU
167
+ num_gpus = torch.cuda.device_count() # Check if GPU is available
168
+ if use_cuda_only:
169
+ device = "cuda:0" if num_gpus > 0 else "cpu" # Run on first available GPU
170
+ else:
171
+ device = "cpu"
172
+
173
+ # Create the DataLoader
174
+ # loader = DataLoader(config, purpose, fixed_padding=fixed_padding)
175
+ dataset = TFDataset(config, purpose, fixed_padding=fixed_padding)
176
+ dataset.set_device(device)
177
+ dataset_val = None
178
+ if config["train"]["validation"] and purpose == "train":
179
+ dataset_val = TFDataset(config, "val", fixed_padding=fixed_padding)
180
+ dataset_val.set_device(device)
181
+ if model_type == "TableModel04_rs":
182
+ from docling_ibm_models.tableformer.models.table04_rs.tablemodel04_rs import ( # noqa: F401
183
+ TableModel04_rs,
184
+ )
185
+ # Find the model class and create an instance of it
186
+ for candidate in BaseModel.__subclasses__():
187
+ if candidate.__name__ == model_type:
188
+ init_data = dataset.get_init_data()
189
+ model = candidate(config, init_data, purpose, device)
190
+
191
+ if model is None:
192
+ logger.warn("Not found model: " + str(model_type))
193
+ return None, None, None
194
+
195
+ logger.info("Found model: " + str(model_type))
196
+
197
+ if purpose == s.PREDICT_PURPOSE:
198
+ return device, dataset, model
199
+ else:
200
+ return device, dataset, dataset_val, model