docling-ibm-models 2.0.7__tar.gz → 3.0.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. {docling_ibm_models-2.0.7 → docling_ibm_models-3.0.0}/PKG-INFO +4 -1
  2. docling_ibm_models-3.0.0/docling_ibm_models/layoutmodel/layout_predictor.py +175 -0
  3. {docling_ibm_models-2.0.7 → docling_ibm_models-3.0.0}/docling_ibm_models/tableformer/data_management/matching_post_processor.py +4 -4
  4. {docling_ibm_models-2.0.7 → docling_ibm_models-3.0.0}/docling_ibm_models/tableformer/data_management/tf_cell_matcher.py +1 -1
  5. {docling_ibm_models-2.0.7 → docling_ibm_models-3.0.0}/docling_ibm_models/tableformer/data_management/tf_predictor.py +37 -40
  6. {docling_ibm_models-2.0.7 → docling_ibm_models-3.0.0}/docling_ibm_models/tableformer/otsl.py +40 -30
  7. {docling_ibm_models-2.0.7 → docling_ibm_models-3.0.0}/pyproject.toml +16 -1
  8. docling_ibm_models-2.0.7/docling_ibm_models/layoutmodel/layout_predictor.py +0 -167
  9. {docling_ibm_models-2.0.7 → docling_ibm_models-3.0.0}/LICENSE +0 -0
  10. {docling_ibm_models-2.0.7 → docling_ibm_models-3.0.0}/README.md +0 -0
  11. {docling_ibm_models-2.0.7 → docling_ibm_models-3.0.0}/docling_ibm_models/tableformer/__init__.py +0 -0
  12. {docling_ibm_models-2.0.7 → docling_ibm_models-3.0.0}/docling_ibm_models/tableformer/common.py +0 -0
  13. {docling_ibm_models-2.0.7 → docling_ibm_models-3.0.0}/docling_ibm_models/tableformer/data_management/__init__.py +0 -0
  14. {docling_ibm_models-2.0.7 → docling_ibm_models-3.0.0}/docling_ibm_models/tableformer/data_management/functional.py +0 -0
  15. {docling_ibm_models-2.0.7 → docling_ibm_models-3.0.0}/docling_ibm_models/tableformer/data_management/transforms.py +0 -0
  16. {docling_ibm_models-2.0.7 → docling_ibm_models-3.0.0}/docling_ibm_models/tableformer/models/__init__.py +0 -0
  17. {docling_ibm_models-2.0.7 → docling_ibm_models-3.0.0}/docling_ibm_models/tableformer/models/common/__init__.py +0 -0
  18. {docling_ibm_models-2.0.7 → docling_ibm_models-3.0.0}/docling_ibm_models/tableformer/models/common/base_model.py +0 -0
  19. {docling_ibm_models-2.0.7 → docling_ibm_models-3.0.0}/docling_ibm_models/tableformer/models/table04_rs/__init__.py +0 -0
  20. {docling_ibm_models-2.0.7 → docling_ibm_models-3.0.0}/docling_ibm_models/tableformer/models/table04_rs/bbox_decoder_rs.py +0 -0
  21. {docling_ibm_models-2.0.7 → docling_ibm_models-3.0.0}/docling_ibm_models/tableformer/models/table04_rs/encoder04_rs.py +0 -0
  22. {docling_ibm_models-2.0.7 → docling_ibm_models-3.0.0}/docling_ibm_models/tableformer/models/table04_rs/tablemodel04_rs.py +0 -0
  23. {docling_ibm_models-2.0.7 → docling_ibm_models-3.0.0}/docling_ibm_models/tableformer/models/table04_rs/transformer_rs.py +0 -0
  24. {docling_ibm_models-2.0.7 → docling_ibm_models-3.0.0}/docling_ibm_models/tableformer/settings.py +0 -0
  25. {docling_ibm_models-2.0.7 → docling_ibm_models-3.0.0}/docling_ibm_models/tableformer/utils/__init__.py +0 -0
  26. {docling_ibm_models-2.0.7 → docling_ibm_models-3.0.0}/docling_ibm_models/tableformer/utils/app_profiler.py +0 -0
  27. {docling_ibm_models-2.0.7 → docling_ibm_models-3.0.0}/docling_ibm_models/tableformer/utils/mem_monitor.py +0 -0
  28. {docling_ibm_models-2.0.7 → docling_ibm_models-3.0.0}/docling_ibm_models/tableformer/utils/torch_utils.py +0 -0
  29. {docling_ibm_models-2.0.7 → docling_ibm_models-3.0.0}/docling_ibm_models/tableformer/utils/utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: docling-ibm-models
3
- Version: 2.0.7
3
+ Version: 3.0.0
4
4
  Summary: This package contains the AI models used by the Docling PDF conversion package
5
5
  License: MIT
6
6
  Keywords: docling,convert,document,pdf,layout model,segmentation,table structure,table former
@@ -18,15 +18,18 @@ Classifier: Programming Language :: Python :: 3.9
18
18
  Classifier: Programming Language :: Python :: 3.10
19
19
  Classifier: Programming Language :: Python :: 3.11
20
20
  Classifier: Programming Language :: Python :: 3.12
21
+ Classifier: Programming Language :: Python :: 3.13
21
22
  Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
22
23
  Requires-Dist: Pillow (>=10.0.0,<11.0.0)
23
24
  Requires-Dist: huggingface_hub (>=0.23,<1)
24
25
  Requires-Dist: jsonlines (>=3.1.0,<4.0.0)
25
26
  Requires-Dist: numpy (>=1.24.4,<3.0.0)
26
27
  Requires-Dist: opencv-python-headless (>=4.6.0.66,<5.0.0.0)
28
+ Requires-Dist: safetensors[torch] (>=0.4.3,<1)
27
29
  Requires-Dist: torch (>=2.2.2,<3.0.0)
28
30
  Requires-Dist: torchvision (>=0,<1)
29
31
  Requires-Dist: tqdm (>=4.64.0,<5.0.0)
32
+ Requires-Dist: transformers (>=4.42.0,<5.0.0)
30
33
  Description-Content-Type: text/markdown
31
34
 
32
35
  [![PyPI version](https://img.shields.io/pypi/v/docling-ibm-models)](https://pypi.org/project/docling-ibm-models/)
@@ -0,0 +1,175 @@
1
+ #
2
+ # Copyright IBM Corp. 2024 - 2024
3
+ # SPDX-License-Identifier: MIT
4
+ #
5
+ import logging
6
+ import os
7
+ from collections.abc import Iterable
8
+ from typing import Union
9
+
10
+ import numpy as np
11
+ import torch
12
+ import torchvision.transforms as T
13
+ from PIL import Image
14
+ from transformers import RTDetrForObjectDetection, RTDetrImageProcessor
15
+
16
+ _log = logging.getLogger(__name__)
17
+
18
+
19
+ class LayoutPredictor:
20
+ """
21
+ Document layout prediction using safe tensors
22
+ """
23
+
24
+ def __init__(
25
+ self,
26
+ artifact_path: str,
27
+ device: str = "cpu",
28
+ num_threads: int = 4,
29
+ ):
30
+ """
31
+ Provide the artifact path that contains the LayoutModel file
32
+
33
+ Parameters
34
+ ----------
35
+ artifact_path: Path for the model torch file.
36
+ device: (Optional) device to run the inference.
37
+ num_threads: (Optional) Number of threads to run the inference if device = 'cpu'
38
+
39
+ Raises
40
+ ------
41
+ FileNotFoundError when the model's torch file is missing
42
+ """
43
+ # Initialize classes map:
44
+ self._classes_map = {
45
+ 0: "background",
46
+ 1: "Caption",
47
+ 2: "Footnote",
48
+ 3: "Formula",
49
+ 4: "List-item",
50
+ 5: "Page-footer",
51
+ 6: "Page-header",
52
+ 7: "Picture",
53
+ 8: "Section-header",
54
+ 9: "Table",
55
+ 10: "Text",
56
+ 11: "Title",
57
+ 12: "Document Index",
58
+ 13: "Code",
59
+ 14: "Checkbox-Selected",
60
+ 15: "Checkbox-Unselected",
61
+ 16: "Form",
62
+ 17: "Key-Value Region",
63
+ }
64
+
65
+ # Blacklisted classes
66
+ self._black_classes = set() # ["Form", "Key-Value Region"])
67
+
68
+ # Set basic params
69
+ self._threshold = 0.3 # Score threshold
70
+ self._image_size = 640
71
+ self._size = np.asarray([[self._image_size, self._image_size]], dtype=np.int64)
72
+
73
+ # Set number of threads for CPU
74
+ self._device = torch.device(device)
75
+ self._num_threads = num_threads
76
+ if device == "cpu":
77
+ torch.set_num_threads(self._num_threads)
78
+
79
+ # Model file and configurations
80
+ self._st_fn = os.path.join(artifact_path, "model.safetensors")
81
+ if not os.path.isfile(self._st_fn):
82
+ raise FileNotFoundError("Missing safe tensors file: {}".format(self._st_fn))
83
+
84
+ # Load model and move to device
85
+ processor_config = os.path.join(artifact_path, "preprocessor_config.json")
86
+ model_config = os.path.join(artifact_path, "config.json")
87
+ self._image_processor = RTDetrImageProcessor.from_json_file(processor_config)
88
+ self._model = RTDetrForObjectDetection.from_pretrained(
89
+ artifact_path, config=model_config
90
+ ).to(self._device)
91
+ self._model.eval()
92
+
93
+ _log.debug("LayoutPredictor settings: {}".format(self.info()))
94
+
95
+ def info(self) -> dict:
96
+ """
97
+ Get information about the configuration of LayoutPredictor
98
+ """
99
+ info = {
100
+ "safe_tensors_file": self._st_fn,
101
+ "device": self._device.type,
102
+ "num_threads": self._num_threads,
103
+ "image_size": self._image_size,
104
+ "threshold": self._threshold,
105
+ }
106
+ return info
107
+
108
+ @torch.inference_mode()
109
+ def predict(self, orig_img: Union[Image.Image, np.ndarray]) -> Iterable[dict]:
110
+ """
111
+ Predict bounding boxes for a given image.
112
+ The origin (0, 0) is the top-left corner and the predicted bbox coords are provided as:
113
+ [left, top, right, bottom]
114
+
115
+ Parameter
116
+ ---------
117
+ origin_img: Image to be predicted as a PIL Image object or numpy array.
118
+
119
+ Yield
120
+ -----
121
+ Bounding box as a dict with the keys: "label", "confidence", "l", "t", "r", "b"
122
+
123
+ Raises
124
+ ------
125
+ TypeError when the input image is not supported
126
+ """
127
+ # Convert image format
128
+ if isinstance(orig_img, Image.Image):
129
+ page_img = orig_img.convert("RGB")
130
+ elif isinstance(orig_img, np.ndarray):
131
+ page_img = Image.fromarray(orig_img).convert("RGB")
132
+ else:
133
+ raise TypeError("Not supported input image format")
134
+
135
+ resize = {"height": self._image_size, "width": self._image_size}
136
+ inputs = self._image_processor(
137
+ images=page_img,
138
+ return_tensors="pt",
139
+ size=resize,
140
+ ).to(self._device)
141
+ outputs = self._model(**inputs)
142
+ results = self._image_processor.post_process_object_detection(
143
+ outputs,
144
+ target_sizes=torch.tensor([page_img.size[::-1]]),
145
+ threshold=self._threshold,
146
+ )
147
+
148
+ w, h = page_img.size
149
+
150
+ result = results[0]
151
+ for score, label_id, box in zip(
152
+ result["scores"], result["labels"], result["boxes"]
153
+ ):
154
+ score = float(score.item())
155
+
156
+ label_id = int(label_id.item()) + 1 # Advance the label_id
157
+ label_str = self._classes_map[label_id]
158
+
159
+ # Filter out blacklisted classes
160
+ if label_str in self._black_classes:
161
+ continue
162
+
163
+ bbox_float = [float(b.item()) for b in box]
164
+ l = min(w, max(0, bbox_float[0]))
165
+ t = min(h, max(0, bbox_float[1]))
166
+ r = min(w, max(0, bbox_float[2]))
167
+ b = min(h, max(0, bbox_float[3]))
168
+ yield {
169
+ "l": l,
170
+ "t": t,
171
+ "r": r,
172
+ "b": b,
173
+ "label": label_str,
174
+ "confidence": score,
175
+ }
@@ -96,10 +96,10 @@ class MatchingPostProcessor:
96
96
  if cell["cell_class"] <= 1:
97
97
  allow_class = False
98
98
  else:
99
- print("***")
100
- print("no cell_class in...")
101
- print(cell)
102
- print("***")
99
+ self._log().debug("***")
100
+ self._log().debug("no cell_class in...")
101
+ self._log().debug(cell)
102
+ self._log().debug("***")
103
103
  if allow_class:
104
104
  match_list = matches[pdf_cell_id]
105
105
  for match in match_list:
@@ -264,7 +264,7 @@ class CellMatcher:
264
264
  r, o = otsl.html_to_otsl(table_html_structure, None, False, False, True, False)
265
265
  if not r:
266
266
  ermsg = "ERR#: COULD NOT CONVERT TO RS THIS TABLE TO COMPUTE SPANS"
267
- print(ermsg)
267
+ self._log().debug(ermsg)
268
268
  else:
269
269
  otsl_spans = o["otsl_spans"]
270
270
 
@@ -2,14 +2,17 @@
2
2
  # Copyright IBM Corp. 2024 - 2024
3
3
  # SPDX-License-Identifier: MIT
4
4
  #
5
+ import glob
5
6
  import json
6
7
  import logging
7
8
  import os
8
9
  from itertools import groupby
10
+ from pathlib import Path
9
11
 
10
12
  import cv2
11
13
  import numpy as np
12
14
  import torch
15
+ from safetensors.torch import load_model
13
16
 
14
17
  import docling_ibm_models.tableformer.common as c
15
18
  import docling_ibm_models.tableformer.data_management.transforms as T
@@ -30,6 +33,8 @@ from docling_ibm_models.tableformer.utils.app_profiler import AggProfiler
30
33
  # LOG_LEVEL = logging.DEBUG
31
34
  LOG_LEVEL = logging.WARN
32
35
 
36
+ logger = s.get_custom_logger(__name__, LOG_LEVEL)
37
+
33
38
 
34
39
  class bcolors:
35
40
  HEADER = "\033[95m"
@@ -53,17 +58,17 @@ def otsl_sqr_chk(rs_list, logdebug):
53
58
 
54
59
  totcelnum = rs_list.count("fcel") + rs_list.count("ecel")
55
60
  if logdebug:
56
- print("Total number of cells = {}".format(totcelnum))
61
+ logger.debug("Total number of cells = {}".format(totcelnum))
57
62
 
58
63
  for ind, ln in enumerate(rs_list_split):
59
64
  ln.append("nl")
60
65
  if logdebug:
61
- print("{}".format(ln))
66
+ logger.debug("{}".format(ln))
62
67
  if len(ln) != init_tag_len:
63
68
  isSquare = False
64
69
  if isSquare:
65
70
  if logdebug:
66
- print(
71
+ logger.debug(
67
72
  "{}*OK* Table is square! *OK*{}".format(
68
73
  bcolors.OKGREEN, bcolors.ENDC
69
74
  )
@@ -71,8 +76,8 @@ def otsl_sqr_chk(rs_list, logdebug):
71
76
  else:
72
77
  if logdebug:
73
78
  err_name = "{}***** ERR ******{}"
74
- print(err_name.format(bcolors.FAIL, bcolors.ENDC))
75
- print(
79
+ logger.debug(err_name.format(bcolors.FAIL, bcolors.ENDC))
80
+ logger.debug(
76
81
  "{}*ERR* Table is not square! *ERR*{}".format(
77
82
  bcolors.FAIL, bcolors.ENDC
78
83
  )
@@ -80,45 +85,27 @@ def otsl_sqr_chk(rs_list, logdebug):
80
85
  return isSquare
81
86
 
82
87
 
83
- def decide_device(config: dict) -> str:
84
- r"""
85
- Decide the inference device based on the "predict.device_mode" parameter
86
- """
87
- device_mode = config["predict"].get("device_mode", "cpu")
88
- num_gpus = torch.cuda.device_count()
89
-
90
- if device_mode == "auto":
91
- device = "cuda:0" if num_gpus > 0 else "cpu"
92
- elif device_mode in ["gpu", "cuda"]:
93
- device = "cuda:0"
94
- else:
95
- device = "cpu"
96
- return device
97
-
98
-
99
88
  class TFPredictor:
100
89
  r"""
101
90
  Table predictions for the in-memory Docling API
102
91
  """
103
92
 
104
- def __init__(self, config, num_threads: int = None):
93
+ def __init__(self, config, device: str = "cpu", num_threads: int = 4):
105
94
  r"""
106
- The number of threads is decided, in the following order, by:
107
- 1. The init method parameter `num_threads`, if it is set.
108
- 2. The envvar "OMP_NUM_THREADS", if it is set.
109
- 3. The default value 4.
110
-
111
95
  Parameters
112
96
  ----------
113
- config : dict
114
- Parameters configuration
97
+ config : dict Parameters configuration
98
+ device: (Optional) torch device to run the inference.
99
+ num_threads: (Optional) Number of threads to run the inference if device = 'cpu'
100
+
115
101
  Raises
116
102
  ------
117
103
  ValueError
118
104
  When the model cannot be found
119
105
  """
120
- self._device = decide_device(config)
121
- self._log().info("Running on device: {}".format(self._device))
106
+ # self._device = torch.device(device)
107
+ self._device = device
108
+ self._log().info("Running on device: {}".format(device))
122
109
 
123
110
  self._config = config
124
111
  self.enable_post_process = True
@@ -131,11 +118,10 @@ class TFPredictor:
131
118
 
132
119
  self._init_word_map()
133
120
 
134
- # Set the number of torch threads
135
- if num_threads is None:
136
- num_threads = int(os.environ.get("OMP_NUM_THREADS", 4))
137
- self._num_threads = num_threads
138
- torch.set_num_threads(num_threads)
121
+ # Set the number of threads
122
+ if device == "cpu":
123
+ self._num_threads = num_threads
124
+ torch.set_num_threads(self._num_threads)
139
125
 
140
126
  # Load the model
141
127
  self._model = self._load_model()
@@ -200,10 +186,21 @@ class TFPredictor:
200
186
  if self._model_type == "TableModel02":
201
187
  self._remove_padding = True
202
188
 
203
- # Load model from checkpoint
204
- success, _, _, _, _ = model.load()
205
- if not success:
206
- err_msg = "Cannot load the model"
189
+ # Load model from safetensors
190
+ save_dir = self._config["model"]["save_dir"]
191
+ models_fn = glob.glob(f"{save_dir}/tableformer_*.safetensors")
192
+ if not models_fn:
193
+ err_msg = "Not able to find a model file for {}".format(self._model_type)
194
+ self._log().error(err_msg)
195
+ raise ValueError(err_msg)
196
+ model_fn = models_fn[
197
+ 0
198
+ ] # Take the first tableformer safetensors file inside the save_dir
199
+ missing, unexpected = load_model(model, model_fn, device=self._device)
200
+ if missing or unexpected:
201
+ err_msg = "Not able to load the model weights for {}".format(
202
+ self._model_type
203
+ )
207
204
  self._log().error(err_msg)
208
205
  raise ValueError(err_msg)
209
206
 
@@ -49,15 +49,15 @@ def otsl_sqr_chk(rs_list, name, logdebug):
49
49
  isSquare = False
50
50
  if isSquare:
51
51
  if logdebug:
52
- print(
52
+ logger.debug(
53
53
  "{}*OK* Table is square! *OK*{}".format(
54
54
  bcolors.OKGREEN, bcolors.ENDC
55
55
  )
56
56
  )
57
57
  else:
58
58
  err_name = "{}*ERR* " + name + " *ERR*{}"
59
- print(err_name.format(bcolors.FAIL, bcolors.ENDC))
60
- print(
59
+ logger.debug(err_name.format(bcolors.FAIL, bcolors.ENDC))
60
+ logger.debug(
61
61
  "{}*ERR* Table is not square! *ERR*{}".format(
62
62
  bcolors.FAIL, bcolors.ENDC
63
63
  )
@@ -89,9 +89,9 @@ def otsl_tags_cells_sync_chk(rs_list, cells, name, logdebug):
89
89
  countCellTags += 1
90
90
  if countCellTags != len(cells):
91
91
  err_name = "{}*!ERR* " + name + " *ERR!*{}"
92
- print(err_name.format(bcolors.FAIL, bcolors.ENDC))
92
+ logger.debug(err_name.format(bcolors.FAIL, bcolors.ENDC))
93
93
  err_msg = "{}*!ERR* Tags are not in sync with cells! *ERR!*{}"
94
- print(err_msg.format(bcolors.FAIL, bcolors.ENDC))
94
+ logger.debug(err_msg.format(bcolors.FAIL, bcolors.ENDC))
95
95
  isGood = False
96
96
  return isGood
97
97
 
@@ -131,11 +131,13 @@ def otsl_to_html(rs_list, logdebug):
131
131
  return rs_list
132
132
  html_table = []
133
133
  if logdebug:
134
- print("{}*Reconstructing HTML...*{}".format(bcolors.WARNING, bcolors.ENDC))
134
+ logger.debug(
135
+ "{}*Reconstructing HTML...*{}".format(bcolors.WARNING, bcolors.ENDC)
136
+ )
135
137
 
136
138
  if not otsl_sqr_chk(rs_list, "---", logdebug):
137
139
  # PAD TABLE TO SQUARE
138
- print("{}*Padding to square...*{}".format(bcolors.WARNING, bcolors.ENDC))
140
+ logger.debug("{}*Padding to square...*{}".format(bcolors.WARNING, bcolors.ENDC))
139
141
  rs_list = otsl_pad_to_sqr(rs_list, "lcel")
140
142
 
141
143
  # 2D structure, line by line:
@@ -144,7 +146,7 @@ def otsl_to_html(rs_list, logdebug):
144
146
  ]
145
147
 
146
148
  if logdebug:
147
- print("")
149
+ logger.debug("")
148
150
 
149
151
  # Sequentially store indexes of 2D spans that were registered to avoid re-registering them
150
152
  registry_2d_span = []
@@ -182,9 +184,9 @@ def otsl_to_html(rs_list, logdebug):
182
184
  span = True
183
185
  # Check if it has vertical span:
184
186
  if rs_row_ind + 1 < len(rs_list_split):
185
- # print(">>>")
186
- # print(rs_list_split[rs_row_ind + 1])
187
- # print(">>> rs_cell_ind = {}".format(rs_cell_ind))
187
+ # logger.debug(">>>")
188
+ # logger.debug(rs_list_split[rs_row_ind + 1])
189
+ # logger.debug(">>> rs_cell_ind = {}".format(rs_cell_ind))
188
190
  if rs_list_split[rs_row_ind + 1][rs_cell_ind] == "ucel":
189
191
  ddist = otsl_check_down(rs_list_split, rs_cell_ind, rs_row_ind)
190
192
  span = True
@@ -198,12 +200,12 @@ def otsl_to_html(rs_list, logdebug):
198
200
  span = True
199
201
  # Check if this 2D span was already registered,
200
202
  # If not - register, if yes - cancel span
201
- # print("rs_cell_ind: {}, xrdist:{}".format(rs_cell_ind, xrdist))
202
- # print("rs_row_ind: {}, xddist:{}".format(rs_cell_ind, xrdist))
203
+ # logger.debug("rs_cell_ind: {}, xrdist:{}".format(rs_cell_ind, xrdist))
204
+ # logger.debug("rs_row_ind: {}, xddist:{}".format(rs_cell_ind, xrdist))
203
205
  for x in range(rs_cell_ind, xrdist + rs_cell_ind):
204
206
  for y in range(rs_row_ind, xddist + rs_row_ind):
205
207
  reg2dind = str(x) + "_" + str(y)
206
- # print(reg2dind)
208
+ # logger.debug(reg2dind)
207
209
  if reg2dind in registry_2d_span:
208
210
  # Cell of the span is already in, cancel current span
209
211
  span = False
@@ -232,9 +234,13 @@ def otsl_to_html(rs_list, logdebug):
232
234
  html_table.extend(html_list)
233
235
 
234
236
  if logdebug:
235
- print("*********************** registry_2d_span ***************************")
236
- print(registry_2d_span)
237
- print("********************************************************************")
237
+ logger.debug(
238
+ "*********************** registry_2d_span ***************************"
239
+ )
240
+ logger.debug(registry_2d_span)
241
+ logger.debug(
242
+ "********************************************************************"
243
+ )
238
244
 
239
245
  return html_table
240
246
 
@@ -316,20 +322,24 @@ def html_to_otsl(table, writer, logdebug, extra_debug, include_html, use_writer)
316
322
  current_line_expands = []
317
323
 
318
324
  if logdebug:
319
- print("")
320
- print("*** {}: {} ***".format(table["split"], table["filename"]))
325
+ logger.debug("")
326
+ logger.debug("*** {}: {} ***".format(table["split"], table["filename"]))
321
327
 
322
328
  colnum = 0
323
329
 
324
330
  if extra_debug:
325
- print("========================== Input HTML ============================")
326
- print(table_html_structure["tokens"])
327
- print("==================================================================")
331
+ logger.debug(
332
+ "========================== Input HTML ============================"
333
+ )
334
+ logger.debug(table_html_structure["tokens"])
335
+ logger.debug(
336
+ "=================================================================="
337
+ )
328
338
 
329
339
  if logdebug:
330
- print("********")
331
- print("* OTSL *")
332
- print("********")
340
+ logger.debug("********")
341
+ logger.debug("* OTSL *")
342
+ logger.debug("********")
333
343
 
334
344
  for i in range(len(table_html_structure["tokens"])):
335
345
  html_tag = table_html_structure["tokens"][i]
@@ -377,7 +387,7 @@ def html_to_otsl(table, writer, logdebug, extra_debug, include_html, use_writer)
377
387
  extra_columns = pre_line_len - cur_line_len - 1
378
388
  if extra_columns > 0:
379
389
  if extra_debug:
380
- print(
390
+ logger.debug(
381
391
  "Extra columns needed in row: {}".format(
382
392
  extra_columns
383
393
  )
@@ -534,11 +544,11 @@ def html_to_otsl(table, writer, logdebug, extra_debug, include_html, use_writer)
534
544
  writer.write(out_line)
535
545
 
536
546
  if logdebug:
537
- print("{}Reconstructed HTML:{}".format(bcolors.OKGREEN, bcolors.ENDC))
538
- print(rHTML)
547
+ logger.debug("{}Reconstructed HTML:{}".format(bcolors.OKGREEN, bcolors.ENDC))
548
+ logger.debug(rHTML)
539
549
  # original HTML
540
550
  oHTML = out_line["html"]["html_structure"]
541
- print("{}Original HTML:{}".format(bcolors.OKBLUE, bcolors.ENDC))
542
- print(oHTML)
551
+ logger.debug("{}Original HTML:{}".format(bcolors.OKBLUE, bcolors.ENDC))
552
+ logger.debug(oHTML)
543
553
 
544
554
  return True, out_line
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "docling-ibm-models"
3
- version = "2.0.7" # DO NOT EDIT, updated automatically
3
+ version = "3.0.0" # DO NOT EDIT, updated automatically
4
4
  description = "This package contains the AI models used by the Docling PDF conversion package"
5
5
  authors = ["Nikos Livathinos <nli@zurich.ibm.com>", "Maxim Lysak <mly@zurich.ibm.com>", "Ahmed Nassar <ahn@zurich.ibm.com>", "Christoph Auer <cau@zurich.ibm.com>", "Michele Dolfi <dol@zurich.ibm.com>", "Peter Staar <taa@zurich.ibm.com>"]
6
6
  license = "MIT"
@@ -24,12 +24,14 @@ packages = [
24
24
  python = "^3.9"
25
25
  torch = "^2.2.2"
26
26
  torchvision = "^0"
27
+ transformers = "^4.42.0"
27
28
  numpy = ">=1.24.4,<3.0.0"
28
29
  jsonlines = "^3.1.0"
29
30
  Pillow = "^10.0.0"
30
31
  tqdm = "^4.64.0"
31
32
  opencv-python-headless = "^4.6.0.66"
32
33
  huggingface_hub = ">=0.23,<1"
34
+ safetensors = {version=">=0.4.3,<1", extras=["torch"]}
33
35
 
34
36
  [tool.poetry.group.dev.dependencies]
35
37
  black = {extras = ["jupyter"], version = "^24.4.2"}
@@ -96,3 +98,16 @@ branch = "main"
96
98
  parser_angular_allowed_types = "build,chore,ci,docs,feat,fix,perf,style,refactor,test"
97
99
  parser_angular_minor_types = "feat"
98
100
  parser_angular_patch_types = "fix,perf"
101
+
102
+
103
+ # [tool.mypy]
104
+ # pretty = true
105
+ # no_implicit_optional = true
106
+ # python_version = "3.10"
107
+ #
108
+ # [[tool.mypy.overrides]]
109
+ # module = [
110
+ # "torchvision.*",
111
+ # "transformers.*"
112
+ # ]
113
+ # ignore_missing_imports = true
@@ -1,167 +0,0 @@
1
- #
2
- # Copyright IBM Corp. 2024 - 2024
3
- # SPDX-License-Identifier: MIT
4
- #
5
- import os
6
- from collections.abc import Iterable
7
- from typing import Union
8
-
9
- import numpy as np
10
- import torch
11
- import torchvision.transforms as T
12
- from PIL import Image
13
-
14
- MODEL_CHECKPOINT_FN = "model.pt"
15
- DEFAULT_NUM_THREADS = 4
16
-
17
-
18
- class LayoutPredictor:
19
- r"""
20
- Document layout prediction using torch
21
- """
22
-
23
- def __init__(
24
- self, artifact_path: str, num_threads: int = None, use_cpu_only: bool = False
25
- ):
26
- r"""
27
- Provide the artifact path that contains the LayoutModel file
28
-
29
- The number of threads is decided, in the following order, by:
30
- 1. The init method parameter `num_threads`, if it is set.
31
- 2. The envvar "OMP_NUM_THREADS", if it is set.
32
- 3. The default value DEFAULT_NUM_THREADS.
33
-
34
- The execution provided is decided, in the following order:
35
- 1. If the init method parameter `cpu_only` is True or the envvar "USE_CPU_ONLY" is set,
36
- it uses the "CPUExecutionProvider".
37
- 3. Otherwise if the "CUDAExecutionProvider" is present, use:
38
- ["CUDAExecutionProvider", "CPUExecutionProvider"]:
39
-
40
- Parameters
41
- ----------
42
- artifact_path: Path for the model torch file.
43
- num_threads: (Optional) Number of threads to run the inference.
44
- use_cpu_only: (Optional) If True, it forces CPU as the execution provider.
45
-
46
- Raises
47
- ------
48
- FileNotFoundError when the model's torch file is missing
49
- """
50
- # Initialize classes map:
51
- self._classes_map = {
52
- 0: "background",
53
- 1: "Caption",
54
- 2: "Footnote",
55
- 3: "Formula",
56
- 4: "List-item",
57
- 5: "Page-footer",
58
- 6: "Page-header",
59
- 7: "Picture",
60
- 8: "Section-header",
61
- 9: "Table",
62
- 10: "Text",
63
- 11: "Title",
64
- 12: "Document Index",
65
- 13: "Code",
66
- 14: "Checkbox-Selected",
67
- 15: "Checkbox-Unselected",
68
- 16: "Form",
69
- 17: "Key-Value Region",
70
- }
71
-
72
- # Blacklisted classes
73
- self._black_classes = set(["Form", "Key-Value Region"])
74
-
75
- # Set basic params
76
- self._threshold = 0.6 # Score threshold
77
- self._image_size = 640
78
- self._size = np.asarray([[self._image_size, self._image_size]], dtype=np.int64)
79
- self._use_cpu_only = use_cpu_only or ("USE_CPU_ONLY" in os.environ)
80
-
81
- # Model file
82
- self._torch_fn = os.path.join(artifact_path, MODEL_CHECKPOINT_FN)
83
- if not os.path.isfile(self._torch_fn):
84
- raise FileNotFoundError("Missing torch file: {}".format(self._torch_fn))
85
-
86
- # Get env vars
87
- if num_threads is None:
88
- num_threads = int(os.environ.get("OMP_NUM_THREADS", DEFAULT_NUM_THREADS))
89
- self._num_threads = num_threads
90
-
91
- self.model = torch.jit.load(self._torch_fn)
92
-
93
- def info(self) -> dict:
94
- r"""
95
- Get information about the configuration of LayoutPredictor
96
- """
97
- info = {
98
- "torch_file": self._torch_fn,
99
- "use_cpu_only": self._use_cpu_only,
100
- "image_size": self._image_size,
101
- "threshold": self._threshold,
102
- }
103
- return info
104
-
105
- def predict(self, orig_img: Union[Image.Image, np.ndarray]) -> Iterable[dict]:
106
- r"""
107
- Predict bounding boxes for a given image.
108
- The origin (0, 0) is the top-left corner and the predicted bbox coords are provided as:
109
- [left, top, right, bottom]
110
-
111
- Parameter
112
- ---------
113
- origin_img: Image to be predicted as a PIL Image object or numpy array.
114
-
115
- Yield
116
- -----
117
- Bounding box as a dict with the keys: "label", "confidence", "l", "t", "r", "b"
118
-
119
- Raises
120
- ------
121
- TypeError when the input image is not supported
122
- """
123
- # Convert image format
124
- if isinstance(orig_img, Image.Image):
125
- page_img = orig_img.convert("RGB")
126
- elif isinstance(orig_img, np.ndarray):
127
- page_img = Image.fromarray(orig_img).convert("RGB")
128
- else:
129
- raise TypeError("Not supported input image format")
130
-
131
- w, h = page_img.size
132
- orig_size = torch.tensor([w, h])[None]
133
-
134
- transforms = T.Compose(
135
- [
136
- T.Resize((640, 640)),
137
- T.ToTensor(),
138
- ]
139
- )
140
- img = transforms(page_img)[None]
141
- # Predict
142
- with torch.no_grad():
143
- labels, boxes, scores = self.model(img, orig_size)
144
-
145
- # Yield output
146
- for label_idx, box, score in zip(labels[0], boxes[0], scores[0]):
147
- # Filter out blacklisted classes
148
- label_idx = int(label_idx.item())
149
- score = float(score.item())
150
- label = self._classes_map[label_idx + 1]
151
- if label in self._black_classes:
152
- continue
153
-
154
- # Check against threshold
155
- if score > self._threshold:
156
- l = min(w, max(0, box[0]))
157
- t = min(h, max(0, box[1]))
158
- r = min(w, max(0, box[2]))
159
- b = min(h, max(0, box[3]))
160
- yield {
161
- "l": l,
162
- "t": t,
163
- "r": r,
164
- "b": b,
165
- "label": label,
166
- "confidence": score,
167
- }