docling-ibm-models 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docling_ibm_models-0.1.0/LICENSE +21 -0
- docling_ibm_models-0.1.0/PKG-INFO +172 -0
- docling_ibm_models-0.1.0/README.md +127 -0
- docling_ibm_models-0.1.0/docling_ibm_models/layoutmodel/layout_predictor.py +171 -0
- docling_ibm_models-0.1.0/docling_ibm_models/tableformer/__init__.py +0 -0
- docling_ibm_models-0.1.0/docling_ibm_models/tableformer/common.py +200 -0
- docling_ibm_models-0.1.0/docling_ibm_models/tableformer/data_management/__init__.py +0 -0
- docling_ibm_models-0.1.0/docling_ibm_models/tableformer/data_management/data_transformer.py +504 -0
- docling_ibm_models-0.1.0/docling_ibm_models/tableformer/data_management/functional.py +574 -0
- docling_ibm_models-0.1.0/docling_ibm_models/tableformer/data_management/matching_post_processor.py +1325 -0
- docling_ibm_models-0.1.0/docling_ibm_models/tableformer/data_management/tf_cell_matcher.py +596 -0
- docling_ibm_models-0.1.0/docling_ibm_models/tableformer/data_management/tf_dataset.py +1233 -0
- docling_ibm_models-0.1.0/docling_ibm_models/tableformer/data_management/tf_predictor.py +1020 -0
- docling_ibm_models-0.1.0/docling_ibm_models/tableformer/data_management/transforms.py +396 -0
- docling_ibm_models-0.1.0/docling_ibm_models/tableformer/models/__init__.py +0 -0
- docling_ibm_models-0.1.0/docling_ibm_models/tableformer/models/common/__init__.py +0 -0
- docling_ibm_models-0.1.0/docling_ibm_models/tableformer/models/common/base_model.py +279 -0
- docling_ibm_models-0.1.0/docling_ibm_models/tableformer/models/table04_rs/__init__.py +0 -0
- docling_ibm_models-0.1.0/docling_ibm_models/tableformer/models/table04_rs/bbox_decoder_rs.py +163 -0
- docling_ibm_models-0.1.0/docling_ibm_models/tableformer/models/table04_rs/encoder04_rs.py +72 -0
- docling_ibm_models-0.1.0/docling_ibm_models/tableformer/models/table04_rs/tablemodel04_rs.py +324 -0
- docling_ibm_models-0.1.0/docling_ibm_models/tableformer/models/table04_rs/transformer_rs.py +203 -0
- docling_ibm_models-0.1.0/docling_ibm_models/tableformer/otsl.py +541 -0
- docling_ibm_models-0.1.0/docling_ibm_models/tableformer/settings.py +90 -0
- docling_ibm_models-0.1.0/docling_ibm_models/tableformer/test_dataset_cache.py +37 -0
- docling_ibm_models-0.1.0/docling_ibm_models/tableformer/test_prepare_image.py +99 -0
- docling_ibm_models-0.1.0/docling_ibm_models/tableformer/utils/__init__.py +0 -0
- docling_ibm_models-0.1.0/docling_ibm_models/tableformer/utils/app_profiler.py +243 -0
- docling_ibm_models-0.1.0/docling_ibm_models/tableformer/utils/torch_utils.py +216 -0
- docling_ibm_models-0.1.0/docling_ibm_models/tableformer/utils/utils.py +376 -0
- docling_ibm_models-0.1.0/docling_ibm_models/tableformer/utils/variance.py +175 -0
- docling_ibm_models-0.1.0/pyproject.toml +82 -0
@@ -0,0 +1,21 @@
|
|
1
|
+
MIT License
|
2
|
+
|
3
|
+
Copyright (c) [year] [fullname]
|
4
|
+
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
7
|
+
in the Software without restriction, including without limitation the rights
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
10
|
+
furnished to do so, subject to the following conditions:
|
11
|
+
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
13
|
+
copies or substantial portions of the Software.
|
14
|
+
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21
|
+
SOFTWARE.
|
@@ -0,0 +1,172 @@
|
|
1
|
+
Metadata-Version: 2.1
|
2
|
+
Name: docling-ibm-models
|
3
|
+
Version: 0.1.0
|
4
|
+
Summary: This package contains the AI models used by the Docling PDF conversion package
|
5
|
+
License: MIT
|
6
|
+
Keywords: docling,convert,document,pdf,layout model,segmentation,table structure,table former
|
7
|
+
Author: Nikos Livathinos
|
8
|
+
Author-email: nli@zurich.ibm.com
|
9
|
+
Requires-Python: >=3.11,<4.0
|
10
|
+
Classifier: Development Status :: 5 - Production/Stable
|
11
|
+
Classifier: Intended Audience :: Developers
|
12
|
+
Classifier: Intended Audience :: Science/Research
|
13
|
+
Classifier: License :: OSI Approved :: MIT License
|
14
|
+
Classifier: Operating System :: MacOS :: MacOS X
|
15
|
+
Classifier: Operating System :: POSIX :: Linux
|
16
|
+
Classifier: Programming Language :: Python :: 3
|
17
|
+
Classifier: Programming Language :: Python :: 3.11
|
18
|
+
Classifier: Programming Language :: Python :: 3.12
|
19
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
20
|
+
Requires-Dist: Distance (>=0.1.3,<0.2.0)
|
21
|
+
Requires-Dist: Pillow (>=10.0.0,<11.0.0)
|
22
|
+
Requires-Dist: apted (>=1.0.3,<2.0.0)
|
23
|
+
Requires-Dist: jsonlines (>=3.1.0,<4.0.0)
|
24
|
+
Requires-Dist: lxml (>=4.9.1,<5.0.0)
|
25
|
+
Requires-Dist: mean_average_precision (>=2021.4.26.0,<2022.0.0.0)
|
26
|
+
Requires-Dist: numpy (>=1.24.4,<2.0.0)
|
27
|
+
Requires-Dist: onnxruntime (>=1.16.2,<2.0.0)
|
28
|
+
Requires-Dist: opencv-python (>=4.9.0.80,<5.0.0.0) ; sys_platform != "linux"
|
29
|
+
Requires-Dist: opencv-python-headless (>=4.9.0.80,<5.0.0.0) ; sys_platform == "linux"
|
30
|
+
Requires-Dist: torch @ https://download.pytorch.org/whl/cpu/torch-2.2.2%2Bcpu-cp311-cp311-linux_x86_64.whl ; python_version == "3.11" and platform_machine == "x86_64" and sys_platform == "linux"
|
31
|
+
Requires-Dist: torch @ https://download.pytorch.org/whl/cpu/torch-2.2.2%2Bcpu-cp312-cp312-linux_x86_64.whl ; python_version == "3.12" and platform_machine == "x86_64" and sys_platform == "linux"
|
32
|
+
Requires-Dist: torch @ https://download.pytorch.org/whl/cpu/torch-2.2.2-cp311-none-macosx_10_9_x86_64.whl ; python_version == "3.11" and platform_machine == "x86_64" and sys_platform == "darwin"
|
33
|
+
Requires-Dist: torch @ https://download.pytorch.org/whl/cpu/torch-2.2.2-cp311-none-macosx_11_0_arm64.whl ; python_version == "3.11" and platform_machine == "arm64" and sys_platform == "darwin"
|
34
|
+
Requires-Dist: torch @ https://download.pytorch.org/whl/cpu/torch-2.2.2-cp312-none-macosx_10_9_x86_64.whl ; python_version == "3.12" and platform_machine == "x86_64" and sys_platform == "darwin"
|
35
|
+
Requires-Dist: torch @ https://download.pytorch.org/whl/cpu/torch-2.2.2-cp312-none-macosx_11_0_arm64.whl ; python_version == "3.12" and platform_machine == "arm64" and sys_platform == "darwin"
|
36
|
+
Requires-Dist: torchvision @ https://download.pytorch.org/whl/cpu/torchvision-0.17.2%2Bcpu-cp311-cp311-linux_x86_64.whl ; python_version == "3.11" and platform_machine == "x86_64" and sys_platform == "linux"
|
37
|
+
Requires-Dist: torchvision @ https://download.pytorch.org/whl/cpu/torchvision-0.17.2%2Bcpu-cp312-cp312-linux_x86_64.whl ; python_version == "3.12" and platform_machine == "x86_64" and sys_platform == "linux"
|
38
|
+
Requires-Dist: torchvision @ https://download.pytorch.org/whl/cpu/torchvision-0.17.2-cp311-cp311-macosx_10_13_x86_64.whl ; python_version == "3.11" and platform_machine == "x86_64" and sys_platform == "darwin"
|
39
|
+
Requires-Dist: torchvision @ https://download.pytorch.org/whl/cpu/torchvision-0.17.2-cp311-cp311-macosx_11_0_arm64.whl ; python_version == "3.11" and platform_machine == "arm64" and sys_platform == "darwin"
|
40
|
+
Requires-Dist: torchvision @ https://download.pytorch.org/whl/cpu/torchvision-0.17.2-cp312-cp312-macosx_10_13_x86_64.whl ; python_version == "3.12" and platform_machine == "x86_64" and sys_platform == "darwin"
|
41
|
+
Requires-Dist: torchvision @ https://download.pytorch.org/whl/cpu/torchvision-0.17.2-cp312-cp312-macosx_11_0_arm64.whl ; python_version == "3.12" and platform_machine == "arm64" and sys_platform == "darwin"
|
42
|
+
Requires-Dist: tqdm (>=4.64.0,<5.0.0)
|
43
|
+
Description-Content-Type: text/markdown
|
44
|
+
|
45
|
+
# Docling-models
|
46
|
+
|
47
|
+
AI modules to support the Dockling PDF document conversion project.
|
48
|
+
|
49
|
+
- TableFormer is an AI module that recognizes the structure of a table and the bounding boxes of the table content.
|
50
|
+
- Layout model is an AI model that provides among other things ability to detect tables on the page. This package contains inference code for Layout model.
|
51
|
+
|
52
|
+
|
53
|
+
## Installation Instructions
|
54
|
+
|
55
|
+
### MacOS / Linux
|
56
|
+
|
57
|
+
To install `poetry` locally, use either `pip` or `homebrew`.
|
58
|
+
|
59
|
+
To install `poetry` on a docker container, do the following:
|
60
|
+
```
|
61
|
+
ENV POETRY_NO_INTERACTION=1 \
|
62
|
+
POETRY_VIRTUALENVS_CREATE=false
|
63
|
+
|
64
|
+
# Install poetry
|
65
|
+
RUN curl -sSL 'https://install.python-poetry.org' > install-poetry.py \
|
66
|
+
&& python install-poetry.py \
|
67
|
+
&& poetry --version \
|
68
|
+
&& rm install-poetry.py
|
69
|
+
```
|
70
|
+
|
71
|
+
To install and run the package, simply set up a poetry environment
|
72
|
+
|
73
|
+
```
|
74
|
+
poetry env use $(which python3.11)
|
75
|
+
poetry shell
|
76
|
+
```
|
77
|
+
|
78
|
+
and install all the dependencies,
|
79
|
+
|
80
|
+
```
|
81
|
+
poetry install # this will only install the deps from the poetry.lock
|
82
|
+
|
83
|
+
poetry install --no-dev # this will skip installing dev dependencies
|
84
|
+
```
|
85
|
+
|
86
|
+
To update or add new dependencies from `pyproject.toml`, rebuild `poetry.lock`
|
87
|
+
```
|
88
|
+
poetry update
|
89
|
+
```
|
90
|
+
|
91
|
+
|
92
|
+
## Pipeline Overview
|
93
|
+

|
94
|
+
|
95
|
+
## Datasets
|
96
|
+
Below we list datasets used with their description, source, and ***"TableFormer Format"***. The TableFormer Format is our processed version of the version of the original format to work with the dataloader out of the box, and to augment the dataset when necassary to add missing groundtruth (bounding boxes for empty cells).
|
97
|
+
|
98
|
+
|
99
|
+
| Name | Description | URL |
|
100
|
+
| ------------- |:-------------:|----|
|
101
|
+
| PubTabNet | PubTabNet contains heterogeneous tables in both image and HTML format, 516k+ tables in the PubMed Central Open Access Subset | [PubTabNet](https://developer.ibm.com/exchanges/data/all/pubtabnet/) |
|
102
|
+
| FinTabNet| A dataset for Financial Report Tables with corresponding ground truth location and structure. 112k+ tables included.| [FinTabNet](https://developer.ibm.com/exchanges/data/all/fintabnet/) |
|
103
|
+
| TableBank| TableBank is a new image-based table detection and recognition dataset built with novel weak supervision from Word and Latex documents on the internet, contains 417K high-quality labeled tables. | [TableBank](https://github.com/doc-analysis/TableBank) |
|
104
|
+
|
105
|
+
## Models
|
106
|
+
|
107
|
+
### TableModel04:
|
108
|
+

|
109
|
+
**TableModel04rs (OTSL)** is our SOTA method that using transformers in order to predict table structure and bounding box.
|
110
|
+
|
111
|
+
|
112
|
+
## Configuration file
|
113
|
+
|
114
|
+
Example configuration can be seen inside test `tests/test_tf_predictor.py`
|
115
|
+
These are the main sections of the configuration file:
|
116
|
+
|
117
|
+
- `dataset`: The directory for prepared data and the parameters used during the data loading.
|
118
|
+
- `model`: The type, name and hyperparameters of the model. Also the directory to save/load the
|
119
|
+
trained checkpoint files.
|
120
|
+
- `train`: Parameters for the training of the model.
|
121
|
+
- `predict`: Parameters for the evaluation of the model.
|
122
|
+
- `dataset_wordmap`: Very important part that contains token maps.
|
123
|
+
|
124
|
+
|
125
|
+
## Model weights
|
126
|
+
|
127
|
+
You can download the model weights and config files from the links:
|
128
|
+
|
129
|
+
- [TableFormer Checkpoint](https://huggingface.co/ds4sd/docling-models/tree/main/model_artifacts/tableformer)
|
130
|
+
- [beehive_v0.0.5](https://huggingface.co/ds4sd/docling-models/tree/main/model_artifacts/layout/beehive_v0.0.5)
|
131
|
+
|
132
|
+
Place the downloaded files into `tests/test_data/model_artifacts/` directory.
|
133
|
+
|
134
|
+
|
135
|
+
## Inference Tests
|
136
|
+
|
137
|
+
This contains unit tests for Docling models.
|
138
|
+
|
139
|
+
First download the model weights (see above), then run:
|
140
|
+
```
|
141
|
+
./devtools/check_code.sh
|
142
|
+
```
|
143
|
+
|
144
|
+
This will also generate prediction and matching visualizations that can be found here:
|
145
|
+
`tests\test_data\viz\`
|
146
|
+
|
147
|
+
Visualization outlines:
|
148
|
+
- `Light Pink`: border of recognized table
|
149
|
+
- `Grey`: OCR cells
|
150
|
+
- `Green`: prediction bboxes
|
151
|
+
- `Red`: OCR cells matched with prediction
|
152
|
+
- `Blue`: Post processed, match
|
153
|
+
- `Bold Blue`: column header
|
154
|
+
- `Bold Magenta`: row header
|
155
|
+
- `Bold Brown`: section row (if table have one)
|
156
|
+
|
157
|
+
|
158
|
+
## Demo
|
159
|
+
|
160
|
+
A demo application allows to apply the `LayoutPredictor` on a directory `<input_dir>` that contains
|
161
|
+
`png` images and visualize the predictions inside another directory `<viz_dir>`.
|
162
|
+
|
163
|
+
First download the model weights (see above), then run:
|
164
|
+
```
|
165
|
+
python -m demo.demo_layout_predictor -i <input_dir> -v <viz_dir>
|
166
|
+
```
|
167
|
+
|
168
|
+
e.g.
|
169
|
+
```
|
170
|
+
python -m demo.demo_layout_predictor -i tests/test_data/samples -v viz/
|
171
|
+
```
|
172
|
+
|
@@ -0,0 +1,127 @@
|
|
1
|
+
# Docling-models
|
2
|
+
|
3
|
+
AI modules to support the Dockling PDF document conversion project.
|
4
|
+
|
5
|
+
- TableFormer is an AI module that recognizes the structure of a table and the bounding boxes of the table content.
|
6
|
+
- Layout model is an AI model that provides among other things ability to detect tables on the page. This package contains inference code for Layout model.
|
7
|
+
|
8
|
+
|
9
|
+
## Installation Instructions
|
10
|
+
|
11
|
+
### MacOS / Linux
|
12
|
+
|
13
|
+
To install `poetry` locally, use either `pip` or `homebrew`.
|
14
|
+
|
15
|
+
To install `poetry` on a docker container, do the following:
|
16
|
+
```
|
17
|
+
ENV POETRY_NO_INTERACTION=1 \
|
18
|
+
POETRY_VIRTUALENVS_CREATE=false
|
19
|
+
|
20
|
+
# Install poetry
|
21
|
+
RUN curl -sSL 'https://install.python-poetry.org' > install-poetry.py \
|
22
|
+
&& python install-poetry.py \
|
23
|
+
&& poetry --version \
|
24
|
+
&& rm install-poetry.py
|
25
|
+
```
|
26
|
+
|
27
|
+
To install and run the package, simply set up a poetry environment
|
28
|
+
|
29
|
+
```
|
30
|
+
poetry env use $(which python3.11)
|
31
|
+
poetry shell
|
32
|
+
```
|
33
|
+
|
34
|
+
and install all the dependencies,
|
35
|
+
|
36
|
+
```
|
37
|
+
poetry install # this will only install the deps from the poetry.lock
|
38
|
+
|
39
|
+
poetry install --no-dev # this will skip installing dev dependencies
|
40
|
+
```
|
41
|
+
|
42
|
+
To update or add new dependencies from `pyproject.toml`, rebuild `poetry.lock`
|
43
|
+
```
|
44
|
+
poetry update
|
45
|
+
```
|
46
|
+
|
47
|
+
|
48
|
+
## Pipeline Overview
|
49
|
+

|
50
|
+
|
51
|
+
## Datasets
|
52
|
+
Below we list datasets used with their description, source, and ***"TableFormer Format"***. The TableFormer Format is our processed version of the version of the original format to work with the dataloader out of the box, and to augment the dataset when necassary to add missing groundtruth (bounding boxes for empty cells).
|
53
|
+
|
54
|
+
|
55
|
+
| Name | Description | URL |
|
56
|
+
| ------------- |:-------------:|----|
|
57
|
+
| PubTabNet | PubTabNet contains heterogeneous tables in both image and HTML format, 516k+ tables in the PubMed Central Open Access Subset | [PubTabNet](https://developer.ibm.com/exchanges/data/all/pubtabnet/) |
|
58
|
+
| FinTabNet| A dataset for Financial Report Tables with corresponding ground truth location and structure. 112k+ tables included.| [FinTabNet](https://developer.ibm.com/exchanges/data/all/fintabnet/) |
|
59
|
+
| TableBank| TableBank is a new image-based table detection and recognition dataset built with novel weak supervision from Word and Latex documents on the internet, contains 417K high-quality labeled tables. | [TableBank](https://github.com/doc-analysis/TableBank) |
|
60
|
+
|
61
|
+
## Models
|
62
|
+
|
63
|
+
### TableModel04:
|
64
|
+

|
65
|
+
**TableModel04rs (OTSL)** is our SOTA method that using transformers in order to predict table structure and bounding box.
|
66
|
+
|
67
|
+
|
68
|
+
## Configuration file
|
69
|
+
|
70
|
+
Example configuration can be seen inside test `tests/test_tf_predictor.py`
|
71
|
+
These are the main sections of the configuration file:
|
72
|
+
|
73
|
+
- `dataset`: The directory for prepared data and the parameters used during the data loading.
|
74
|
+
- `model`: The type, name and hyperparameters of the model. Also the directory to save/load the
|
75
|
+
trained checkpoint files.
|
76
|
+
- `train`: Parameters for the training of the model.
|
77
|
+
- `predict`: Parameters for the evaluation of the model.
|
78
|
+
- `dataset_wordmap`: Very important part that contains token maps.
|
79
|
+
|
80
|
+
|
81
|
+
## Model weights
|
82
|
+
|
83
|
+
You can download the model weights and config files from the links:
|
84
|
+
|
85
|
+
- [TableFormer Checkpoint](https://huggingface.co/ds4sd/docling-models/tree/main/model_artifacts/tableformer)
|
86
|
+
- [beehive_v0.0.5](https://huggingface.co/ds4sd/docling-models/tree/main/model_artifacts/layout/beehive_v0.0.5)
|
87
|
+
|
88
|
+
Place the downloaded files into `tests/test_data/model_artifacts/` directory.
|
89
|
+
|
90
|
+
|
91
|
+
## Inference Tests
|
92
|
+
|
93
|
+
This contains unit tests for Docling models.
|
94
|
+
|
95
|
+
First download the model weights (see above), then run:
|
96
|
+
```
|
97
|
+
./devtools/check_code.sh
|
98
|
+
```
|
99
|
+
|
100
|
+
This will also generate prediction and matching visualizations that can be found here:
|
101
|
+
`tests\test_data\viz\`
|
102
|
+
|
103
|
+
Visualization outlines:
|
104
|
+
- `Light Pink`: border of recognized table
|
105
|
+
- `Grey`: OCR cells
|
106
|
+
- `Green`: prediction bboxes
|
107
|
+
- `Red`: OCR cells matched with prediction
|
108
|
+
- `Blue`: Post processed, match
|
109
|
+
- `Bold Blue`: column header
|
110
|
+
- `Bold Magenta`: row header
|
111
|
+
- `Bold Brown`: section row (if table have one)
|
112
|
+
|
113
|
+
|
114
|
+
## Demo
|
115
|
+
|
116
|
+
A demo application allows to apply the `LayoutPredictor` on a directory `<input_dir>` that contains
|
117
|
+
`png` images and visualize the predictions inside another directory `<viz_dir>`.
|
118
|
+
|
119
|
+
First download the model weights (see above), then run:
|
120
|
+
```
|
121
|
+
python -m demo.demo_layout_predictor -i <input_dir> -v <viz_dir>
|
122
|
+
```
|
123
|
+
|
124
|
+
e.g.
|
125
|
+
```
|
126
|
+
python -m demo.demo_layout_predictor -i tests/test_data/samples -v viz/
|
127
|
+
```
|
@@ -0,0 +1,171 @@
|
|
1
|
+
#
|
2
|
+
# Copyright IBM Corp. 2024 - 2024
|
3
|
+
# SPDX-License-Identifier: MIT
|
4
|
+
#
|
5
|
+
import os
|
6
|
+
from collections.abc import Iterable
|
7
|
+
from typing import Union
|
8
|
+
|
9
|
+
import numpy as np
|
10
|
+
import onnxruntime as ort
|
11
|
+
from PIL import Image
|
12
|
+
|
13
|
+
MODEL_CHECKPOINT_FN = "model.pt"
|
14
|
+
DEFAULT_NUM_THREADS = 4
|
15
|
+
|
16
|
+
|
17
|
+
# Classes:
|
18
|
+
CLASSES_MAP = {
|
19
|
+
0: "background",
|
20
|
+
1: "Caption",
|
21
|
+
2: "Footnote",
|
22
|
+
3: "Formula",
|
23
|
+
4: "List-item",
|
24
|
+
5: "Page-footer",
|
25
|
+
6: "Page-header",
|
26
|
+
7: "Picture",
|
27
|
+
8: "Section-header",
|
28
|
+
9: "Table",
|
29
|
+
10: "Text",
|
30
|
+
11: "Title",
|
31
|
+
12: "Document Index",
|
32
|
+
13: "Code",
|
33
|
+
14: "Checkbox-Selected",
|
34
|
+
15: "Checkbox-Unselected",
|
35
|
+
16: "Form",
|
36
|
+
17: "Key-Value Region",
|
37
|
+
}
|
38
|
+
|
39
|
+
|
40
|
+
class LayoutPredictor:
|
41
|
+
r"""
|
42
|
+
Document layout prediction using ONNX
|
43
|
+
"""
|
44
|
+
|
45
|
+
def __init__(
|
46
|
+
self, artifact_path: str, num_threads: int = None, use_cpu_only: bool = False
|
47
|
+
):
|
48
|
+
r"""
|
49
|
+
Provide the artifact path that contains the LayoutModel ONNX file
|
50
|
+
|
51
|
+
The number of threads is decided, in the following order, by:
|
52
|
+
1. The init method parameter `num_threads`, if it is set.
|
53
|
+
2. The envvar "OMP_NUM_THREADS", if it is set.
|
54
|
+
3. The default value DEFAULT_NUM_THREADS.
|
55
|
+
|
56
|
+
The execution provided is decided, in the following order:
|
57
|
+
1. If the init method parameter `cpu_only` is True or the envvar "USE_CPU_ONLY" is set,
|
58
|
+
it uses the "CPUExecutionProvider".
|
59
|
+
3. Otherwise if the "CUDAExecutionProvider" is present, use:
|
60
|
+
["CUDAExecutionProvider", "CPUExecutionProvider"]:
|
61
|
+
|
62
|
+
Parameters
|
63
|
+
----------
|
64
|
+
artifact_path: Path for the model ONNX file.
|
65
|
+
num_threads: (Optional) Number of threads to run the inference.
|
66
|
+
use_cpu_only: (Optional) If True, it forces CPU as the execution provider.
|
67
|
+
|
68
|
+
Raises
|
69
|
+
------
|
70
|
+
FileNotFoundError when the model's ONNX file is missing
|
71
|
+
"""
|
72
|
+
# Set basic params
|
73
|
+
self._threshold = 0.6 # Score threshold
|
74
|
+
self._image_size = 640
|
75
|
+
self._size = np.asarray([[self._image_size, self._image_size]], dtype=np.int64)
|
76
|
+
|
77
|
+
# Get env vars
|
78
|
+
self._use_cpu_only = use_cpu_only or ("USE_CPU_ONLY" in os.environ)
|
79
|
+
if num_threads is None:
|
80
|
+
num_threads = int(os.environ.get("OMP_NUM_THREADS", DEFAULT_NUM_THREADS))
|
81
|
+
self._num_threads = num_threads
|
82
|
+
|
83
|
+
# Decide the execution providers
|
84
|
+
if (
|
85
|
+
not self._use_cpu_only
|
86
|
+
and "CUDAExecutionProvider" in ort.get_available_providers()
|
87
|
+
):
|
88
|
+
providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
|
89
|
+
else:
|
90
|
+
providers = ["CPUExecutionProvider"]
|
91
|
+
self._providers = providers
|
92
|
+
|
93
|
+
# Model ONNX file
|
94
|
+
self._onnx_fn = os.path.join(artifact_path, MODEL_CHECKPOINT_FN)
|
95
|
+
if not os.path.isfile(self._onnx_fn):
|
96
|
+
raise FileNotFoundError("Missing ONNX file: {}".format(self._onnx_fn))
|
97
|
+
|
98
|
+
# ONNX options
|
99
|
+
self._options = ort.SessionOptions()
|
100
|
+
self._options.intra_op_num_threads = self._num_threads
|
101
|
+
self.sess = ort.InferenceSession(
|
102
|
+
self._onnx_fn,
|
103
|
+
sess_options=self._options,
|
104
|
+
providers=self._providers,
|
105
|
+
)
|
106
|
+
|
107
|
+
def info(self) -> dict:
|
108
|
+
r"""
|
109
|
+
Get information about the configuration of LayoutPredictor
|
110
|
+
"""
|
111
|
+
info = {
|
112
|
+
"onnx_file": self._onnx_fn,
|
113
|
+
"intra_op_num_threads": self._num_threads,
|
114
|
+
"use_cpu_only": self._use_cpu_only,
|
115
|
+
"providers": self._providers,
|
116
|
+
"image_size": self._image_size,
|
117
|
+
"threshold": self._threshold,
|
118
|
+
}
|
119
|
+
return info
|
120
|
+
|
121
|
+
def predict(self, orig_img: Union[Image, np.array]) -> Iterable[dict]:
|
122
|
+
r"""
|
123
|
+
Predict bounding boxes for a given image.
|
124
|
+
The origin (0, 0) is the top-left corner and the predicted bbox coords are provided as:
|
125
|
+
[left, top, right, bottom]
|
126
|
+
|
127
|
+
Parameter
|
128
|
+
---------
|
129
|
+
origin_img: Image to be predicted as a PIL Image object or numpy array.
|
130
|
+
|
131
|
+
Yield
|
132
|
+
-----
|
133
|
+
Bounding box as a dict with the keys: "label", "confidence", "l", "t", "r", "b"
|
134
|
+
|
135
|
+
Raises
|
136
|
+
------
|
137
|
+
TypeError when the input image is not supported
|
138
|
+
"""
|
139
|
+
# Convert image format
|
140
|
+
if isinstance(orig_img, Image.Image):
|
141
|
+
page_img = orig_img.convert("RGB")
|
142
|
+
elif isinstance(orig_img, np.ndarray):
|
143
|
+
page_img = Image.fromarray(orig_img).convert("RGB")
|
144
|
+
else:
|
145
|
+
raise TypeError("Not supported input image format")
|
146
|
+
|
147
|
+
w, h = page_img.size
|
148
|
+
page_img = page_img.resize((self._image_size, self._image_size))
|
149
|
+
page_data = np.array(page_img, dtype=np.uint8) / np.float32(255.0)
|
150
|
+
page_data = np.expand_dims(np.transpose(page_data, axes=[2, 0, 1]), axis=0)
|
151
|
+
|
152
|
+
# Predict
|
153
|
+
labels, boxes, scores = self.sess.run(
|
154
|
+
output_names=None,
|
155
|
+
input_feed={
|
156
|
+
"images": page_data,
|
157
|
+
"orig_target_sizes": self._size,
|
158
|
+
},
|
159
|
+
)
|
160
|
+
|
161
|
+
# Yield output
|
162
|
+
for label, box, score in zip(labels[0], boxes[0], scores[0]):
|
163
|
+
if score > self._threshold:
|
164
|
+
yield {
|
165
|
+
"l": box[0] / self._image_size * w,
|
166
|
+
"t": box[1] / self._image_size * h,
|
167
|
+
"r": box[2] / self._image_size * w,
|
168
|
+
"b": box[3] / self._image_size * h,
|
169
|
+
"label": CLASSES_MAP[label],
|
170
|
+
"confidence": score,
|
171
|
+
}
|
File without changes
|
@@ -0,0 +1,200 @@
|
|
1
|
+
#
|
2
|
+
# Copyright IBM Corp. 2024 - 2024
|
3
|
+
# SPDX-License-Identifier: MIT
|
4
|
+
#
|
5
|
+
import argparse
|
6
|
+
import json
|
7
|
+
import logging
|
8
|
+
import os
|
9
|
+
|
10
|
+
import torch
|
11
|
+
|
12
|
+
import docling_ibm_models.tableformer.settings as s
|
13
|
+
from docling_ibm_models.tableformer.models.common.base_model import BaseModel
|
14
|
+
|
15
|
+
LOG_LEVEL = logging.DEBUG
|
16
|
+
logger = s.get_custom_logger("common", LOG_LEVEL)
|
17
|
+
|
18
|
+
|
19
|
+
def validate_config(config):
|
20
|
+
r"""
|
21
|
+
Validate the provided configuration file.
|
22
|
+
A ValueError exception will be thrown in case the config file is invalid
|
23
|
+
|
24
|
+
Parameters
|
25
|
+
----------
|
26
|
+
config : dictionary
|
27
|
+
Configuration for the tablemodel
|
28
|
+
|
29
|
+
Returns
|
30
|
+
-------
|
31
|
+
bool : True on success
|
32
|
+
"""
|
33
|
+
if "model" not in config:
|
34
|
+
return True
|
35
|
+
if "preparation" not in config:
|
36
|
+
return True
|
37
|
+
assert (
|
38
|
+
"max_tag_len" in config["preparation"]
|
39
|
+
), "Config error: 'preparation.max_tag_len' parameter is missing"
|
40
|
+
if "seq_len" in config["model"]:
|
41
|
+
assert (
|
42
|
+
config["model"]["seq_len"] > 0
|
43
|
+
), "Config error: 'model.seq_len' should be positive"
|
44
|
+
assert config["model"]["seq_len"] <= (
|
45
|
+
config["preparation"]["max_tag_len"] + 2
|
46
|
+
), "Config error: 'model.seq_len' should be up to 'preparation.max_tag_len' + 2"
|
47
|
+
|
48
|
+
return True
|
49
|
+
|
50
|
+
|
51
|
+
def parse_arguments():
|
52
|
+
r"""
|
53
|
+
Parse the input arguments
|
54
|
+
A ValueError exception will be thrown in case the config file is invalid
|
55
|
+
"""
|
56
|
+
parser = argparse.ArgumentParser(description="Train the TableModel")
|
57
|
+
parser.add_argument(
|
58
|
+
"-c", "--config", required=True, default=None, help="configuration file (JSON)"
|
59
|
+
)
|
60
|
+
args = parser.parse_args()
|
61
|
+
config_filename = args.config
|
62
|
+
|
63
|
+
assert os.path.isfile(config_filename), "FAILURE: Config file not found."
|
64
|
+
return read_config(config_filename)
|
65
|
+
|
66
|
+
|
67
|
+
def read_config(config_filename):
|
68
|
+
with open(config_filename, "r") as fd:
|
69
|
+
config = json.load(fd)
|
70
|
+
|
71
|
+
# Validate the config file
|
72
|
+
validate_config(config)
|
73
|
+
|
74
|
+
return config
|
75
|
+
|
76
|
+
|
77
|
+
def safe_get_parameter(input_dict, index_path, default=None, required=False):
|
78
|
+
r"""
|
79
|
+
Safe get parameter from a nested dictionary.
|
80
|
+
|
81
|
+
Provide a nested dictionary (dictionary of dictionaries) and a list of indices:
|
82
|
+
- If the whole index path exists the value pointed by it is returned
|
83
|
+
- Otherwise the default value is returned.
|
84
|
+
|
85
|
+
Input:
|
86
|
+
input_dict: Data structure of nested dictionaries.
|
87
|
+
index_path: List with the indices path to follow inside the input_dict.
|
88
|
+
default: Default value to return if the indices path is broken.
|
89
|
+
required: If true a ValueError exception will be raised in case the parameter does not exist
|
90
|
+
Output:
|
91
|
+
The value pointed by the index path or "default".
|
92
|
+
"""
|
93
|
+
if input_dict is None or index_path is None:
|
94
|
+
return default
|
95
|
+
|
96
|
+
d = input_dict
|
97
|
+
for i in index_path[:-1]:
|
98
|
+
if i not in d:
|
99
|
+
if required:
|
100
|
+
raise ValueError("Missing parameter: {}".format(i))
|
101
|
+
return default
|
102
|
+
d = d[i]
|
103
|
+
|
104
|
+
last_index = index_path[-1]
|
105
|
+
if last_index not in d:
|
106
|
+
if required:
|
107
|
+
raise ValueError("Missing parameter: {}".format(last_index))
|
108
|
+
return default
|
109
|
+
|
110
|
+
return d[last_index]
|
111
|
+
|
112
|
+
|
113
|
+
def get_prepared_data_filename(prepared_data_part, dataset_name):
|
114
|
+
r"""
|
115
|
+
Build the full filename of the prepared data part
|
116
|
+
|
117
|
+
Parameters
|
118
|
+
----------
|
119
|
+
prepared_data_part : string
|
120
|
+
Part of the prepared data
|
121
|
+
dataset_name : string
|
122
|
+
Name of the dataset
|
123
|
+
|
124
|
+
Returns
|
125
|
+
-------
|
126
|
+
string
|
127
|
+
The full filename for the prepared file
|
128
|
+
"""
|
129
|
+
template = s.PREPARED_DATA_PARTS[prepared_data_part]
|
130
|
+
if "<POSTFIX>" in template:
|
131
|
+
template = template.replace("<POSTFIX>", dataset_name)
|
132
|
+
return template
|
133
|
+
|
134
|
+
|
135
|
+
def create_dataset_and_model(config, purpose, fixed_padding=False):
|
136
|
+
r"""
|
137
|
+
Gets a model from configuration
|
138
|
+
|
139
|
+
Parameters
|
140
|
+
---------
|
141
|
+
config : Dictionary
|
142
|
+
The configuration of the model
|
143
|
+
purpose : string
|
144
|
+
One of "train", "eval", "predict"
|
145
|
+
fixed_padding : bool
|
146
|
+
Parameter passed to the constructor of the DataLoader
|
147
|
+
|
148
|
+
Returns
|
149
|
+
-------
|
150
|
+
In case a Model cannot be initialized return None, None, None. Otherwise:
|
151
|
+
|
152
|
+
device : selected device
|
153
|
+
dataset : Instance of the DataLoader
|
154
|
+
model : Instance of the model
|
155
|
+
"""
|
156
|
+
from docling_ibm_models.tableformer.data_management.tf_dataset import TFDataset
|
157
|
+
|
158
|
+
model_type = config["model"]["type"]
|
159
|
+
model = None
|
160
|
+
|
161
|
+
# Get env vars:
|
162
|
+
use_cpu_only = os.environ.get("USE_CPU_ONLY", False)
|
163
|
+
use_cuda_only = not use_cpu_only
|
164
|
+
|
165
|
+
# Use the cpu for the evaluation
|
166
|
+
device = "cpu" # Default, run on CPU
|
167
|
+
num_gpus = torch.cuda.device_count() # Check if GPU is available
|
168
|
+
if use_cuda_only:
|
169
|
+
device = "cuda:0" if num_gpus > 0 else "cpu" # Run on first available GPU
|
170
|
+
else:
|
171
|
+
device = "cpu"
|
172
|
+
|
173
|
+
# Create the DataLoader
|
174
|
+
# loader = DataLoader(config, purpose, fixed_padding=fixed_padding)
|
175
|
+
dataset = TFDataset(config, purpose, fixed_padding=fixed_padding)
|
176
|
+
dataset.set_device(device)
|
177
|
+
dataset_val = None
|
178
|
+
if config["train"]["validation"] and purpose == "train":
|
179
|
+
dataset_val = TFDataset(config, "val", fixed_padding=fixed_padding)
|
180
|
+
dataset_val.set_device(device)
|
181
|
+
if model_type == "TableModel04_rs":
|
182
|
+
from docling_ibm_models.tableformer.models.table04_rs.tablemodel04_rs import ( # noqa: F401
|
183
|
+
TableModel04_rs,
|
184
|
+
)
|
185
|
+
# Find the model class and create an instance of it
|
186
|
+
for candidate in BaseModel.__subclasses__():
|
187
|
+
if candidate.__name__ == model_type:
|
188
|
+
init_data = dataset.get_init_data()
|
189
|
+
model = candidate(config, init_data, purpose, device)
|
190
|
+
|
191
|
+
if model is None:
|
192
|
+
logger.warn("Not found model: " + str(model_type))
|
193
|
+
return None, None, None
|
194
|
+
|
195
|
+
logger.info("Found model: " + str(model_type))
|
196
|
+
|
197
|
+
if purpose == s.PREDICT_PURPOSE:
|
198
|
+
return device, dataset, model
|
199
|
+
else:
|
200
|
+
return device, dataset, dataset_val, model
|