docling 2.18.0__py3-none-any.whl → 2.20.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docling/backend/md_backend.py +62 -46
- docling/backend/msword_backend.py +1 -1
- docling/cli/main.py +13 -0
- docling/cli/models.py +107 -0
- docling/cli/tools.py +17 -0
- docling/datamodel/pipeline_options.py +52 -2
- docling/datamodel/settings.py +2 -0
- docling/models/base_model.py +5 -2
- docling/models/code_formula_model.py +15 -9
- docling/models/document_picture_classifier.py +11 -8
- docling/models/easyocr_model.py +49 -4
- docling/models/layout_model.py +49 -3
- docling/models/picture_description_api_model.py +101 -0
- docling/models/picture_description_base_model.py +64 -0
- docling/models/picture_description_vlm_model.py +109 -0
- docling/models/table_structure_model.py +44 -2
- docling/pipeline/base_pipeline.py +1 -1
- docling/pipeline/standard_pdf_pipeline.py +66 -25
- docling/utils/model_downloader.py +84 -0
- docling/utils/utils.py +24 -0
- {docling-2.18.0.dist-info → docling-2.20.0.dist-info}/METADATA +8 -4
- {docling-2.18.0.dist-info → docling-2.20.0.dist-info}/RECORD +25 -19
- {docling-2.18.0.dist-info → docling-2.20.0.dist-info}/entry_points.txt +1 -0
- {docling-2.18.0.dist-info → docling-2.20.0.dist-info}/LICENSE +0 -0
- {docling-2.18.0.dist-info → docling-2.20.0.dist-info}/WHEEL +0 -0
docling/models/easyocr_model.py
CHANGED
@@ -1,9 +1,10 @@
|
|
1
1
|
import logging
|
2
2
|
import warnings
|
3
|
-
|
3
|
+
import zipfile
|
4
|
+
from pathlib import Path
|
5
|
+
from typing import Iterable, List, Optional
|
4
6
|
|
5
7
|
import numpy
|
6
|
-
import torch
|
7
8
|
from docling_core.types.doc import BoundingBox, CoordOrigin
|
8
9
|
|
9
10
|
from docling.datamodel.base_models import Cell, OcrCell, Page
|
@@ -17,14 +18,18 @@ from docling.datamodel.settings import settings
|
|
17
18
|
from docling.models.base_ocr_model import BaseOcrModel
|
18
19
|
from docling.utils.accelerator_utils import decide_device
|
19
20
|
from docling.utils.profiling import TimeRecorder
|
21
|
+
from docling.utils.utils import download_url_with_progress
|
20
22
|
|
21
23
|
_log = logging.getLogger(__name__)
|
22
24
|
|
23
25
|
|
24
26
|
class EasyOcrModel(BaseOcrModel):
|
27
|
+
_model_repo_folder = "EasyOcr"
|
28
|
+
|
25
29
|
def __init__(
|
26
30
|
self,
|
27
31
|
enabled: bool,
|
32
|
+
artifacts_path: Optional[Path],
|
28
33
|
options: EasyOcrOptions,
|
29
34
|
accelerator_options: AcceleratorOptions,
|
30
35
|
):
|
@@ -62,15 +67,55 @@ class EasyOcrModel(BaseOcrModel):
|
|
62
67
|
)
|
63
68
|
use_gpu = self.options.use_gpu
|
64
69
|
|
70
|
+
download_enabled = self.options.download_enabled
|
71
|
+
model_storage_directory = self.options.model_storage_directory
|
72
|
+
if artifacts_path is not None and model_storage_directory is None:
|
73
|
+
download_enabled = False
|
74
|
+
model_storage_directory = str(artifacts_path / self._model_repo_folder)
|
75
|
+
|
65
76
|
self.reader = easyocr.Reader(
|
66
77
|
lang_list=self.options.lang,
|
67
78
|
gpu=use_gpu,
|
68
|
-
model_storage_directory=
|
79
|
+
model_storage_directory=model_storage_directory,
|
69
80
|
recog_network=self.options.recog_network,
|
70
|
-
download_enabled=
|
81
|
+
download_enabled=download_enabled,
|
71
82
|
verbose=False,
|
72
83
|
)
|
73
84
|
|
85
|
+
@staticmethod
|
86
|
+
def download_models(
|
87
|
+
detection_models: List[str] = ["craft"],
|
88
|
+
recognition_models: List[str] = ["english_g2", "latin_g2"],
|
89
|
+
local_dir: Optional[Path] = None,
|
90
|
+
force: bool = False,
|
91
|
+
progress: bool = False,
|
92
|
+
) -> Path:
|
93
|
+
# Models are located in https://github.com/JaidedAI/EasyOCR/blob/master/easyocr/config.py
|
94
|
+
from easyocr.config import detection_models as det_models_dict
|
95
|
+
from easyocr.config import recognition_models as rec_models_dict
|
96
|
+
|
97
|
+
if local_dir is None:
|
98
|
+
local_dir = settings.cache_dir / "models" / EasyOcrModel._model_repo_folder
|
99
|
+
|
100
|
+
local_dir.mkdir(parents=True, exist_ok=True)
|
101
|
+
|
102
|
+
# Collect models to download
|
103
|
+
download_list = []
|
104
|
+
for model_name in detection_models:
|
105
|
+
if model_name in det_models_dict:
|
106
|
+
download_list.append(det_models_dict[model_name])
|
107
|
+
for model_name in recognition_models:
|
108
|
+
if model_name in rec_models_dict["gen2"]:
|
109
|
+
download_list.append(rec_models_dict["gen2"][model_name])
|
110
|
+
|
111
|
+
# Download models
|
112
|
+
for model_details in download_list:
|
113
|
+
buf = download_url_with_progress(model_details["url"], progress=progress)
|
114
|
+
with zipfile.ZipFile(buf, "r") as zip_ref:
|
115
|
+
zip_ref.extractall(local_dir)
|
116
|
+
|
117
|
+
return local_dir
|
118
|
+
|
74
119
|
def __call__(
|
75
120
|
self, conv_res: ConversionResult, page_batch: Iterable[Page]
|
76
121
|
) -> Iterable[Page]:
|
docling/models/layout_model.py
CHANGED
@@ -1,7 +1,8 @@
|
|
1
1
|
import copy
|
2
2
|
import logging
|
3
|
+
import warnings
|
3
4
|
from pathlib import Path
|
4
|
-
from typing import Iterable
|
5
|
+
from typing import Iterable, Optional, Union
|
5
6
|
|
6
7
|
from docling_core.types.doc import DocItemLabel
|
7
8
|
from docling_ibm_models.layoutmodel.layout_predictor import LayoutPredictor
|
@@ -21,6 +22,8 @@ _log = logging.getLogger(__name__)
|
|
21
22
|
|
22
23
|
|
23
24
|
class LayoutModel(BasePageModel):
|
25
|
+
_model_repo_folder = "ds4sd--docling-models"
|
26
|
+
_model_path = "model_artifacts/layout"
|
24
27
|
|
25
28
|
TEXT_ELEM_LABELS = [
|
26
29
|
DocItemLabel.TEXT,
|
@@ -42,15 +45,56 @@ class LayoutModel(BasePageModel):
|
|
42
45
|
FORMULA_LABEL = DocItemLabel.FORMULA
|
43
46
|
CONTAINER_LABELS = [DocItemLabel.FORM, DocItemLabel.KEY_VALUE_REGION]
|
44
47
|
|
45
|
-
def __init__(
|
48
|
+
def __init__(
|
49
|
+
self, artifacts_path: Optional[Path], accelerator_options: AcceleratorOptions
|
50
|
+
):
|
46
51
|
device = decide_device(accelerator_options.device)
|
47
52
|
|
53
|
+
if artifacts_path is None:
|
54
|
+
artifacts_path = self.download_models() / self._model_path
|
55
|
+
else:
|
56
|
+
# will become the default in the future
|
57
|
+
if (artifacts_path / self._model_repo_folder).exists():
|
58
|
+
artifacts_path = (
|
59
|
+
artifacts_path / self._model_repo_folder / self._model_path
|
60
|
+
)
|
61
|
+
elif (artifacts_path / self._model_path).exists():
|
62
|
+
warnings.warn(
|
63
|
+
"The usage of artifacts_path containing directly "
|
64
|
+
f"{self._model_path} is deprecated. Please point "
|
65
|
+
"the artifacts_path to the parent containing "
|
66
|
+
f"the {self._model_repo_folder} folder.",
|
67
|
+
DeprecationWarning,
|
68
|
+
stacklevel=3,
|
69
|
+
)
|
70
|
+
artifacts_path = artifacts_path / self._model_path
|
71
|
+
|
48
72
|
self.layout_predictor = LayoutPredictor(
|
49
73
|
artifact_path=str(artifacts_path),
|
50
74
|
device=device,
|
51
75
|
num_threads=accelerator_options.num_threads,
|
52
76
|
)
|
53
77
|
|
78
|
+
@staticmethod
|
79
|
+
def download_models(
|
80
|
+
local_dir: Optional[Path] = None,
|
81
|
+
force: bool = False,
|
82
|
+
progress: bool = False,
|
83
|
+
) -> Path:
|
84
|
+
from huggingface_hub import snapshot_download
|
85
|
+
from huggingface_hub.utils import disable_progress_bars
|
86
|
+
|
87
|
+
if not progress:
|
88
|
+
disable_progress_bars()
|
89
|
+
download_path = snapshot_download(
|
90
|
+
repo_id="ds4sd/docling-models",
|
91
|
+
force_download=force,
|
92
|
+
local_dir=local_dir,
|
93
|
+
revision="v2.1.0",
|
94
|
+
)
|
95
|
+
|
96
|
+
return Path(download_path)
|
97
|
+
|
54
98
|
def draw_clusters_and_cells_side_by_side(
|
55
99
|
self, conv_res, page, clusters, mode_prefix: str, show: bool = False
|
56
100
|
):
|
@@ -106,10 +150,12 @@ class LayoutModel(BasePageModel):
|
|
106
150
|
else:
|
107
151
|
with TimeRecorder(conv_res, "layout"):
|
108
152
|
assert page.size is not None
|
153
|
+
page_image = page.get_image(scale=1.0)
|
154
|
+
assert page_image is not None
|
109
155
|
|
110
156
|
clusters = []
|
111
157
|
for ix, pred_item in enumerate(
|
112
|
-
self.layout_predictor.predict(
|
158
|
+
self.layout_predictor.predict(page_image)
|
113
159
|
):
|
114
160
|
label = DocItemLabel(
|
115
161
|
pred_item["label"]
|
@@ -0,0 +1,101 @@
|
|
1
|
+
import base64
|
2
|
+
import io
|
3
|
+
import logging
|
4
|
+
from typing import Iterable, List, Optional
|
5
|
+
|
6
|
+
import requests
|
7
|
+
from PIL import Image
|
8
|
+
from pydantic import BaseModel, ConfigDict
|
9
|
+
|
10
|
+
from docling.datamodel.pipeline_options import PictureDescriptionApiOptions
|
11
|
+
from docling.models.picture_description_base_model import PictureDescriptionBaseModel
|
12
|
+
|
13
|
+
_log = logging.getLogger(__name__)
|
14
|
+
|
15
|
+
|
16
|
+
class ChatMessage(BaseModel):
|
17
|
+
role: str
|
18
|
+
content: str
|
19
|
+
|
20
|
+
|
21
|
+
class ResponseChoice(BaseModel):
|
22
|
+
index: int
|
23
|
+
message: ChatMessage
|
24
|
+
finish_reason: str
|
25
|
+
|
26
|
+
|
27
|
+
class ResponseUsage(BaseModel):
|
28
|
+
prompt_tokens: int
|
29
|
+
completion_tokens: int
|
30
|
+
total_tokens: int
|
31
|
+
|
32
|
+
|
33
|
+
class ApiResponse(BaseModel):
|
34
|
+
model_config = ConfigDict(
|
35
|
+
protected_namespaces=(),
|
36
|
+
)
|
37
|
+
|
38
|
+
id: str
|
39
|
+
model: Optional[str] = None # returned by openai
|
40
|
+
choices: List[ResponseChoice]
|
41
|
+
created: int
|
42
|
+
usage: ResponseUsage
|
43
|
+
|
44
|
+
|
45
|
+
class PictureDescriptionApiModel(PictureDescriptionBaseModel):
|
46
|
+
# elements_batch_size = 4
|
47
|
+
|
48
|
+
def __init__(self, enabled: bool, options: PictureDescriptionApiOptions):
|
49
|
+
super().__init__(enabled=enabled, options=options)
|
50
|
+
self.options: PictureDescriptionApiOptions
|
51
|
+
|
52
|
+
if self.enabled:
|
53
|
+
if options.url.host != "localhost":
|
54
|
+
raise NotImplementedError(
|
55
|
+
"The options try to connect to remote APIs which are not yet allowed."
|
56
|
+
)
|
57
|
+
|
58
|
+
def _annotate_images(self, images: Iterable[Image.Image]) -> Iterable[str]:
|
59
|
+
# Note: technically we could make a batch request here,
|
60
|
+
# but not all APIs will allow for it. For example, vllm won't allow more than 1.
|
61
|
+
for image in images:
|
62
|
+
img_io = io.BytesIO()
|
63
|
+
image.save(img_io, "PNG")
|
64
|
+
image_base64 = base64.b64encode(img_io.getvalue()).decode("utf-8")
|
65
|
+
|
66
|
+
messages = [
|
67
|
+
{
|
68
|
+
"role": "user",
|
69
|
+
"content": [
|
70
|
+
{
|
71
|
+
"type": "text",
|
72
|
+
"text": self.options.prompt,
|
73
|
+
},
|
74
|
+
{
|
75
|
+
"type": "image_url",
|
76
|
+
"image_url": {
|
77
|
+
"url": f"data:image/png;base64,{image_base64}"
|
78
|
+
},
|
79
|
+
},
|
80
|
+
],
|
81
|
+
}
|
82
|
+
]
|
83
|
+
|
84
|
+
payload = {
|
85
|
+
"messages": messages,
|
86
|
+
**self.options.params,
|
87
|
+
}
|
88
|
+
|
89
|
+
r = requests.post(
|
90
|
+
str(self.options.url),
|
91
|
+
headers=self.options.headers,
|
92
|
+
json=payload,
|
93
|
+
timeout=self.options.timeout,
|
94
|
+
)
|
95
|
+
if not r.ok:
|
96
|
+
_log.error(f"Error calling the API. Reponse was {r.text}")
|
97
|
+
r.raise_for_status()
|
98
|
+
|
99
|
+
api_resp = ApiResponse.model_validate_json(r.text)
|
100
|
+
generated_text = api_resp.choices[0].message.content.strip()
|
101
|
+
yield generated_text
|
@@ -0,0 +1,64 @@
|
|
1
|
+
import logging
|
2
|
+
from pathlib import Path
|
3
|
+
from typing import Any, Iterable, List, Optional, Union
|
4
|
+
|
5
|
+
from docling_core.types.doc import (
|
6
|
+
DoclingDocument,
|
7
|
+
NodeItem,
|
8
|
+
PictureClassificationClass,
|
9
|
+
PictureItem,
|
10
|
+
)
|
11
|
+
from docling_core.types.doc.document import ( # TODO: move import to docling_core.types.doc
|
12
|
+
PictureDescriptionData,
|
13
|
+
)
|
14
|
+
from PIL import Image
|
15
|
+
|
16
|
+
from docling.datamodel.pipeline_options import PictureDescriptionBaseOptions
|
17
|
+
from docling.models.base_model import (
|
18
|
+
BaseItemAndImageEnrichmentModel,
|
19
|
+
ItemAndImageEnrichmentElement,
|
20
|
+
)
|
21
|
+
|
22
|
+
|
23
|
+
class PictureDescriptionBaseModel(BaseItemAndImageEnrichmentModel):
|
24
|
+
images_scale: float = 2.0
|
25
|
+
|
26
|
+
def __init__(
|
27
|
+
self,
|
28
|
+
enabled: bool,
|
29
|
+
options: PictureDescriptionBaseOptions,
|
30
|
+
):
|
31
|
+
self.enabled = enabled
|
32
|
+
self.options = options
|
33
|
+
self.provenance = "not-implemented"
|
34
|
+
|
35
|
+
def is_processable(self, doc: DoclingDocument, element: NodeItem) -> bool:
|
36
|
+
return self.enabled and isinstance(element, PictureItem)
|
37
|
+
|
38
|
+
def _annotate_images(self, images: Iterable[Image.Image]) -> Iterable[str]:
|
39
|
+
raise NotImplementedError
|
40
|
+
|
41
|
+
def __call__(
|
42
|
+
self,
|
43
|
+
doc: DoclingDocument,
|
44
|
+
element_batch: Iterable[ItemAndImageEnrichmentElement],
|
45
|
+
) -> Iterable[NodeItem]:
|
46
|
+
if not self.enabled:
|
47
|
+
for element in element_batch:
|
48
|
+
yield element.item
|
49
|
+
return
|
50
|
+
|
51
|
+
images: List[Image.Image] = []
|
52
|
+
elements: List[PictureItem] = []
|
53
|
+
for el in element_batch:
|
54
|
+
assert isinstance(el.item, PictureItem)
|
55
|
+
elements.append(el.item)
|
56
|
+
images.append(el.image)
|
57
|
+
|
58
|
+
outputs = self._annotate_images(images)
|
59
|
+
|
60
|
+
for item, output in zip(elements, outputs):
|
61
|
+
item.annotations.append(
|
62
|
+
PictureDescriptionData(text=output, provenance=self.provenance)
|
63
|
+
)
|
64
|
+
yield item
|
@@ -0,0 +1,109 @@
|
|
1
|
+
from pathlib import Path
|
2
|
+
from typing import Iterable, Optional, Union
|
3
|
+
|
4
|
+
from PIL import Image
|
5
|
+
|
6
|
+
from docling.datamodel.pipeline_options import (
|
7
|
+
AcceleratorOptions,
|
8
|
+
PictureDescriptionVlmOptions,
|
9
|
+
)
|
10
|
+
from docling.models.picture_description_base_model import PictureDescriptionBaseModel
|
11
|
+
from docling.utils.accelerator_utils import decide_device
|
12
|
+
|
13
|
+
|
14
|
+
class PictureDescriptionVlmModel(PictureDescriptionBaseModel):
|
15
|
+
|
16
|
+
def __init__(
|
17
|
+
self,
|
18
|
+
enabled: bool,
|
19
|
+
artifacts_path: Optional[Union[Path, str]],
|
20
|
+
options: PictureDescriptionVlmOptions,
|
21
|
+
accelerator_options: AcceleratorOptions,
|
22
|
+
):
|
23
|
+
super().__init__(enabled=enabled, options=options)
|
24
|
+
self.options: PictureDescriptionVlmOptions
|
25
|
+
|
26
|
+
if self.enabled:
|
27
|
+
|
28
|
+
if artifacts_path is None:
|
29
|
+
artifacts_path = self.download_models(repo_id=self.options.repo_id)
|
30
|
+
else:
|
31
|
+
artifacts_path = Path(artifacts_path) / self.options.repo_cache_folder
|
32
|
+
|
33
|
+
self.device = decide_device(accelerator_options.device)
|
34
|
+
|
35
|
+
try:
|
36
|
+
import torch
|
37
|
+
from transformers import AutoModelForVision2Seq, AutoProcessor
|
38
|
+
except ImportError:
|
39
|
+
raise ImportError(
|
40
|
+
"transformers >=4.46 is not installed. Please install Docling with the required extras `pip install docling[vlm]`."
|
41
|
+
)
|
42
|
+
|
43
|
+
# Initialize processor and model
|
44
|
+
self.processor = AutoProcessor.from_pretrained(self.options.repo_id)
|
45
|
+
self.model = AutoModelForVision2Seq.from_pretrained(
|
46
|
+
self.options.repo_id,
|
47
|
+
torch_dtype=torch.bfloat16,
|
48
|
+
_attn_implementation=(
|
49
|
+
"flash_attention_2" if self.device.startswith("cuda") else "eager"
|
50
|
+
),
|
51
|
+
).to(self.device)
|
52
|
+
|
53
|
+
self.provenance = f"{self.options.repo_id}"
|
54
|
+
|
55
|
+
@staticmethod
|
56
|
+
def download_models(
|
57
|
+
repo_id: str,
|
58
|
+
local_dir: Optional[Path] = None,
|
59
|
+
force: bool = False,
|
60
|
+
progress: bool = False,
|
61
|
+
) -> Path:
|
62
|
+
from huggingface_hub import snapshot_download
|
63
|
+
from huggingface_hub.utils import disable_progress_bars
|
64
|
+
|
65
|
+
if not progress:
|
66
|
+
disable_progress_bars()
|
67
|
+
download_path = snapshot_download(
|
68
|
+
repo_id=repo_id,
|
69
|
+
force_download=force,
|
70
|
+
local_dir=local_dir,
|
71
|
+
)
|
72
|
+
|
73
|
+
return Path(download_path)
|
74
|
+
|
75
|
+
def _annotate_images(self, images: Iterable[Image.Image]) -> Iterable[str]:
|
76
|
+
from transformers import GenerationConfig
|
77
|
+
|
78
|
+
# Create input messages
|
79
|
+
messages = [
|
80
|
+
{
|
81
|
+
"role": "user",
|
82
|
+
"content": [
|
83
|
+
{"type": "image"},
|
84
|
+
{"type": "text", "text": self.options.prompt},
|
85
|
+
],
|
86
|
+
},
|
87
|
+
]
|
88
|
+
|
89
|
+
# TODO: do batch generation
|
90
|
+
|
91
|
+
for image in images:
|
92
|
+
# Prepare inputs
|
93
|
+
prompt = self.processor.apply_chat_template(
|
94
|
+
messages, add_generation_prompt=True
|
95
|
+
)
|
96
|
+
inputs = self.processor(text=prompt, images=[image], return_tensors="pt")
|
97
|
+
inputs = inputs.to(self.device)
|
98
|
+
|
99
|
+
# Generate outputs
|
100
|
+
generated_ids = self.model.generate(
|
101
|
+
**inputs,
|
102
|
+
generation_config=GenerationConfig(**self.options.generation_config),
|
103
|
+
)
|
104
|
+
generated_texts = self.processor.batch_decode(
|
105
|
+
generated_ids[:, inputs["input_ids"].shape[1] :],
|
106
|
+
skip_special_tokens=True,
|
107
|
+
)
|
108
|
+
|
109
|
+
yield generated_texts[0].strip()
|
@@ -1,6 +1,7 @@
|
|
1
1
|
import copy
|
2
|
+
import warnings
|
2
3
|
from pathlib import Path
|
3
|
-
from typing import Iterable
|
4
|
+
from typing import Iterable, Optional, Union
|
4
5
|
|
5
6
|
import numpy
|
6
7
|
from docling_core.types.doc import BoundingBox, DocItemLabel, TableCell
|
@@ -22,10 +23,13 @@ from docling.utils.profiling import TimeRecorder
|
|
22
23
|
|
23
24
|
|
24
25
|
class TableStructureModel(BasePageModel):
|
26
|
+
_model_repo_folder = "ds4sd--docling-models"
|
27
|
+
_model_path = "model_artifacts/tableformer"
|
28
|
+
|
25
29
|
def __init__(
|
26
30
|
self,
|
27
31
|
enabled: bool,
|
28
|
-
artifacts_path: Path,
|
32
|
+
artifacts_path: Optional[Path],
|
29
33
|
options: TableStructureOptions,
|
30
34
|
accelerator_options: AcceleratorOptions,
|
31
35
|
):
|
@@ -35,6 +39,26 @@ class TableStructureModel(BasePageModel):
|
|
35
39
|
|
36
40
|
self.enabled = enabled
|
37
41
|
if self.enabled:
|
42
|
+
|
43
|
+
if artifacts_path is None:
|
44
|
+
artifacts_path = self.download_models() / self._model_path
|
45
|
+
else:
|
46
|
+
# will become the default in the future
|
47
|
+
if (artifacts_path / self._model_repo_folder).exists():
|
48
|
+
artifacts_path = (
|
49
|
+
artifacts_path / self._model_repo_folder / self._model_path
|
50
|
+
)
|
51
|
+
elif (artifacts_path / self._model_path).exists():
|
52
|
+
warnings.warn(
|
53
|
+
"The usage of artifacts_path containing directly "
|
54
|
+
f"{self._model_path} is deprecated. Please point "
|
55
|
+
"the artifacts_path to the parent containing "
|
56
|
+
f"the {self._model_repo_folder} folder.",
|
57
|
+
DeprecationWarning,
|
58
|
+
stacklevel=3,
|
59
|
+
)
|
60
|
+
artifacts_path = artifacts_path / self._model_path
|
61
|
+
|
38
62
|
if self.mode == TableFormerMode.ACCURATE:
|
39
63
|
artifacts_path = artifacts_path / "accurate"
|
40
64
|
else:
|
@@ -58,6 +82,24 @@ class TableStructureModel(BasePageModel):
|
|
58
82
|
)
|
59
83
|
self.scale = 2.0 # Scale up table input images to 144 dpi
|
60
84
|
|
85
|
+
@staticmethod
|
86
|
+
def download_models(
|
87
|
+
local_dir: Optional[Path] = None, force: bool = False, progress: bool = False
|
88
|
+
) -> Path:
|
89
|
+
from huggingface_hub import snapshot_download
|
90
|
+
from huggingface_hub.utils import disable_progress_bars
|
91
|
+
|
92
|
+
if not progress:
|
93
|
+
disable_progress_bars()
|
94
|
+
download_path = snapshot_download(
|
95
|
+
repo_id="ds4sd/docling-models",
|
96
|
+
force_download=force,
|
97
|
+
local_dir=local_dir,
|
98
|
+
revision="v2.1.0",
|
99
|
+
)
|
100
|
+
|
101
|
+
return Path(download_path)
|
102
|
+
|
61
103
|
def draw_table_and_cells(
|
62
104
|
self,
|
63
105
|
conv_res: ConversionResult,
|
@@ -79,7 +79,7 @@ class BasePipeline(ABC):
|
|
79
79
|
for model in self.enrichment_pipe:
|
80
80
|
for element_batch in chunkify(
|
81
81
|
_prepare_elements(conv_res, model),
|
82
|
-
|
82
|
+
model.elements_batch_size,
|
83
83
|
):
|
84
84
|
for element in model(
|
85
85
|
doc=conv_res.document, element_batch=element_batch
|