docling 2.17.0__py3-none-any.whl → 2.19.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docling/backend/html_backend.py +18 -18
- docling/backend/md_backend.py +144 -75
- docling/backend/mspowerpoint_backend.py +39 -27
- docling/backend/msword_backend.py +173 -131
- docling/cli/main.py +8 -0
- docling/cli/models.py +105 -0
- docling/cli/tools.py +17 -0
- docling/datamodel/document.py +2 -0
- docling/datamodel/settings.py +18 -1
- docling/document_converter.py +12 -2
- docling/models/base_model.py +3 -0
- docling/models/code_formula_model.py +15 -9
- docling/models/document_picture_classifier.py +11 -8
- docling/models/easyocr_model.py +50 -3
- docling/models/layout_model.py +49 -3
- docling/models/table_structure_model.py +53 -7
- docling/pipeline/base_pipeline.py +4 -2
- docling/pipeline/standard_pdf_pipeline.py +25 -24
- docling/utils/glm_utils.py +4 -0
- docling/utils/model_downloader.py +72 -0
- docling/utils/utils.py +24 -0
- {docling-2.17.0.dist-info → docling-2.19.0.dist-info}/METADATA +11 -5
- {docling-2.17.0.dist-info → docling-2.19.0.dist-info}/RECORD +26 -23
- {docling-2.17.0.dist-info → docling-2.19.0.dist-info}/WHEEL +1 -1
- {docling-2.17.0.dist-info → docling-2.19.0.dist-info}/entry_points.txt +1 -0
- {docling-2.17.0.dist-info → docling-2.19.0.dist-info}/LICENSE +0 -0
docling/cli/models.py
ADDED
@@ -0,0 +1,105 @@
|
|
1
|
+
import logging
|
2
|
+
import warnings
|
3
|
+
from enum import Enum
|
4
|
+
from pathlib import Path
|
5
|
+
from typing import Annotated, Optional
|
6
|
+
|
7
|
+
import typer
|
8
|
+
from rich.console import Console
|
9
|
+
from rich.logging import RichHandler
|
10
|
+
|
11
|
+
from docling.datamodel.settings import settings
|
12
|
+
from docling.utils.model_downloader import download_models
|
13
|
+
|
14
|
+
warnings.filterwarnings(action="ignore", category=UserWarning, module="pydantic|torch")
|
15
|
+
warnings.filterwarnings(action="ignore", category=FutureWarning, module="easyocr")
|
16
|
+
|
17
|
+
console = Console()
|
18
|
+
err_console = Console(stderr=True)
|
19
|
+
|
20
|
+
|
21
|
+
app = typer.Typer(
|
22
|
+
name="Docling models helper",
|
23
|
+
no_args_is_help=True,
|
24
|
+
add_completion=False,
|
25
|
+
pretty_exceptions_enable=False,
|
26
|
+
)
|
27
|
+
|
28
|
+
|
29
|
+
class _AvailableModels(str, Enum):
|
30
|
+
LAYOUT = "layout"
|
31
|
+
TABLEFORMER = "tableformer"
|
32
|
+
CODE_FORMULA = "code_formula"
|
33
|
+
PICTURE_CLASSIFIER = "picture_classifier"
|
34
|
+
EASYOCR = "easyocr"
|
35
|
+
|
36
|
+
|
37
|
+
@app.command("download")
|
38
|
+
def download(
|
39
|
+
output_dir: Annotated[
|
40
|
+
Path,
|
41
|
+
typer.Option(
|
42
|
+
...,
|
43
|
+
"-o",
|
44
|
+
"--output-dir",
|
45
|
+
help="The directory where all the models are downloaded.",
|
46
|
+
),
|
47
|
+
] = (settings.cache_dir / "models"),
|
48
|
+
force: Annotated[
|
49
|
+
bool, typer.Option(..., help="If true, the download will be forced")
|
50
|
+
] = False,
|
51
|
+
models: Annotated[
|
52
|
+
Optional[list[_AvailableModels]],
|
53
|
+
typer.Argument(
|
54
|
+
help=f"Models to download (default behavior: all will be downloaded)",
|
55
|
+
),
|
56
|
+
] = None,
|
57
|
+
quiet: Annotated[
|
58
|
+
bool,
|
59
|
+
typer.Option(
|
60
|
+
...,
|
61
|
+
"-q",
|
62
|
+
"--quiet",
|
63
|
+
help="No extra output is generated, the CLI prints only the directory with the cached models.",
|
64
|
+
),
|
65
|
+
] = False,
|
66
|
+
):
|
67
|
+
if not quiet:
|
68
|
+
FORMAT = "%(message)s"
|
69
|
+
logging.basicConfig(
|
70
|
+
level=logging.INFO,
|
71
|
+
format="[blue]%(message)s[/blue]",
|
72
|
+
datefmt="[%X]",
|
73
|
+
handlers=[RichHandler(show_level=False, show_time=False, markup=True)],
|
74
|
+
)
|
75
|
+
to_download = models or [m for m in _AvailableModels]
|
76
|
+
output_dir = download_models(
|
77
|
+
output_dir=output_dir,
|
78
|
+
force=force,
|
79
|
+
progress=(not quiet),
|
80
|
+
with_layout=_AvailableModels.LAYOUT in to_download,
|
81
|
+
with_tableformer=_AvailableModels.TABLEFORMER in to_download,
|
82
|
+
with_code_formula=_AvailableModels.CODE_FORMULA in to_download,
|
83
|
+
with_picture_classifier=_AvailableModels.PICTURE_CLASSIFIER in to_download,
|
84
|
+
with_easyocr=_AvailableModels.EASYOCR in to_download,
|
85
|
+
)
|
86
|
+
|
87
|
+
if quiet:
|
88
|
+
typer.echo(output_dir)
|
89
|
+
else:
|
90
|
+
typer.secho(f"\nModels downloaded into: {output_dir}.", fg="green")
|
91
|
+
|
92
|
+
console.print(
|
93
|
+
"\n",
|
94
|
+
"Docling can now be configured for running offline using the local artifacts.\n\n",
|
95
|
+
"Using the CLI:",
|
96
|
+
f"`docling --artifacts-path={output_dir} FILE`",
|
97
|
+
"\n",
|
98
|
+
"Using Python: see the documentation at <https://ds4sd.github.io/docling/usage>.",
|
99
|
+
)
|
100
|
+
|
101
|
+
|
102
|
+
click_app = typer.main.get_command(app)
|
103
|
+
|
104
|
+
if __name__ == "__main__":
|
105
|
+
app()
|
docling/cli/tools.py
ADDED
@@ -0,0 +1,17 @@
|
|
1
|
+
import typer
|
2
|
+
|
3
|
+
from docling.cli.models import app as models_app
|
4
|
+
|
5
|
+
app = typer.Typer(
|
6
|
+
name="Docling helpers",
|
7
|
+
no_args_is_help=True,
|
8
|
+
add_completion=False,
|
9
|
+
pretty_exceptions_enable=False,
|
10
|
+
)
|
11
|
+
|
12
|
+
app.add_typer(models_app, name="models")
|
13
|
+
|
14
|
+
click_app = typer.main.get_command(app)
|
15
|
+
|
16
|
+
if __name__ == "__main__":
|
17
|
+
app()
|
docling/datamodel/document.py
CHANGED
@@ -157,6 +157,8 @@ class InputDocument(BaseModel):
|
|
157
157
|
self.page_count = self._backend.page_count()
|
158
158
|
if not self.page_count <= self.limits.max_num_pages:
|
159
159
|
self.valid = False
|
160
|
+
elif self.page_count < self.limits.page_range[0]:
|
161
|
+
self.valid = False
|
160
162
|
|
161
163
|
except (FileNotFoundError, OSError) as e:
|
162
164
|
self.valid = False
|
docling/datamodel/settings.py
CHANGED
@@ -1,13 +1,28 @@
|
|
1
1
|
import sys
|
2
2
|
from pathlib import Path
|
3
|
+
from typing import Annotated, Tuple
|
3
4
|
|
4
|
-
from pydantic import BaseModel
|
5
|
+
from pydantic import BaseModel, PlainValidator
|
5
6
|
from pydantic_settings import BaseSettings, SettingsConfigDict
|
6
7
|
|
7
8
|
|
9
|
+
def _validate_page_range(v: Tuple[int, int]) -> Tuple[int, int]:
|
10
|
+
if v[0] < 1 or v[1] < v[0]:
|
11
|
+
raise ValueError(
|
12
|
+
"Invalid page range: start must be ≥ 1 and end must be ≥ start."
|
13
|
+
)
|
14
|
+
return v
|
15
|
+
|
16
|
+
|
17
|
+
PageRange = Annotated[Tuple[int, int], PlainValidator(_validate_page_range)]
|
18
|
+
|
19
|
+
DEFAULT_PAGE_RANGE: PageRange = (1, sys.maxsize)
|
20
|
+
|
21
|
+
|
8
22
|
class DocumentLimits(BaseModel):
|
9
23
|
max_num_pages: int = sys.maxsize
|
10
24
|
max_file_size: int = sys.maxsize
|
25
|
+
page_range: PageRange = DEFAULT_PAGE_RANGE
|
11
26
|
|
12
27
|
|
13
28
|
class BatchConcurrencySettings(BaseModel):
|
@@ -46,5 +61,7 @@ class AppSettings(BaseSettings):
|
|
46
61
|
perf: BatchConcurrencySettings
|
47
62
|
debug: DebugSettings
|
48
63
|
|
64
|
+
cache_dir: Path = Path.home() / ".cache" / "docling"
|
65
|
+
|
49
66
|
|
50
67
|
settings = AppSettings(perf=BatchConcurrencySettings(), debug=DebugSettings())
|
docling/document_converter.py
CHANGED
@@ -1,9 +1,10 @@
|
|
1
1
|
import logging
|
2
|
+
import math
|
2
3
|
import sys
|
3
4
|
import time
|
4
5
|
from functools import partial
|
5
6
|
from pathlib import Path
|
6
|
-
from typing import Dict, Iterable, Iterator, List, Optional, Type, Union
|
7
|
+
from typing import Dict, Iterable, Iterator, List, Optional, Tuple, Type, Union
|
7
8
|
|
8
9
|
from pydantic import BaseModel, ConfigDict, model_validator, validate_call
|
9
10
|
|
@@ -31,7 +32,12 @@ from docling.datamodel.document import (
|
|
31
32
|
_DocumentConversionInput,
|
32
33
|
)
|
33
34
|
from docling.datamodel.pipeline_options import PipelineOptions
|
34
|
-
from docling.datamodel.settings import
|
35
|
+
from docling.datamodel.settings import (
|
36
|
+
DEFAULT_PAGE_RANGE,
|
37
|
+
DocumentLimits,
|
38
|
+
PageRange,
|
39
|
+
settings,
|
40
|
+
)
|
35
41
|
from docling.exceptions import ConversionError
|
36
42
|
from docling.pipeline.base_pipeline import BasePipeline
|
37
43
|
from docling.pipeline.simple_pipeline import SimplePipeline
|
@@ -184,6 +190,7 @@ class DocumentConverter:
|
|
184
190
|
raises_on_error: bool = True,
|
185
191
|
max_num_pages: int = sys.maxsize,
|
186
192
|
max_file_size: int = sys.maxsize,
|
193
|
+
page_range: PageRange = DEFAULT_PAGE_RANGE,
|
187
194
|
) -> ConversionResult:
|
188
195
|
all_res = self.convert_all(
|
189
196
|
source=[source],
|
@@ -191,6 +198,7 @@ class DocumentConverter:
|
|
191
198
|
max_num_pages=max_num_pages,
|
192
199
|
max_file_size=max_file_size,
|
193
200
|
headers=headers,
|
201
|
+
page_range=page_range,
|
194
202
|
)
|
195
203
|
return next(all_res)
|
196
204
|
|
@@ -202,10 +210,12 @@ class DocumentConverter:
|
|
202
210
|
raises_on_error: bool = True, # True: raises on first conversion error; False: does not raise on conv error
|
203
211
|
max_num_pages: int = sys.maxsize,
|
204
212
|
max_file_size: int = sys.maxsize,
|
213
|
+
page_range: PageRange = DEFAULT_PAGE_RANGE,
|
205
214
|
) -> Iterator[ConversionResult]:
|
206
215
|
limits = DocumentLimits(
|
207
216
|
max_num_pages=max_num_pages,
|
208
217
|
max_file_size=max_file_size,
|
218
|
+
page_range=page_range,
|
209
219
|
)
|
210
220
|
conv_input = _DocumentConversionInput(
|
211
221
|
path_or_stream_iterator=source, limits=limits, headers=headers
|
docling/models/base_model.py
CHANGED
@@ -6,6 +6,7 @@ from typing_extensions import TypeVar
|
|
6
6
|
|
7
7
|
from docling.datamodel.base_models import ItemAndImageEnrichmentElement, Page
|
8
8
|
from docling.datamodel.document import ConversionResult
|
9
|
+
from docling.datamodel.settings import settings
|
9
10
|
|
10
11
|
|
11
12
|
class BasePageModel(ABC):
|
@@ -21,6 +22,8 @@ EnrichElementT = TypeVar("EnrichElementT", default=NodeItem)
|
|
21
22
|
|
22
23
|
class GenericEnrichmentModel(ABC, Generic[EnrichElementT]):
|
23
24
|
|
25
|
+
elements_batch_size: int = settings.perf.elements_batch_size
|
26
|
+
|
24
27
|
@abstractmethod
|
25
28
|
def is_processable(self, doc: DoclingDocument, element: NodeItem) -> bool:
|
26
29
|
pass
|
@@ -2,6 +2,7 @@ import re
|
|
2
2
|
from pathlib import Path
|
3
3
|
from typing import Iterable, List, Literal, Optional, Tuple, Union
|
4
4
|
|
5
|
+
import numpy as np
|
5
6
|
from docling_core.types.doc import (
|
6
7
|
CodeItem,
|
7
8
|
DocItemLabel,
|
@@ -61,13 +62,15 @@ class CodeFormulaModel(BaseItemAndImageEnrichmentModel):
|
|
61
62
|
Processes the given batch of elements and enriches them with predictions.
|
62
63
|
"""
|
63
64
|
|
65
|
+
_model_repo_folder = "ds4sd--CodeFormula"
|
66
|
+
elements_batch_size = 5
|
64
67
|
images_scale = 1.66 # = 120 dpi, aligned with training data resolution
|
65
68
|
expansion_factor = 0.03
|
66
69
|
|
67
70
|
def __init__(
|
68
71
|
self,
|
69
72
|
enabled: bool,
|
70
|
-
artifacts_path: Optional[
|
73
|
+
artifacts_path: Optional[Path],
|
71
74
|
options: CodeFormulaModelOptions,
|
72
75
|
accelerator_options: AcceleratorOptions,
|
73
76
|
):
|
@@ -96,29 +99,32 @@ class CodeFormulaModel(BaseItemAndImageEnrichmentModel):
|
|
96
99
|
)
|
97
100
|
|
98
101
|
if artifacts_path is None:
|
99
|
-
artifacts_path = self.
|
102
|
+
artifacts_path = self.download_models()
|
100
103
|
else:
|
101
|
-
artifacts_path =
|
104
|
+
artifacts_path = artifacts_path / self._model_repo_folder
|
102
105
|
|
103
106
|
self.code_formula_model = CodeFormulaPredictor(
|
104
|
-
artifacts_path=artifacts_path,
|
107
|
+
artifacts_path=str(artifacts_path),
|
105
108
|
device=device,
|
106
109
|
num_threads=accelerator_options.num_threads,
|
107
110
|
)
|
108
111
|
|
109
112
|
@staticmethod
|
110
|
-
def
|
111
|
-
local_dir: Optional[Path] = None,
|
113
|
+
def download_models(
|
114
|
+
local_dir: Optional[Path] = None,
|
115
|
+
force: bool = False,
|
116
|
+
progress: bool = False,
|
112
117
|
) -> Path:
|
113
118
|
from huggingface_hub import snapshot_download
|
114
119
|
from huggingface_hub.utils import disable_progress_bars
|
115
120
|
|
116
|
-
|
121
|
+
if not progress:
|
122
|
+
disable_progress_bars()
|
117
123
|
download_path = snapshot_download(
|
118
124
|
repo_id="ds4sd/CodeFormula",
|
119
125
|
force_download=force,
|
120
126
|
local_dir=local_dir,
|
121
|
-
revision="v1.0.
|
127
|
+
revision="v1.0.1",
|
122
128
|
)
|
123
129
|
|
124
130
|
return Path(download_path)
|
@@ -226,7 +232,7 @@ class CodeFormulaModel(BaseItemAndImageEnrichmentModel):
|
|
226
232
|
return
|
227
233
|
|
228
234
|
labels: List[str] = []
|
229
|
-
images: List[Image.Image] = []
|
235
|
+
images: List[Union[Image.Image, np.ndarray]] = []
|
230
236
|
elements: List[TextItem] = []
|
231
237
|
for el in element_batch:
|
232
238
|
assert isinstance(el.item, TextItem)
|
@@ -1,6 +1,7 @@
|
|
1
1
|
from pathlib import Path
|
2
2
|
from typing import Iterable, List, Literal, Optional, Tuple, Union
|
3
3
|
|
4
|
+
import numpy as np
|
4
5
|
from docling_core.types.doc import (
|
5
6
|
DoclingDocument,
|
6
7
|
NodeItem,
|
@@ -55,12 +56,13 @@ class DocumentPictureClassifier(BaseEnrichmentModel):
|
|
55
56
|
Processes a batch of elements and adds classification annotations.
|
56
57
|
"""
|
57
58
|
|
59
|
+
_model_repo_folder = "ds4sd--DocumentFigureClassifier"
|
58
60
|
images_scale = 2
|
59
61
|
|
60
62
|
def __init__(
|
61
63
|
self,
|
62
64
|
enabled: bool,
|
63
|
-
artifacts_path: Optional[
|
65
|
+
artifacts_path: Optional[Path],
|
64
66
|
options: DocumentPictureClassifierOptions,
|
65
67
|
accelerator_options: AcceleratorOptions,
|
66
68
|
):
|
@@ -88,24 +90,25 @@ class DocumentPictureClassifier(BaseEnrichmentModel):
|
|
88
90
|
)
|
89
91
|
|
90
92
|
if artifacts_path is None:
|
91
|
-
artifacts_path = self.
|
93
|
+
artifacts_path = self.download_models()
|
92
94
|
else:
|
93
|
-
artifacts_path =
|
95
|
+
artifacts_path = artifacts_path / self._model_repo_folder
|
94
96
|
|
95
97
|
self.document_picture_classifier = DocumentFigureClassifierPredictor(
|
96
|
-
artifacts_path=artifacts_path,
|
98
|
+
artifacts_path=str(artifacts_path),
|
97
99
|
device=device,
|
98
100
|
num_threads=accelerator_options.num_threads,
|
99
101
|
)
|
100
102
|
|
101
103
|
@staticmethod
|
102
|
-
def
|
103
|
-
local_dir: Optional[Path] = None, force: bool = False
|
104
|
+
def download_models(
|
105
|
+
local_dir: Optional[Path] = None, force: bool = False, progress: bool = False
|
104
106
|
) -> Path:
|
105
107
|
from huggingface_hub import snapshot_download
|
106
108
|
from huggingface_hub.utils import disable_progress_bars
|
107
109
|
|
108
|
-
|
110
|
+
if not progress:
|
111
|
+
disable_progress_bars()
|
109
112
|
download_path = snapshot_download(
|
110
113
|
repo_id="ds4sd/DocumentFigureClassifier",
|
111
114
|
force_download=force,
|
@@ -159,7 +162,7 @@ class DocumentPictureClassifier(BaseEnrichmentModel):
|
|
159
162
|
yield element
|
160
163
|
return
|
161
164
|
|
162
|
-
images: List[Image.Image] = []
|
165
|
+
images: List[Union[Image.Image, np.ndarray]] = []
|
163
166
|
elements: List[PictureItem] = []
|
164
167
|
for el in element_batch:
|
165
168
|
assert isinstance(el, PictureItem)
|
docling/models/easyocr_model.py
CHANGED
@@ -1,7 +1,10 @@
|
|
1
1
|
import logging
|
2
2
|
import warnings
|
3
|
-
|
3
|
+
import zipfile
|
4
|
+
from pathlib import Path
|
5
|
+
from typing import Iterable, List, Optional
|
4
6
|
|
7
|
+
import httpx
|
5
8
|
import numpy
|
6
9
|
import torch
|
7
10
|
from docling_core.types.doc import BoundingBox, CoordOrigin
|
@@ -17,14 +20,18 @@ from docling.datamodel.settings import settings
|
|
17
20
|
from docling.models.base_ocr_model import BaseOcrModel
|
18
21
|
from docling.utils.accelerator_utils import decide_device
|
19
22
|
from docling.utils.profiling import TimeRecorder
|
23
|
+
from docling.utils.utils import download_url_with_progress
|
20
24
|
|
21
25
|
_log = logging.getLogger(__name__)
|
22
26
|
|
23
27
|
|
24
28
|
class EasyOcrModel(BaseOcrModel):
|
29
|
+
_model_repo_folder = "EasyOcr"
|
30
|
+
|
25
31
|
def __init__(
|
26
32
|
self,
|
27
33
|
enabled: bool,
|
34
|
+
artifacts_path: Optional[Path],
|
28
35
|
options: EasyOcrOptions,
|
29
36
|
accelerator_options: AcceleratorOptions,
|
30
37
|
):
|
@@ -62,15 +69,55 @@ class EasyOcrModel(BaseOcrModel):
|
|
62
69
|
)
|
63
70
|
use_gpu = self.options.use_gpu
|
64
71
|
|
72
|
+
download_enabled = self.options.download_enabled
|
73
|
+
model_storage_directory = self.options.model_storage_directory
|
74
|
+
if artifacts_path is not None and model_storage_directory is None:
|
75
|
+
download_enabled = False
|
76
|
+
model_storage_directory = str(artifacts_path / self._model_repo_folder)
|
77
|
+
|
65
78
|
self.reader = easyocr.Reader(
|
66
79
|
lang_list=self.options.lang,
|
67
80
|
gpu=use_gpu,
|
68
|
-
model_storage_directory=
|
81
|
+
model_storage_directory=model_storage_directory,
|
69
82
|
recog_network=self.options.recog_network,
|
70
|
-
download_enabled=
|
83
|
+
download_enabled=download_enabled,
|
71
84
|
verbose=False,
|
72
85
|
)
|
73
86
|
|
87
|
+
@staticmethod
|
88
|
+
def download_models(
|
89
|
+
detection_models: List[str] = ["craft"],
|
90
|
+
recognition_models: List[str] = ["english_g2", "latin_g2"],
|
91
|
+
local_dir: Optional[Path] = None,
|
92
|
+
force: bool = False,
|
93
|
+
progress: bool = False,
|
94
|
+
) -> Path:
|
95
|
+
# Models are located in https://github.com/JaidedAI/EasyOCR/blob/master/easyocr/config.py
|
96
|
+
from easyocr.config import detection_models as det_models_dict
|
97
|
+
from easyocr.config import recognition_models as rec_models_dict
|
98
|
+
|
99
|
+
if local_dir is None:
|
100
|
+
local_dir = settings.cache_dir / "models" / EasyOcrModel._model_repo_folder
|
101
|
+
|
102
|
+
local_dir.mkdir(parents=True, exist_ok=True)
|
103
|
+
|
104
|
+
# Collect models to download
|
105
|
+
download_list = []
|
106
|
+
for model_name in detection_models:
|
107
|
+
if model_name in det_models_dict:
|
108
|
+
download_list.append(det_models_dict[model_name])
|
109
|
+
for model_name in recognition_models:
|
110
|
+
if model_name in rec_models_dict["gen2"]:
|
111
|
+
download_list.append(rec_models_dict["gen2"][model_name])
|
112
|
+
|
113
|
+
# Download models
|
114
|
+
for model_details in download_list:
|
115
|
+
buf = download_url_with_progress(model_details["url"], progress=progress)
|
116
|
+
with zipfile.ZipFile(buf, "r") as zip_ref:
|
117
|
+
zip_ref.extractall(local_dir)
|
118
|
+
|
119
|
+
return local_dir
|
120
|
+
|
74
121
|
def __call__(
|
75
122
|
self, conv_res: ConversionResult, page_batch: Iterable[Page]
|
76
123
|
) -> Iterable[Page]:
|
docling/models/layout_model.py
CHANGED
@@ -1,7 +1,8 @@
|
|
1
1
|
import copy
|
2
2
|
import logging
|
3
|
+
import warnings
|
3
4
|
from pathlib import Path
|
4
|
-
from typing import Iterable
|
5
|
+
from typing import Iterable, Optional, Union
|
5
6
|
|
6
7
|
from docling_core.types.doc import DocItemLabel
|
7
8
|
from docling_ibm_models.layoutmodel.layout_predictor import LayoutPredictor
|
@@ -21,6 +22,8 @@ _log = logging.getLogger(__name__)
|
|
21
22
|
|
22
23
|
|
23
24
|
class LayoutModel(BasePageModel):
|
25
|
+
_model_repo_folder = "ds4sd--docling-models"
|
26
|
+
_model_path = "model_artifacts/layout"
|
24
27
|
|
25
28
|
TEXT_ELEM_LABELS = [
|
26
29
|
DocItemLabel.TEXT,
|
@@ -42,15 +45,56 @@ class LayoutModel(BasePageModel):
|
|
42
45
|
FORMULA_LABEL = DocItemLabel.FORMULA
|
43
46
|
CONTAINER_LABELS = [DocItemLabel.FORM, DocItemLabel.KEY_VALUE_REGION]
|
44
47
|
|
45
|
-
def __init__(
|
48
|
+
def __init__(
|
49
|
+
self, artifacts_path: Optional[Path], accelerator_options: AcceleratorOptions
|
50
|
+
):
|
46
51
|
device = decide_device(accelerator_options.device)
|
47
52
|
|
53
|
+
if artifacts_path is None:
|
54
|
+
artifacts_path = self.download_models() / self._model_path
|
55
|
+
else:
|
56
|
+
# will become the default in the future
|
57
|
+
if (artifacts_path / self._model_repo_folder).exists():
|
58
|
+
artifacts_path = (
|
59
|
+
artifacts_path / self._model_repo_folder / self._model_path
|
60
|
+
)
|
61
|
+
elif (artifacts_path / self._model_path).exists():
|
62
|
+
warnings.warn(
|
63
|
+
"The usage of artifacts_path containing directly "
|
64
|
+
f"{self._model_path} is deprecated. Please point "
|
65
|
+
"the artifacts_path to the parent containing "
|
66
|
+
f"the {self._model_repo_folder} folder.",
|
67
|
+
DeprecationWarning,
|
68
|
+
stacklevel=3,
|
69
|
+
)
|
70
|
+
artifacts_path = artifacts_path / self._model_path
|
71
|
+
|
48
72
|
self.layout_predictor = LayoutPredictor(
|
49
73
|
artifact_path=str(artifacts_path),
|
50
74
|
device=device,
|
51
75
|
num_threads=accelerator_options.num_threads,
|
52
76
|
)
|
53
77
|
|
78
|
+
@staticmethod
|
79
|
+
def download_models(
|
80
|
+
local_dir: Optional[Path] = None,
|
81
|
+
force: bool = False,
|
82
|
+
progress: bool = False,
|
83
|
+
) -> Path:
|
84
|
+
from huggingface_hub import snapshot_download
|
85
|
+
from huggingface_hub.utils import disable_progress_bars
|
86
|
+
|
87
|
+
if not progress:
|
88
|
+
disable_progress_bars()
|
89
|
+
download_path = snapshot_download(
|
90
|
+
repo_id="ds4sd/docling-models",
|
91
|
+
force_download=force,
|
92
|
+
local_dir=local_dir,
|
93
|
+
revision="v2.1.0",
|
94
|
+
)
|
95
|
+
|
96
|
+
return Path(download_path)
|
97
|
+
|
54
98
|
def draw_clusters_and_cells_side_by_side(
|
55
99
|
self, conv_res, page, clusters, mode_prefix: str, show: bool = False
|
56
100
|
):
|
@@ -106,10 +150,12 @@ class LayoutModel(BasePageModel):
|
|
106
150
|
else:
|
107
151
|
with TimeRecorder(conv_res, "layout"):
|
108
152
|
assert page.size is not None
|
153
|
+
page_image = page.get_image(scale=1.0)
|
154
|
+
assert page_image is not None
|
109
155
|
|
110
156
|
clusters = []
|
111
157
|
for ix, pred_item in enumerate(
|
112
|
-
self.layout_predictor.predict(
|
158
|
+
self.layout_predictor.predict(page_image)
|
113
159
|
):
|
114
160
|
label = DocItemLabel(
|
115
161
|
pred_item["label"]
|
@@ -1,6 +1,7 @@
|
|
1
1
|
import copy
|
2
|
+
import warnings
|
2
3
|
from pathlib import Path
|
3
|
-
from typing import Iterable
|
4
|
+
from typing import Iterable, Optional, Union
|
4
5
|
|
5
6
|
import numpy
|
6
7
|
from docling_core.types.doc import BoundingBox, DocItemLabel, TableCell
|
@@ -22,10 +23,13 @@ from docling.utils.profiling import TimeRecorder
|
|
22
23
|
|
23
24
|
|
24
25
|
class TableStructureModel(BasePageModel):
|
26
|
+
_model_repo_folder = "ds4sd--docling-models"
|
27
|
+
_model_path = "model_artifacts/tableformer"
|
28
|
+
|
25
29
|
def __init__(
|
26
30
|
self,
|
27
31
|
enabled: bool,
|
28
|
-
artifacts_path: Path,
|
32
|
+
artifacts_path: Optional[Path],
|
29
33
|
options: TableStructureOptions,
|
30
34
|
accelerator_options: AcceleratorOptions,
|
31
35
|
):
|
@@ -35,6 +39,26 @@ class TableStructureModel(BasePageModel):
|
|
35
39
|
|
36
40
|
self.enabled = enabled
|
37
41
|
if self.enabled:
|
42
|
+
|
43
|
+
if artifacts_path is None:
|
44
|
+
artifacts_path = self.download_models() / self._model_path
|
45
|
+
else:
|
46
|
+
# will become the default in the future
|
47
|
+
if (artifacts_path / self._model_repo_folder).exists():
|
48
|
+
artifacts_path = (
|
49
|
+
artifacts_path / self._model_repo_folder / self._model_path
|
50
|
+
)
|
51
|
+
elif (artifacts_path / self._model_path).exists():
|
52
|
+
warnings.warn(
|
53
|
+
"The usage of artifacts_path containing directly "
|
54
|
+
f"{self._model_path} is deprecated. Please point "
|
55
|
+
"the artifacts_path to the parent containing "
|
56
|
+
f"the {self._model_repo_folder} folder.",
|
57
|
+
DeprecationWarning,
|
58
|
+
stacklevel=3,
|
59
|
+
)
|
60
|
+
artifacts_path = artifacts_path / self._model_path
|
61
|
+
|
38
62
|
if self.mode == TableFormerMode.ACCURATE:
|
39
63
|
artifacts_path = artifacts_path / "accurate"
|
40
64
|
else:
|
@@ -58,6 +82,24 @@ class TableStructureModel(BasePageModel):
|
|
58
82
|
)
|
59
83
|
self.scale = 2.0 # Scale up table input images to 144 dpi
|
60
84
|
|
85
|
+
@staticmethod
|
86
|
+
def download_models(
|
87
|
+
local_dir: Optional[Path] = None, force: bool = False, progress: bool = False
|
88
|
+
) -> Path:
|
89
|
+
from huggingface_hub import snapshot_download
|
90
|
+
from huggingface_hub.utils import disable_progress_bars
|
91
|
+
|
92
|
+
if not progress:
|
93
|
+
disable_progress_bars()
|
94
|
+
download_path = snapshot_download(
|
95
|
+
repo_id="ds4sd/docling-models",
|
96
|
+
force_download=force,
|
97
|
+
local_dir=local_dir,
|
98
|
+
revision="v2.1.0",
|
99
|
+
)
|
100
|
+
|
101
|
+
return Path(download_path)
|
102
|
+
|
61
103
|
def draw_table_and_cells(
|
62
104
|
self,
|
63
105
|
conv_res: ConversionResult,
|
@@ -209,12 +251,16 @@ class TableStructureModel(BasePageModel):
|
|
209
251
|
tc.bbox = tc.bbox.scaled(1 / self.scale)
|
210
252
|
table_cells.append(tc)
|
211
253
|
|
254
|
+
assert "predict_details" in table_out
|
255
|
+
|
212
256
|
# Retrieving cols/rows, after post processing:
|
213
|
-
num_rows = table_out["predict_details"]
|
214
|
-
num_cols = table_out["predict_details"]
|
215
|
-
otsl_seq =
|
216
|
-
"
|
217
|
-
|
257
|
+
num_rows = table_out["predict_details"].get("num_rows", 0)
|
258
|
+
num_cols = table_out["predict_details"].get("num_cols", 0)
|
259
|
+
otsl_seq = (
|
260
|
+
table_out["predict_details"]
|
261
|
+
.get("prediction", {})
|
262
|
+
.get("rs_seq", [])
|
263
|
+
)
|
218
264
|
|
219
265
|
tbl = Table(
|
220
266
|
otsl_seq=otsl_seq,
|
@@ -79,7 +79,7 @@ class BasePipeline(ABC):
|
|
79
79
|
for model in self.enrichment_pipe:
|
80
80
|
for element_batch in chunkify(
|
81
81
|
_prepare_elements(conv_res, model),
|
82
|
-
|
82
|
+
model.elements_batch_size,
|
83
83
|
):
|
84
84
|
for element in model(
|
85
85
|
doc=conv_res.document, element_batch=element_batch
|
@@ -141,7 +141,9 @@ class PaginatedPipeline(BasePipeline): # TODO this is a bad name.
|
|
141
141
|
with TimeRecorder(conv_res, "doc_build", scope=ProfilingScope.DOCUMENT):
|
142
142
|
|
143
143
|
for i in range(0, conv_res.input.page_count):
|
144
|
-
conv_res.
|
144
|
+
start_page, end_page = conv_res.input.limits.page_range
|
145
|
+
if (start_page - 1) <= i <= (end_page - 1):
|
146
|
+
conv_res.pages.append(Page(page_no=i))
|
145
147
|
|
146
148
|
try:
|
147
149
|
# Iterate batches of pages (page_batch_size) in the doc
|