docling 2.35.0__py3-none-any.whl → 2.36.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docling/backend/xml/jats_backend.py +0 -0
- docling/cli/main.py +12 -15
- docling/datamodel/accelerator_options.py +68 -0
- docling/datamodel/base_models.py +10 -8
- docling/datamodel/pipeline_options.py +29 -161
- docling/datamodel/pipeline_options_vlm_model.py +81 -0
- docling/datamodel/vlm_model_specs.py +144 -0
- docling/document_converter.py +5 -0
- docling/models/api_vlm_model.py +1 -1
- docling/models/base_ocr_model.py +2 -1
- docling/models/code_formula_model.py +6 -11
- docling/models/document_picture_classifier.py +6 -11
- docling/models/easyocr_model.py +1 -2
- docling/models/layout_model.py +6 -11
- docling/models/ocr_mac_model.py +1 -1
- docling/models/picture_description_api_model.py +1 -1
- docling/models/picture_description_base_model.py +1 -1
- docling/models/picture_description_vlm_model.py +7 -22
- docling/models/rapid_ocr_model.py +1 -2
- docling/models/table_structure_model.py +6 -12
- docling/models/tesseract_ocr_cli_model.py +1 -1
- docling/models/tesseract_ocr_model.py +1 -1
- docling/models/utils/__init__.py +0 -0
- docling/models/utils/hf_model_download.py +40 -0
- docling/models/vlm_models_inline/__init__.py +0 -0
- docling/models/vlm_models_inline/hf_transformers_model.py +194 -0
- docling/models/{hf_mlx_model.py → vlm_models_inline/mlx_model.py} +56 -44
- docling/pipeline/vlm_pipeline.py +228 -61
- docling/utils/accelerator_utils.py +17 -2
- docling/utils/model_downloader.py +13 -12
- {docling-2.35.0.dist-info → docling-2.36.0.dist-info}/METADATA +54 -55
- {docling-2.35.0.dist-info → docling-2.36.0.dist-info}/RECORD +46 -39
- {docling-2.35.0.dist-info → docling-2.36.0.dist-info}/WHEEL +2 -1
- docling-2.36.0.dist-info/entry_points.txt +6 -0
- docling-2.36.0.dist-info/top_level.txt +1 -0
- docling/models/hf_vlm_model.py +0 -182
- docling-2.35.0.dist-info/entry_points.txt +0 -7
- {docling-2.35.0.dist-info → docling-2.36.0.dist-info/licenses}/LICENSE +0 -0
@@ -16,9 +16,10 @@ from docling_core.types.doc.labels import CodeLanguageLabel
|
|
16
16
|
from PIL import Image, ImageOps
|
17
17
|
from pydantic import BaseModel
|
18
18
|
|
19
|
+
from docling.datamodel.accelerator_options import AcceleratorOptions
|
19
20
|
from docling.datamodel.base_models import ItemAndImageEnrichmentElement
|
20
|
-
from docling.datamodel.pipeline_options import AcceleratorOptions
|
21
21
|
from docling.models.base_model import BaseItemAndImageEnrichmentModel
|
22
|
+
from docling.models.utils.hf_model_download import download_hf_model
|
22
23
|
from docling.utils.accelerator_utils import decide_device
|
23
24
|
|
24
25
|
|
@@ -117,20 +118,14 @@ class CodeFormulaModel(BaseItemAndImageEnrichmentModel):
|
|
117
118
|
force: bool = False,
|
118
119
|
progress: bool = False,
|
119
120
|
) -> Path:
|
120
|
-
|
121
|
-
from huggingface_hub.utils import disable_progress_bars
|
122
|
-
|
123
|
-
if not progress:
|
124
|
-
disable_progress_bars()
|
125
|
-
download_path = snapshot_download(
|
121
|
+
return download_hf_model(
|
126
122
|
repo_id="ds4sd/CodeFormula",
|
127
|
-
force_download=force,
|
128
|
-
local_dir=local_dir,
|
129
123
|
revision="v1.0.2",
|
124
|
+
local_dir=local_dir,
|
125
|
+
force=force,
|
126
|
+
progress=progress,
|
130
127
|
)
|
131
128
|
|
132
|
-
return Path(download_path)
|
133
|
-
|
134
129
|
def is_processable(self, doc: DoclingDocument, element: NodeItem) -> bool:
|
135
130
|
"""
|
136
131
|
Determines if a given element in a document can be processed by the model.
|
@@ -13,8 +13,9 @@ from docling_core.types.doc import (
|
|
13
13
|
from PIL import Image
|
14
14
|
from pydantic import BaseModel
|
15
15
|
|
16
|
-
from docling.datamodel.
|
16
|
+
from docling.datamodel.accelerator_options import AcceleratorOptions
|
17
17
|
from docling.models.base_model import BaseEnrichmentModel
|
18
|
+
from docling.models.utils.hf_model_download import download_hf_model
|
18
19
|
from docling.utils.accelerator_utils import decide_device
|
19
20
|
|
20
21
|
|
@@ -105,20 +106,14 @@ class DocumentPictureClassifier(BaseEnrichmentModel):
|
|
105
106
|
def download_models(
|
106
107
|
local_dir: Optional[Path] = None, force: bool = False, progress: bool = False
|
107
108
|
) -> Path:
|
108
|
-
|
109
|
-
from huggingface_hub.utils import disable_progress_bars
|
110
|
-
|
111
|
-
if not progress:
|
112
|
-
disable_progress_bars()
|
113
|
-
download_path = snapshot_download(
|
109
|
+
return download_hf_model(
|
114
110
|
repo_id="ds4sd/DocumentFigureClassifier",
|
115
|
-
force_download=force,
|
116
|
-
local_dir=local_dir,
|
117
111
|
revision="v1.0.1",
|
112
|
+
local_dir=local_dir,
|
113
|
+
force=force,
|
114
|
+
progress=progress,
|
118
115
|
)
|
119
116
|
|
120
|
-
return Path(download_path)
|
121
|
-
|
122
117
|
def is_processable(self, doc: DoclingDocument, element: NodeItem) -> bool:
|
123
118
|
"""
|
124
119
|
Determines if the given element can be processed by the classifier.
|
docling/models/easyocr_model.py
CHANGED
@@ -9,11 +9,10 @@ import numpy
|
|
9
9
|
from docling_core.types.doc import BoundingBox, CoordOrigin
|
10
10
|
from docling_core.types.doc.page import BoundingRectangle, TextCell
|
11
11
|
|
12
|
+
from docling.datamodel.accelerator_options import AcceleratorDevice, AcceleratorOptions
|
12
13
|
from docling.datamodel.base_models import Page
|
13
14
|
from docling.datamodel.document import ConversionResult
|
14
15
|
from docling.datamodel.pipeline_options import (
|
15
|
-
AcceleratorDevice,
|
16
|
-
AcceleratorOptions,
|
17
16
|
EasyOcrOptions,
|
18
17
|
OcrOptions,
|
19
18
|
)
|
docling/models/layout_model.py
CHANGED
@@ -10,11 +10,12 @@ from docling_core.types.doc import DocItemLabel
|
|
10
10
|
from docling_ibm_models.layoutmodel.layout_predictor import LayoutPredictor
|
11
11
|
from PIL import Image
|
12
12
|
|
13
|
+
from docling.datamodel.accelerator_options import AcceleratorOptions
|
13
14
|
from docling.datamodel.base_models import BoundingBox, Cluster, LayoutPrediction, Page
|
14
15
|
from docling.datamodel.document import ConversionResult
|
15
|
-
from docling.datamodel.pipeline_options import AcceleratorOptions
|
16
16
|
from docling.datamodel.settings import settings
|
17
17
|
from docling.models.base_model import BasePageModel
|
18
|
+
from docling.models.utils.hf_model_download import download_hf_model
|
18
19
|
from docling.utils.accelerator_utils import decide_device
|
19
20
|
from docling.utils.layout_postprocessor import LayoutPostprocessor
|
20
21
|
from docling.utils.profiling import TimeRecorder
|
@@ -83,20 +84,14 @@ class LayoutModel(BasePageModel):
|
|
83
84
|
force: bool = False,
|
84
85
|
progress: bool = False,
|
85
86
|
) -> Path:
|
86
|
-
|
87
|
-
from huggingface_hub.utils import disable_progress_bars
|
88
|
-
|
89
|
-
if not progress:
|
90
|
-
disable_progress_bars()
|
91
|
-
download_path = snapshot_download(
|
87
|
+
return download_hf_model(
|
92
88
|
repo_id="ds4sd/docling-models",
|
93
|
-
|
89
|
+
revision="v2.2.0",
|
94
90
|
local_dir=local_dir,
|
95
|
-
|
91
|
+
force=force,
|
92
|
+
progress=progress,
|
96
93
|
)
|
97
94
|
|
98
|
-
return Path(download_path)
|
99
|
-
|
100
95
|
def draw_clusters_and_cells_side_by_side(
|
101
96
|
self, conv_res, page, clusters, mode_prefix: str, show: bool = False
|
102
97
|
):
|
docling/models/ocr_mac_model.py
CHANGED
@@ -8,10 +8,10 @@ from typing import Optional, Type
|
|
8
8
|
from docling_core.types.doc import BoundingBox, CoordOrigin
|
9
9
|
from docling_core.types.doc.page import BoundingRectangle, TextCell
|
10
10
|
|
11
|
+
from docling.datamodel.accelerator_options import AcceleratorOptions
|
11
12
|
from docling.datamodel.base_models import Page
|
12
13
|
from docling.datamodel.document import ConversionResult
|
13
14
|
from docling.datamodel.pipeline_options import (
|
14
|
-
AcceleratorOptions,
|
15
15
|
OcrMacOptions,
|
16
16
|
OcrOptions,
|
17
17
|
)
|
@@ -5,8 +5,8 @@ from typing import Optional, Type, Union
|
|
5
5
|
|
6
6
|
from PIL import Image
|
7
7
|
|
8
|
+
from docling.datamodel.accelerator_options import AcceleratorOptions
|
8
9
|
from docling.datamodel.pipeline_options import (
|
9
|
-
AcceleratorOptions,
|
10
10
|
PictureDescriptionApiOptions,
|
11
11
|
PictureDescriptionBaseOptions,
|
12
12
|
)
|
@@ -13,8 +13,8 @@ from docling_core.types.doc.document import ( # TODO: move import to docling_co
|
|
13
13
|
)
|
14
14
|
from PIL import Image
|
15
15
|
|
16
|
+
from docling.datamodel.accelerator_options import AcceleratorOptions
|
16
17
|
from docling.datamodel.pipeline_options import (
|
17
|
-
AcceleratorOptions,
|
18
18
|
PictureDescriptionBaseOptions,
|
19
19
|
)
|
20
20
|
from docling.models.base_model import (
|
@@ -4,16 +4,21 @@ from typing import Optional, Type, Union
|
|
4
4
|
|
5
5
|
from PIL import Image
|
6
6
|
|
7
|
+
from docling.datamodel.accelerator_options import AcceleratorOptions
|
7
8
|
from docling.datamodel.pipeline_options import (
|
8
|
-
AcceleratorOptions,
|
9
9
|
PictureDescriptionBaseOptions,
|
10
10
|
PictureDescriptionVlmOptions,
|
11
11
|
)
|
12
12
|
from docling.models.picture_description_base_model import PictureDescriptionBaseModel
|
13
|
+
from docling.models.utils.hf_model_download import (
|
14
|
+
HuggingFaceModelDownloadMixin,
|
15
|
+
)
|
13
16
|
from docling.utils.accelerator_utils import decide_device
|
14
17
|
|
15
18
|
|
16
|
-
class PictureDescriptionVlmModel(
|
19
|
+
class PictureDescriptionVlmModel(
|
20
|
+
PictureDescriptionBaseModel, HuggingFaceModelDownloadMixin
|
21
|
+
):
|
17
22
|
@classmethod
|
18
23
|
def get_options_type(cls) -> Type[PictureDescriptionBaseOptions]:
|
19
24
|
return PictureDescriptionVlmOptions
|
@@ -66,26 +71,6 @@ class PictureDescriptionVlmModel(PictureDescriptionBaseModel):
|
|
66
71
|
|
67
72
|
self.provenance = f"{self.options.repo_id}"
|
68
73
|
|
69
|
-
@staticmethod
|
70
|
-
def download_models(
|
71
|
-
repo_id: str,
|
72
|
-
local_dir: Optional[Path] = None,
|
73
|
-
force: bool = False,
|
74
|
-
progress: bool = False,
|
75
|
-
) -> Path:
|
76
|
-
from huggingface_hub import snapshot_download
|
77
|
-
from huggingface_hub.utils import disable_progress_bars
|
78
|
-
|
79
|
-
if not progress:
|
80
|
-
disable_progress_bars()
|
81
|
-
download_path = snapshot_download(
|
82
|
-
repo_id=repo_id,
|
83
|
-
force_download=force,
|
84
|
-
local_dir=local_dir,
|
85
|
-
)
|
86
|
-
|
87
|
-
return Path(download_path)
|
88
|
-
|
89
74
|
def _annotate_images(self, images: Iterable[Image.Image]) -> Iterable[str]:
|
90
75
|
from transformers import GenerationConfig
|
91
76
|
|
@@ -7,11 +7,10 @@ import numpy
|
|
7
7
|
from docling_core.types.doc import BoundingBox, CoordOrigin
|
8
8
|
from docling_core.types.doc.page import BoundingRectangle, TextCell
|
9
9
|
|
10
|
+
from docling.datamodel.accelerator_options import AcceleratorDevice, AcceleratorOptions
|
10
11
|
from docling.datamodel.base_models import Page
|
11
12
|
from docling.datamodel.document import ConversionResult
|
12
13
|
from docling.datamodel.pipeline_options import (
|
13
|
-
AcceleratorDevice,
|
14
|
-
AcceleratorOptions,
|
15
14
|
OcrOptions,
|
16
15
|
RapidOcrOptions,
|
17
16
|
)
|
@@ -13,16 +13,16 @@ from docling_core.types.doc.page import (
|
|
13
13
|
from docling_ibm_models.tableformer.data_management.tf_predictor import TFPredictor
|
14
14
|
from PIL import ImageDraw
|
15
15
|
|
16
|
+
from docling.datamodel.accelerator_options import AcceleratorDevice, AcceleratorOptions
|
16
17
|
from docling.datamodel.base_models import Page, Table, TableStructurePrediction
|
17
18
|
from docling.datamodel.document import ConversionResult
|
18
19
|
from docling.datamodel.pipeline_options import (
|
19
|
-
AcceleratorDevice,
|
20
|
-
AcceleratorOptions,
|
21
20
|
TableFormerMode,
|
22
21
|
TableStructureOptions,
|
23
22
|
)
|
24
23
|
from docling.datamodel.settings import settings
|
25
24
|
from docling.models.base_model import BasePageModel
|
25
|
+
from docling.models.utils.hf_model_download import download_hf_model
|
26
26
|
from docling.utils.accelerator_utils import decide_device
|
27
27
|
from docling.utils.profiling import TimeRecorder
|
28
28
|
|
@@ -90,20 +90,14 @@ class TableStructureModel(BasePageModel):
|
|
90
90
|
def download_models(
|
91
91
|
local_dir: Optional[Path] = None, force: bool = False, progress: bool = False
|
92
92
|
) -> Path:
|
93
|
-
|
94
|
-
from huggingface_hub.utils import disable_progress_bars
|
95
|
-
|
96
|
-
if not progress:
|
97
|
-
disable_progress_bars()
|
98
|
-
download_path = snapshot_download(
|
93
|
+
return download_hf_model(
|
99
94
|
repo_id="ds4sd/docling-models",
|
100
|
-
force_download=force,
|
101
|
-
local_dir=local_dir,
|
102
95
|
revision="v2.2.0",
|
96
|
+
local_dir=local_dir,
|
97
|
+
force=force,
|
98
|
+
progress=progress,
|
103
99
|
)
|
104
100
|
|
105
|
-
return Path(download_path)
|
106
|
-
|
107
101
|
def draw_table_and_cells(
|
108
102
|
self,
|
109
103
|
conv_res: ConversionResult,
|
@@ -13,10 +13,10 @@ import pandas as pd
|
|
13
13
|
from docling_core.types.doc import BoundingBox, CoordOrigin
|
14
14
|
from docling_core.types.doc.page import TextCell
|
15
15
|
|
16
|
+
from docling.datamodel.accelerator_options import AcceleratorOptions
|
16
17
|
from docling.datamodel.base_models import Page
|
17
18
|
from docling.datamodel.document import ConversionResult
|
18
19
|
from docling.datamodel.pipeline_options import (
|
19
|
-
AcceleratorOptions,
|
20
20
|
OcrOptions,
|
21
21
|
TesseractCliOcrOptions,
|
22
22
|
)
|
@@ -7,10 +7,10 @@ from typing import Iterable, Optional, Type
|
|
7
7
|
from docling_core.types.doc import BoundingBox, CoordOrigin
|
8
8
|
from docling_core.types.doc.page import TextCell
|
9
9
|
|
10
|
+
from docling.datamodel.accelerator_options import AcceleratorOptions
|
10
11
|
from docling.datamodel.base_models import Page
|
11
12
|
from docling.datamodel.document import ConversionResult
|
12
13
|
from docling.datamodel.pipeline_options import (
|
13
|
-
AcceleratorOptions,
|
14
14
|
OcrOptions,
|
15
15
|
TesseractOcrOptions,
|
16
16
|
)
|
File without changes
|
@@ -0,0 +1,40 @@
|
|
1
|
+
import logging
|
2
|
+
from pathlib import Path
|
3
|
+
from typing import Optional
|
4
|
+
|
5
|
+
_log = logging.getLogger(__name__)
|
6
|
+
|
7
|
+
|
8
|
+
def download_hf_model(
|
9
|
+
repo_id: str,
|
10
|
+
local_dir: Optional[Path] = None,
|
11
|
+
force: bool = False,
|
12
|
+
progress: bool = False,
|
13
|
+
revision: Optional[str] = None,
|
14
|
+
) -> Path:
|
15
|
+
from huggingface_hub import snapshot_download
|
16
|
+
from huggingface_hub.utils import disable_progress_bars
|
17
|
+
|
18
|
+
if not progress:
|
19
|
+
disable_progress_bars()
|
20
|
+
download_path = snapshot_download(
|
21
|
+
repo_id=repo_id,
|
22
|
+
force_download=force,
|
23
|
+
local_dir=local_dir,
|
24
|
+
revision=revision,
|
25
|
+
)
|
26
|
+
|
27
|
+
return Path(download_path)
|
28
|
+
|
29
|
+
|
30
|
+
class HuggingFaceModelDownloadMixin:
|
31
|
+
@staticmethod
|
32
|
+
def download_models(
|
33
|
+
repo_id: str,
|
34
|
+
local_dir: Optional[Path] = None,
|
35
|
+
force: bool = False,
|
36
|
+
progress: bool = False,
|
37
|
+
) -> Path:
|
38
|
+
return download_hf_model(
|
39
|
+
repo_id=repo_id, local_dir=local_dir, force=force, progress=progress
|
40
|
+
)
|
File without changes
|
@@ -0,0 +1,194 @@
|
|
1
|
+
import importlib.metadata
|
2
|
+
import logging
|
3
|
+
import time
|
4
|
+
from collections.abc import Iterable
|
5
|
+
from pathlib import Path
|
6
|
+
from typing import Any, Optional
|
7
|
+
|
8
|
+
from docling.datamodel.accelerator_options import (
|
9
|
+
AcceleratorOptions,
|
10
|
+
)
|
11
|
+
from docling.datamodel.base_models import Page, VlmPrediction
|
12
|
+
from docling.datamodel.document import ConversionResult
|
13
|
+
from docling.datamodel.pipeline_options_vlm_model import (
|
14
|
+
InlineVlmOptions,
|
15
|
+
TransformersModelType,
|
16
|
+
)
|
17
|
+
from docling.models.base_model import BasePageModel
|
18
|
+
from docling.models.utils.hf_model_download import (
|
19
|
+
HuggingFaceModelDownloadMixin,
|
20
|
+
)
|
21
|
+
from docling.utils.accelerator_utils import decide_device
|
22
|
+
from docling.utils.profiling import TimeRecorder
|
23
|
+
|
24
|
+
_log = logging.getLogger(__name__)
|
25
|
+
|
26
|
+
|
27
|
+
class HuggingFaceTransformersVlmModel(BasePageModel, HuggingFaceModelDownloadMixin):
|
28
|
+
def __init__(
|
29
|
+
self,
|
30
|
+
enabled: bool,
|
31
|
+
artifacts_path: Optional[Path],
|
32
|
+
accelerator_options: AcceleratorOptions,
|
33
|
+
vlm_options: InlineVlmOptions,
|
34
|
+
):
|
35
|
+
self.enabled = enabled
|
36
|
+
|
37
|
+
self.vlm_options = vlm_options
|
38
|
+
|
39
|
+
if self.enabled:
|
40
|
+
import torch
|
41
|
+
from transformers import (
|
42
|
+
AutoModel,
|
43
|
+
AutoModelForCausalLM,
|
44
|
+
AutoModelForVision2Seq,
|
45
|
+
AutoProcessor,
|
46
|
+
BitsAndBytesConfig,
|
47
|
+
GenerationConfig,
|
48
|
+
)
|
49
|
+
|
50
|
+
transformers_version = importlib.metadata.version("transformers")
|
51
|
+
if (
|
52
|
+
self.vlm_options.repo_id == "microsoft/Phi-4-multimodal-instruct"
|
53
|
+
and transformers_version >= "4.52.0"
|
54
|
+
):
|
55
|
+
raise NotImplementedError(
|
56
|
+
f"Phi 4 only works with transformers<4.52.0 but you have {transformers_version=}. Please downgrage running pip install -U 'transformers<4.52.0'."
|
57
|
+
)
|
58
|
+
|
59
|
+
self.device = decide_device(
|
60
|
+
accelerator_options.device,
|
61
|
+
supported_devices=vlm_options.supported_devices,
|
62
|
+
)
|
63
|
+
_log.debug(f"Available device for VLM: {self.device}")
|
64
|
+
|
65
|
+
self.use_cache = vlm_options.use_kv_cache
|
66
|
+
self.max_new_tokens = vlm_options.max_new_tokens
|
67
|
+
self.temperature = vlm_options.temperature
|
68
|
+
|
69
|
+
repo_cache_folder = vlm_options.repo_id.replace("/", "--")
|
70
|
+
|
71
|
+
if artifacts_path is None:
|
72
|
+
artifacts_path = self.download_models(self.vlm_options.repo_id)
|
73
|
+
elif (artifacts_path / repo_cache_folder).exists():
|
74
|
+
artifacts_path = artifacts_path / repo_cache_folder
|
75
|
+
|
76
|
+
self.param_quantization_config: Optional[BitsAndBytesConfig] = None
|
77
|
+
if vlm_options.quantized:
|
78
|
+
self.param_quantization_config = BitsAndBytesConfig(
|
79
|
+
load_in_8bit=vlm_options.load_in_8bit,
|
80
|
+
llm_int8_threshold=vlm_options.llm_int8_threshold,
|
81
|
+
)
|
82
|
+
|
83
|
+
model_cls: Any = AutoModel
|
84
|
+
if (
|
85
|
+
self.vlm_options.transformers_model_type
|
86
|
+
== TransformersModelType.AUTOMODEL_CAUSALLM
|
87
|
+
):
|
88
|
+
model_cls = AutoModelForCausalLM
|
89
|
+
elif (
|
90
|
+
self.vlm_options.transformers_model_type
|
91
|
+
== TransformersModelType.AUTOMODEL_VISION2SEQ
|
92
|
+
):
|
93
|
+
model_cls = AutoModelForVision2Seq
|
94
|
+
|
95
|
+
self.processor = AutoProcessor.from_pretrained(
|
96
|
+
artifacts_path,
|
97
|
+
trust_remote_code=vlm_options.trust_remote_code,
|
98
|
+
)
|
99
|
+
self.vlm_model = model_cls.from_pretrained(
|
100
|
+
artifacts_path,
|
101
|
+
device_map=self.device,
|
102
|
+
_attn_implementation=(
|
103
|
+
"flash_attention_2"
|
104
|
+
if self.device.startswith("cuda")
|
105
|
+
and accelerator_options.cuda_use_flash_attention2
|
106
|
+
else "eager"
|
107
|
+
),
|
108
|
+
trust_remote_code=vlm_options.trust_remote_code,
|
109
|
+
)
|
110
|
+
|
111
|
+
# Load generation config
|
112
|
+
self.generation_config = GenerationConfig.from_pretrained(artifacts_path)
|
113
|
+
|
114
|
+
def __call__(
|
115
|
+
self, conv_res: ConversionResult, page_batch: Iterable[Page]
|
116
|
+
) -> Iterable[Page]:
|
117
|
+
for page in page_batch:
|
118
|
+
assert page._backend is not None
|
119
|
+
if not page._backend.is_valid():
|
120
|
+
yield page
|
121
|
+
else:
|
122
|
+
with TimeRecorder(conv_res, "vlm"):
|
123
|
+
assert page.size is not None
|
124
|
+
|
125
|
+
hi_res_image = page.get_image(scale=self.vlm_options.scale)
|
126
|
+
|
127
|
+
# Define prompt structure
|
128
|
+
prompt = self.formulate_prompt()
|
129
|
+
|
130
|
+
inputs = self.processor(
|
131
|
+
text=prompt, images=[hi_res_image], return_tensors="pt"
|
132
|
+
).to(self.device)
|
133
|
+
|
134
|
+
start_time = time.time()
|
135
|
+
# Call model to generate:
|
136
|
+
generated_ids = self.vlm_model.generate(
|
137
|
+
**inputs,
|
138
|
+
max_new_tokens=self.max_new_tokens,
|
139
|
+
use_cache=self.use_cache,
|
140
|
+
temperature=self.temperature,
|
141
|
+
generation_config=self.generation_config,
|
142
|
+
**self.vlm_options.extra_generation_config,
|
143
|
+
)
|
144
|
+
|
145
|
+
generation_time = time.time() - start_time
|
146
|
+
generated_texts = self.processor.batch_decode(
|
147
|
+
generated_ids[:, inputs["input_ids"].shape[1] :],
|
148
|
+
skip_special_tokens=False,
|
149
|
+
)[0]
|
150
|
+
|
151
|
+
num_tokens = len(generated_ids[0])
|
152
|
+
_log.debug(
|
153
|
+
f"Generated {num_tokens} tokens in time {generation_time:.2f} seconds."
|
154
|
+
)
|
155
|
+
page.predictions.vlm_response = VlmPrediction(
|
156
|
+
text=generated_texts,
|
157
|
+
generation_time=generation_time,
|
158
|
+
)
|
159
|
+
|
160
|
+
yield page
|
161
|
+
|
162
|
+
def formulate_prompt(self) -> str:
|
163
|
+
"""Formulate a prompt for the VLM."""
|
164
|
+
|
165
|
+
if self.vlm_options.repo_id == "microsoft/Phi-4-multimodal-instruct":
|
166
|
+
_log.debug("Using specialized prompt for Phi-4")
|
167
|
+
# more info here: https://huggingface.co/microsoft/Phi-4-multimodal-instruct#loading-the-model-locally
|
168
|
+
|
169
|
+
user_prompt = "<|user|>"
|
170
|
+
assistant_prompt = "<|assistant|>"
|
171
|
+
prompt_suffix = "<|end|>"
|
172
|
+
|
173
|
+
prompt = f"{user_prompt}<|image_1|>{self.vlm_options.prompt}{prompt_suffix}{assistant_prompt}"
|
174
|
+
_log.debug(f"prompt for {self.vlm_options.repo_id}: {prompt}")
|
175
|
+
|
176
|
+
return prompt
|
177
|
+
|
178
|
+
messages = [
|
179
|
+
{
|
180
|
+
"role": "user",
|
181
|
+
"content": [
|
182
|
+
{
|
183
|
+
"type": "text",
|
184
|
+
"text": "This is a page from a document.",
|
185
|
+
},
|
186
|
+
{"type": "image"},
|
187
|
+
{"type": "text", "text": self.vlm_options.prompt},
|
188
|
+
],
|
189
|
+
}
|
190
|
+
]
|
191
|
+
prompt = self.processor.apply_chat_template(
|
192
|
+
messages, add_generation_prompt=False
|
193
|
+
)
|
194
|
+
return prompt
|