natural-pdf 0.1.4__py3-none-any.whl → 0.1.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/api/index.md +386 -0
- docs/assets/favicon.png +3 -0
- docs/assets/favicon.svg +3 -0
- docs/assets/javascripts/custom.js +17 -0
- docs/assets/logo.svg +3 -0
- docs/assets/sample-screen.png +0 -0
- docs/assets/social-preview.png +17 -0
- docs/assets/social-preview.svg +17 -0
- docs/assets/stylesheets/custom.css +65 -0
- docs/document-qa/index.ipynb +435 -0
- docs/document-qa/index.md +79 -0
- docs/element-selection/index.ipynb +915 -0
- docs/element-selection/index.md +229 -0
- docs/index.md +170 -0
- docs/installation/index.md +69 -0
- docs/interactive-widget/index.ipynb +962 -0
- docs/interactive-widget/index.md +12 -0
- docs/layout-analysis/index.ipynb +818 -0
- docs/layout-analysis/index.md +185 -0
- docs/ocr/index.md +209 -0
- docs/pdf-navigation/index.ipynb +314 -0
- docs/pdf-navigation/index.md +97 -0
- docs/regions/index.ipynb +816 -0
- docs/regions/index.md +294 -0
- docs/tables/index.ipynb +658 -0
- docs/tables/index.md +144 -0
- docs/text-analysis/index.ipynb +370 -0
- docs/text-analysis/index.md +105 -0
- docs/text-extraction/index.ipynb +1478 -0
- docs/text-extraction/index.md +292 -0
- docs/tutorials/01-loading-and-extraction.ipynb +1710 -0
- docs/tutorials/01-loading-and-extraction.md +95 -0
- docs/tutorials/02-finding-elements.ipynb +340 -0
- docs/tutorials/02-finding-elements.md +149 -0
- docs/tutorials/03-extracting-blocks.ipynb +147 -0
- docs/tutorials/03-extracting-blocks.md +48 -0
- docs/tutorials/04-table-extraction.ipynb +114 -0
- docs/tutorials/04-table-extraction.md +50 -0
- docs/tutorials/05-excluding-content.ipynb +270 -0
- docs/tutorials/05-excluding-content.md +109 -0
- docs/tutorials/06-document-qa.ipynb +332 -0
- docs/tutorials/06-document-qa.md +91 -0
- docs/tutorials/07-layout-analysis.ipynb +288 -0
- docs/tutorials/07-layout-analysis.md +66 -0
- docs/tutorials/07-working-with-regions.ipynb +413 -0
- docs/tutorials/07-working-with-regions.md +151 -0
- docs/tutorials/08-spatial-navigation.ipynb +508 -0
- docs/tutorials/08-spatial-navigation.md +190 -0
- docs/tutorials/09-section-extraction.ipynb +2434 -0
- docs/tutorials/09-section-extraction.md +256 -0
- docs/tutorials/10-form-field-extraction.ipynb +512 -0
- docs/tutorials/10-form-field-extraction.md +201 -0
- docs/tutorials/11-enhanced-table-processing.ipynb +54 -0
- docs/tutorials/11-enhanced-table-processing.md +9 -0
- docs/tutorials/12-ocr-integration.ipynb +604 -0
- docs/tutorials/12-ocr-integration.md +175 -0
- docs/tutorials/13-semantic-search.ipynb +1328 -0
- docs/tutorials/13-semantic-search.md +77 -0
- docs/visual-debugging/index.ipynb +2970 -0
- docs/visual-debugging/index.md +157 -0
- docs/visual-debugging/region.png +0 -0
- natural_pdf/__init__.py +50 -33
- natural_pdf/analyzers/__init__.py +2 -1
- natural_pdf/analyzers/layout/base.py +32 -24
- natural_pdf/analyzers/layout/docling.py +131 -72
- natural_pdf/analyzers/layout/gemini.py +264 -0
- natural_pdf/analyzers/layout/layout_analyzer.py +156 -113
- natural_pdf/analyzers/layout/layout_manager.py +125 -58
- natural_pdf/analyzers/layout/layout_options.py +43 -17
- natural_pdf/analyzers/layout/paddle.py +152 -95
- natural_pdf/analyzers/layout/surya.py +164 -92
- natural_pdf/analyzers/layout/tatr.py +149 -84
- natural_pdf/analyzers/layout/yolo.py +89 -45
- natural_pdf/analyzers/text_options.py +22 -15
- natural_pdf/analyzers/text_structure.py +131 -85
- natural_pdf/analyzers/utils.py +30 -23
- natural_pdf/collections/pdf_collection.py +146 -97
- natural_pdf/core/__init__.py +1 -1
- natural_pdf/core/element_manager.py +419 -337
- natural_pdf/core/highlighting_service.py +268 -196
- natural_pdf/core/page.py +1044 -521
- natural_pdf/core/pdf.py +516 -313
- natural_pdf/elements/__init__.py +1 -1
- natural_pdf/elements/base.py +307 -225
- natural_pdf/elements/collections.py +805 -543
- natural_pdf/elements/line.py +39 -36
- natural_pdf/elements/rect.py +32 -30
- natural_pdf/elements/region.py +889 -879
- natural_pdf/elements/text.py +127 -99
- natural_pdf/exporters/__init__.py +0 -1
- natural_pdf/exporters/searchable_pdf.py +261 -102
- natural_pdf/ocr/__init__.py +57 -35
- natural_pdf/ocr/engine.py +150 -46
- natural_pdf/ocr/engine_easyocr.py +146 -150
- natural_pdf/ocr/engine_paddle.py +118 -175
- natural_pdf/ocr/engine_surya.py +78 -141
- natural_pdf/ocr/ocr_factory.py +114 -0
- natural_pdf/ocr/ocr_manager.py +122 -124
- natural_pdf/ocr/ocr_options.py +16 -20
- natural_pdf/ocr/utils.py +98 -0
- natural_pdf/qa/__init__.py +1 -1
- natural_pdf/qa/document_qa.py +119 -111
- natural_pdf/search/__init__.py +37 -31
- natural_pdf/search/haystack_search_service.py +312 -189
- natural_pdf/search/haystack_utils.py +186 -122
- natural_pdf/search/search_options.py +25 -14
- natural_pdf/search/search_service_protocol.py +12 -6
- natural_pdf/search/searchable_mixin.py +261 -176
- natural_pdf/selectors/__init__.py +2 -1
- natural_pdf/selectors/parser.py +159 -316
- natural_pdf/templates/__init__.py +1 -1
- natural_pdf/templates/spa/css/style.css +334 -0
- natural_pdf/templates/spa/index.html +31 -0
- natural_pdf/templates/spa/js/app.js +472 -0
- natural_pdf/templates/spa/words.txt +235976 -0
- natural_pdf/utils/debug.py +32 -0
- natural_pdf/utils/highlighting.py +8 -2
- natural_pdf/utils/identifiers.py +29 -0
- natural_pdf/utils/packaging.py +418 -0
- natural_pdf/utils/reading_order.py +65 -63
- natural_pdf/utils/text_extraction.py +195 -0
- natural_pdf/utils/visualization.py +70 -61
- natural_pdf/widgets/__init__.py +2 -3
- natural_pdf/widgets/viewer.py +749 -718
- {natural_pdf-0.1.4.dist-info → natural_pdf-0.1.6.dist-info}/METADATA +53 -17
- natural_pdf-0.1.6.dist-info/RECORD +141 -0
- {natural_pdf-0.1.4.dist-info → natural_pdf-0.1.6.dist-info}/WHEEL +1 -1
- natural_pdf-0.1.6.dist-info/top_level.txt +4 -0
- notebooks/Examples.ipynb +1293 -0
- pdfs/.gitkeep +0 -0
- pdfs/01-practice.pdf +543 -0
- pdfs/0500000US42001.pdf +0 -0
- pdfs/0500000US42007.pdf +0 -0
- pdfs/2014 Statistics.pdf +0 -0
- pdfs/2019 Statistics.pdf +0 -0
- pdfs/Atlanta_Public_Schools_GA_sample.pdf +0 -0
- pdfs/needs-ocr.pdf +0 -0
- natural_pdf/templates/ocr_debug.html +0 -517
- natural_pdf-0.1.4.dist-info/RECORD +0 -61
- natural_pdf-0.1.4.dist-info/top_level.txt +0 -1
- {natural_pdf-0.1.4.dist-info → natural_pdf-0.1.6.dist-info}/licenses/LICENSE +0 -0
@@ -1,24 +1,38 @@
|
|
1
1
|
# layout_detector_docling.py
|
2
|
-
import logging
|
3
2
|
import importlib.util
|
3
|
+
import logging
|
4
4
|
import os
|
5
5
|
import tempfile
|
6
|
-
from typing import
|
6
|
+
from typing import Any, Dict, List, Optional
|
7
|
+
|
7
8
|
from PIL import Image
|
8
9
|
|
9
10
|
# Assuming base class and options are importable
|
10
11
|
try:
|
11
12
|
from .base import LayoutDetector
|
12
|
-
from .layout_options import
|
13
|
+
from .layout_options import BaseLayoutOptions, DoclingLayoutOptions
|
13
14
|
except ImportError:
|
14
15
|
# Placeholders if run standalone or imports fail
|
15
|
-
class BaseLayoutOptions:
|
16
|
-
|
16
|
+
class BaseLayoutOptions:
|
17
|
+
pass
|
18
|
+
|
19
|
+
class DoclingLayoutOptions(BaseLayoutOptions):
|
20
|
+
pass
|
21
|
+
|
17
22
|
class LayoutDetector:
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
23
|
+
def __init__(self):
|
24
|
+
self.logger = logging.getLogger()
|
25
|
+
self.supported_classes = set()
|
26
|
+
|
27
|
+
def _get_model(self, options):
|
28
|
+
raise NotImplementedError
|
29
|
+
|
30
|
+
def _normalize_class_name(self, n):
|
31
|
+
return n
|
32
|
+
|
33
|
+
def validate_classes(self, c):
|
34
|
+
pass
|
35
|
+
|
22
36
|
logging.basicConfig()
|
23
37
|
|
24
38
|
logger = logging.getLogger(__name__)
|
@@ -42,11 +56,27 @@ class DoclingLayoutDetector(LayoutDetector):
|
|
42
56
|
super().__init__()
|
43
57
|
# Docling classes are dynamic/hierarchical, define common ones
|
44
58
|
self.supported_classes = {
|
45
|
-
|
46
|
-
|
47
|
-
|
59
|
+
"Header",
|
60
|
+
"Footer",
|
61
|
+
"Paragraph",
|
62
|
+
"Heading",
|
63
|
+
"List",
|
64
|
+
"ListItem",
|
65
|
+
"Table",
|
66
|
+
"Figure",
|
67
|
+
"Caption",
|
68
|
+
"Footnote",
|
69
|
+
"PageNumber",
|
70
|
+
"Equation",
|
71
|
+
"Code",
|
72
|
+
"Title",
|
73
|
+
"Author",
|
74
|
+
"Abstract",
|
75
|
+
"Section",
|
76
|
+
"Unknown",
|
77
|
+
"Metadata", # Add more as needed
|
48
78
|
}
|
49
|
-
self._docling_document_cache = {}
|
79
|
+
self._docling_document_cache = {} # Cache the output doc per image/options if needed
|
50
80
|
|
51
81
|
def is_available(self) -> bool:
|
52
82
|
"""Check if docling is installed."""
|
@@ -55,9 +85,9 @@ class DoclingLayoutDetector(LayoutDetector):
|
|
55
85
|
def _get_cache_key(self, options: BaseLayoutOptions) -> str:
|
56
86
|
"""Generate cache key based on device and potentially converter args."""
|
57
87
|
if not isinstance(options, DoclingLayoutOptions):
|
58
|
-
|
88
|
+
options = DoclingLayoutOptions(device=options.device, extra_args=options.extra_args)
|
59
89
|
|
60
|
-
device_key = str(options.device).lower() if options.device else
|
90
|
+
device_key = str(options.device).lower() if options.device else "default_device"
|
61
91
|
# Include hash of extra_args if they affect model loading/converter init
|
62
92
|
extra_args_key = hash(frozenset(options.extra_args.items()))
|
63
93
|
return f"{self.__class__.__name__}_{device_key}_{extra_args_key}"
|
@@ -88,12 +118,17 @@ class DoclingLayoutDetector(LayoutDetector):
|
|
88
118
|
raise RuntimeError("Docling dependency not installed.")
|
89
119
|
|
90
120
|
if not isinstance(options, DoclingLayoutOptions):
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
121
|
+
self.logger.warning(
|
122
|
+
"Received BaseLayoutOptions, expected DoclingLayoutOptions. Using defaults."
|
123
|
+
)
|
124
|
+
options = DoclingLayoutOptions(
|
125
|
+
confidence=options.confidence,
|
126
|
+
classes=options.classes,
|
127
|
+
exclude_classes=options.exclude_classes,
|
128
|
+
device=options.device,
|
129
|
+
extra_args=options.extra_args,
|
130
|
+
verbose=options.extra_args.get("verbose", False),
|
131
|
+
)
|
97
132
|
|
98
133
|
# Validate classes before proceeding (note: Docling classes are case-sensitive)
|
99
134
|
# self.validate_classes(options.classes or []) # Validation might be tricky due to case sensitivity
|
@@ -105,18 +140,20 @@ class DoclingLayoutDetector(LayoutDetector):
|
|
105
140
|
|
106
141
|
# Docling convert method requires an image path. Save temp file.
|
107
142
|
detections = []
|
108
|
-
docling_doc = None
|
143
|
+
docling_doc = None # To store the result
|
109
144
|
with tempfile.TemporaryDirectory() as temp_dir:
|
110
145
|
temp_image_path = os.path.join(temp_dir, f"docling_input_{os.getpid()}.png")
|
111
146
|
try:
|
112
|
-
self.logger.debug(
|
113
|
-
|
147
|
+
self.logger.debug(
|
148
|
+
f"Saving temporary image for Docling detector to: {temp_image_path}"
|
149
|
+
)
|
150
|
+
image.convert("RGB").save(temp_image_path) # Ensure RGB
|
114
151
|
|
115
152
|
# Convert the document using Docling's DocumentConverter
|
116
153
|
self.logger.debug("Running Docling conversion...")
|
117
154
|
# Docling convert returns a Result object with a 'document' attribute
|
118
155
|
result = converter.convert(temp_image_path)
|
119
|
-
docling_doc = result.document
|
156
|
+
docling_doc = result.document # Store the DoclingDocument
|
120
157
|
self.logger.info(f"Docling conversion complete.")
|
121
158
|
|
122
159
|
# Convert Docling document to our detection format
|
@@ -124,12 +161,14 @@ class DoclingLayoutDetector(LayoutDetector):
|
|
124
161
|
|
125
162
|
except Exception as e:
|
126
163
|
self.logger.error(f"Error during Docling detection: {e}", exc_info=True)
|
127
|
-
raise
|
164
|
+
raise # Re-raise the exception
|
128
165
|
finally:
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
166
|
+
# Ensure temp file is removed
|
167
|
+
if os.path.exists(temp_image_path):
|
168
|
+
try:
|
169
|
+
os.remove(temp_image_path)
|
170
|
+
except OSError as e_rm:
|
171
|
+
self.logger.warning(f"Could not remove temp file {temp_image_path}: {e_rm}")
|
133
172
|
|
134
173
|
# Cache the docling document if needed elsewhere (maybe associate with page?)
|
135
174
|
# self._docling_document_cache[image_hash] = docling_doc # Needs a way to key this
|
@@ -137,26 +176,37 @@ class DoclingLayoutDetector(LayoutDetector):
|
|
137
176
|
self.logger.info(f"Docling detected {len(detections)} layout elements matching criteria.")
|
138
177
|
return detections
|
139
178
|
|
140
|
-
def _convert_docling_to_detections(
|
179
|
+
def _convert_docling_to_detections(
|
180
|
+
self, doc, options: DoclingLayoutOptions
|
181
|
+
) -> List[Dict[str, Any]]:
|
141
182
|
"""Convert a Docling document to our standard detection format."""
|
142
|
-
if not doc or not hasattr(doc,
|
183
|
+
if not doc or not hasattr(doc, "pages") or not doc.pages:
|
143
184
|
self.logger.warning("Invalid or empty Docling document for conversion.")
|
144
185
|
return []
|
145
186
|
|
146
187
|
detections = []
|
147
|
-
id_to_detection_index = {}
|
188
|
+
id_to_detection_index = {} # Map Docling ID to index in detections list
|
148
189
|
|
149
190
|
# Prepare normalized class filters once
|
150
|
-
normalized_classes_req =
|
151
|
-
|
191
|
+
normalized_classes_req = (
|
192
|
+
{self._normalize_class_name(c) for c in options.classes} if options.classes else None
|
193
|
+
)
|
194
|
+
normalized_classes_excl = (
|
195
|
+
{self._normalize_class_name(c) for c in options.exclude_classes}
|
196
|
+
if options.exclude_classes
|
197
|
+
else set()
|
198
|
+
)
|
152
199
|
|
153
200
|
# --- Iterate through elements using Docling's structure ---
|
154
201
|
# This requires traversing the hierarchy (e.g., doc.body.children)
|
155
202
|
# or iterating through specific lists like doc.texts, doc.tables etc.
|
156
203
|
elements_to_process = []
|
157
|
-
if hasattr(doc,
|
158
|
-
|
159
|
-
if hasattr(doc,
|
204
|
+
if hasattr(doc, "texts"):
|
205
|
+
elements_to_process.extend(doc.texts)
|
206
|
+
if hasattr(doc, "tables"):
|
207
|
+
elements_to_process.extend(doc.tables)
|
208
|
+
if hasattr(doc, "pictures"):
|
209
|
+
elements_to_process.extend(doc.pictures)
|
160
210
|
# Add other element types from DoclingDocument as needed
|
161
211
|
|
162
212
|
self.logger.debug(f"Converting {len(elements_to_process)} Docling elements...")
|
@@ -164,16 +214,19 @@ class DoclingLayoutDetector(LayoutDetector):
|
|
164
214
|
for elem in elements_to_process:
|
165
215
|
try:
|
166
216
|
# Get Provenance (bbox and page number)
|
167
|
-
if not hasattr(elem,
|
168
|
-
|
169
|
-
|
217
|
+
if not hasattr(elem, "prov") or not elem.prov:
|
218
|
+
continue
|
219
|
+
prov = elem.prov[0] # Use first provenance
|
220
|
+
if not hasattr(prov, "bbox") or not prov.bbox:
|
221
|
+
continue
|
170
222
|
bbox = prov.bbox
|
171
223
|
page_no = prov.page_no
|
172
224
|
|
173
225
|
# Get Page Dimensions (crucial for coordinate conversion)
|
174
|
-
if not hasattr(doc.pages.get(page_no),
|
226
|
+
if not hasattr(doc.pages.get(page_no), "size"):
|
227
|
+
continue
|
175
228
|
page_height = doc.pages[page_no].size.height
|
176
|
-
page_width = doc.pages[page_no].size.width
|
229
|
+
page_width = doc.pages[page_no].size.width # Needed? Bbox seems absolute
|
177
230
|
|
178
231
|
# Convert coordinates from Docling's system (often bottom-left origin)
|
179
232
|
# to standard top-left origin (0,0 at top-left)
|
@@ -182,46 +235,51 @@ class DoclingLayoutDetector(LayoutDetector):
|
|
182
235
|
x1 = float(bbox.r)
|
183
236
|
# Convert y: top_y = page_height - bottom_left_t
|
184
237
|
# bottom_y = page_height - bottom_left_b
|
185
|
-
y0 = float(page_height - bbox.t)
|
186
|
-
y1 = float(page_height - bbox.b)
|
238
|
+
y0 = float(page_height - bbox.t) # Top y
|
239
|
+
y1 = float(page_height - bbox.b) # Bottom y
|
187
240
|
|
188
241
|
# Ensure y0 < y1
|
189
|
-
if y0 > y1:
|
242
|
+
if y0 > y1:
|
243
|
+
y0, y1 = y1, y0
|
190
244
|
# Ensure x0 < x1
|
191
|
-
if x0 > x1:
|
245
|
+
if x0 > x1:
|
246
|
+
x0, x1 = x1, x0
|
192
247
|
|
193
248
|
# Get Class Label
|
194
|
-
label_orig = str(getattr(elem,
|
249
|
+
label_orig = str(getattr(elem, "label", "Unknown")) # Default if no label
|
195
250
|
normalized_label = self._normalize_class_name(label_orig)
|
196
251
|
|
197
252
|
# Apply Class Filtering
|
198
|
-
if normalized_classes_req and normalized_label not in normalized_classes_req:
|
199
|
-
|
253
|
+
if normalized_classes_req and normalized_label not in normalized_classes_req:
|
254
|
+
continue
|
255
|
+
if normalized_label in normalized_classes_excl:
|
256
|
+
continue
|
200
257
|
|
201
258
|
# Get Confidence (Docling often doesn't provide per-element confidence)
|
202
|
-
confidence = getattr(elem,
|
203
|
-
if confidence < options.confidence:
|
259
|
+
confidence = getattr(elem, "confidence", 0.95) # Assign default confidence
|
260
|
+
if confidence < options.confidence:
|
261
|
+
continue # Apply confidence threshold
|
204
262
|
|
205
263
|
# Get Text Content
|
206
|
-
text_content = getattr(elem,
|
264
|
+
text_content = getattr(elem, "text", None)
|
207
265
|
|
208
266
|
# Get IDs for hierarchy
|
209
|
-
docling_id = getattr(elem,
|
210
|
-
parent_id_obj = getattr(elem,
|
211
|
-
parent_id = getattr(parent_id_obj,
|
267
|
+
docling_id = getattr(elem, "self_ref", None)
|
268
|
+
parent_id_obj = getattr(elem, "parent", None)
|
269
|
+
parent_id = getattr(parent_id_obj, "self_ref", None) if parent_id_obj else None
|
212
270
|
|
213
271
|
# Create Detection Dictionary
|
214
272
|
detection = {
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
273
|
+
"bbox": (x0, y0, x1, y1),
|
274
|
+
"class": label_orig,
|
275
|
+
"normalized_class": normalized_label,
|
276
|
+
"confidence": confidence,
|
277
|
+
"text": text_content,
|
278
|
+
"docling_id": docling_id,
|
279
|
+
"parent_id": parent_id,
|
280
|
+
"page_number": page_no, # Add page number if useful
|
281
|
+
"source": "layout",
|
282
|
+
"model": "docling",
|
225
283
|
}
|
226
284
|
detections.append(detection)
|
227
285
|
|
@@ -229,8 +287,8 @@ class DoclingLayoutDetector(LayoutDetector):
|
|
229
287
|
# if docling_id: id_to_detection_index[docling_id] = len(detections) - 1
|
230
288
|
|
231
289
|
except Exception as conv_e:
|
232
|
-
|
233
|
-
|
290
|
+
self.logger.warning(f"Could not convert Docling element: {elem}. Error: {conv_e}")
|
291
|
+
continue
|
234
292
|
|
235
293
|
return detections
|
236
294
|
|
@@ -241,7 +299,8 @@ class DoclingLayoutDetector(LayoutDetector):
|
|
241
299
|
"""
|
242
300
|
# This requires caching the doc based on image/options or re-running.
|
243
301
|
# For simplicity, let's just re-run detect if needed.
|
244
|
-
self.logger.warning(
|
245
|
-
|
246
|
-
|
247
|
-
|
302
|
+
self.logger.warning(
|
303
|
+
"get_docling_document: Re-running detection to ensure document is generated."
|
304
|
+
)
|
305
|
+
self.detect(image, options) # Run detect to populate internal doc
|
306
|
+
return getattr(self, "_docling_document", None) # Return the stored doc
|
@@ -0,0 +1,264 @@
|
|
1
|
+
# layout_detector_gemini.py
|
2
|
+
import importlib.util
|
3
|
+
import logging
|
4
|
+
import os
|
5
|
+
from typing import Any, Dict, List, Optional
|
6
|
+
import base64
|
7
|
+
import io
|
8
|
+
|
9
|
+
from pydantic import BaseModel, Field
|
10
|
+
from PIL import Image
|
11
|
+
|
12
|
+
# Use OpenAI library for interaction
|
13
|
+
try:
|
14
|
+
from openai import OpenAI
|
15
|
+
from openai.types.chat import ChatCompletion
|
16
|
+
# Import OpenAIError for exception handling if needed
|
17
|
+
except ImportError:
|
18
|
+
OpenAI = None
|
19
|
+
ChatCompletion = None
|
20
|
+
|
21
|
+
try:
|
22
|
+
from .base import LayoutDetector
|
23
|
+
from .layout_options import BaseLayoutOptions, GeminiLayoutOptions
|
24
|
+
except ImportError:
|
25
|
+
# Placeholders if run standalone or imports fail
|
26
|
+
class BaseLayoutOptions:
|
27
|
+
pass
|
28
|
+
|
29
|
+
class GeminiLayoutOptions(BaseLayoutOptions):
|
30
|
+
pass
|
31
|
+
|
32
|
+
class LayoutDetector:
|
33
|
+
def __init__(self):
|
34
|
+
self.logger = logging.getLogger()
|
35
|
+
self.supported_classes = set() # Will be dynamic based on user request
|
36
|
+
|
37
|
+
def _get_model(self, options):
|
38
|
+
raise NotImplementedError
|
39
|
+
|
40
|
+
def _normalize_class_name(self, n):
|
41
|
+
return n.lower().replace("_", "-").replace(" ", "-")
|
42
|
+
|
43
|
+
def validate_classes(self, c):
|
44
|
+
pass # Less strict validation needed for LLM
|
45
|
+
|
46
|
+
logging.basicConfig()
|
47
|
+
|
48
|
+
logger = logging.getLogger(__name__)
|
49
|
+
|
50
|
+
# Define Pydantic model for the expected output structure
|
51
|
+
# This is used by the openai library's `response_format`
|
52
|
+
class DetectedRegion(BaseModel):
|
53
|
+
label: str = Field(description="The identified class name.")
|
54
|
+
bbox: List[float] = Field(description="Bounding box coordinates [xmin, ymin, xmax, ymax].", min_items=4, max_items=4)
|
55
|
+
confidence: float = Field(description="Confidence score [0.0, 1.0].", ge=0.0, le=1.0)
|
56
|
+
|
57
|
+
|
58
|
+
class GeminiLayoutDetector(LayoutDetector):
|
59
|
+
"""Document layout detector using Google's Gemini models via OpenAI compatibility layer."""
|
60
|
+
|
61
|
+
# Base URL for the Gemini OpenAI-compatible endpoint
|
62
|
+
GEMINI_BASE_URL = "https://generativelanguage.googleapis.com/v1beta/openai/"
|
63
|
+
|
64
|
+
def __init__(self):
|
65
|
+
super().__init__()
|
66
|
+
self.supported_classes = set() # Indicate dynamic nature
|
67
|
+
|
68
|
+
def is_available(self) -> bool:
|
69
|
+
"""Check if openai library is installed and GOOGLE_API_KEY is available."""
|
70
|
+
api_key = os.environ.get("GOOGLE_API_KEY")
|
71
|
+
if not api_key:
|
72
|
+
logger.warning("GOOGLE_API_KEY environment variable not set. Gemini detector (via OpenAI lib) will not be available.")
|
73
|
+
return False
|
74
|
+
if OpenAI is None:
|
75
|
+
logger.warning("openai package not found. Gemini detector (via OpenAI lib) will not be available.")
|
76
|
+
return False
|
77
|
+
return True
|
78
|
+
|
79
|
+
def _get_cache_key(self, options: GeminiLayoutOptions) -> str:
|
80
|
+
"""Generate cache key based on model name."""
|
81
|
+
if not isinstance(options, GeminiLayoutOptions):
|
82
|
+
options = GeminiLayoutOptions() # Use defaults
|
83
|
+
|
84
|
+
model_key = options.model_name
|
85
|
+
# Prompt is built dynamically, so not part of cache key based on options
|
86
|
+
return f"{self.__class__.__name__}_{model_key}"
|
87
|
+
|
88
|
+
def _load_model_from_options(self, options: GeminiLayoutOptions) -> Any:
|
89
|
+
"""Validate options and return the model name."""
|
90
|
+
if not self.is_available():
|
91
|
+
raise RuntimeError(
|
92
|
+
"OpenAI library not installed or GOOGLE_API_KEY not set. Please run: pip install openai"
|
93
|
+
)
|
94
|
+
|
95
|
+
if not isinstance(options, GeminiLayoutOptions):
|
96
|
+
raise TypeError("Incorrect options type provided for Gemini model loading.")
|
97
|
+
|
98
|
+
# Simply return the model name, client is created in detect()
|
99
|
+
return options.model_name
|
100
|
+
|
101
|
+
def detect(self, image: Image.Image, options: BaseLayoutOptions) -> List[Dict[str, Any]]:
|
102
|
+
"""Detect layout elements in an image using Gemini via OpenAI library."""
|
103
|
+
if not self.is_available():
|
104
|
+
raise RuntimeError(
|
105
|
+
"OpenAI library not installed or GOOGLE_API_KEY not set."
|
106
|
+
)
|
107
|
+
|
108
|
+
# Ensure options are the correct type
|
109
|
+
if not isinstance(options, GeminiLayoutOptions):
|
110
|
+
self.logger.warning(
|
111
|
+
"Received BaseLayoutOptions, expected GeminiLayoutOptions. Using defaults."
|
112
|
+
)
|
113
|
+
options = GeminiLayoutOptions(
|
114
|
+
confidence=options.confidence,
|
115
|
+
classes=options.classes,
|
116
|
+
exclude_classes=options.exclude_classes,
|
117
|
+
device=options.device,
|
118
|
+
extra_args=options.extra_args,
|
119
|
+
)
|
120
|
+
|
121
|
+
model_name = self._get_model(options)
|
122
|
+
api_key = os.environ.get("GOOGLE_API_KEY")
|
123
|
+
|
124
|
+
detections = []
|
125
|
+
try:
|
126
|
+
# --- 1. Initialize OpenAI Client for Gemini ---
|
127
|
+
client = OpenAI(
|
128
|
+
api_key=api_key,
|
129
|
+
base_url=self.GEMINI_BASE_URL
|
130
|
+
)
|
131
|
+
|
132
|
+
# --- 2. Prepare Input for OpenAI API ---
|
133
|
+
if not options.classes:
|
134
|
+
logger.error("Gemini layout detection requires a list of classes to find.")
|
135
|
+
return []
|
136
|
+
|
137
|
+
width, height = image.size
|
138
|
+
|
139
|
+
# Convert image to base64
|
140
|
+
buffered = io.BytesIO()
|
141
|
+
image.save(buffered, format="PNG")
|
142
|
+
img_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
|
143
|
+
image_url = f"data:image/png;base64,{img_base64}"
|
144
|
+
|
145
|
+
# Construct the prompt text
|
146
|
+
class_list_str = ", ".join(f'`{c}`' for c in options.classes)
|
147
|
+
prompt_text = (
|
148
|
+
f"Analyze the provided image of a document page ({width}x{height}). "
|
149
|
+
f"Identify all regions corresponding to the following types: {class_list_str}. "
|
150
|
+
f"Return ONLY the structured data requested."
|
151
|
+
)
|
152
|
+
|
153
|
+
# Prepare messages for chat completions endpoint
|
154
|
+
messages = [
|
155
|
+
{
|
156
|
+
"role": "user",
|
157
|
+
"content": [
|
158
|
+
{"type": "text", "text": prompt_text},
|
159
|
+
{
|
160
|
+
"type": "image_url",
|
161
|
+
"image_url": {"url": image_url},
|
162
|
+
},
|
163
|
+
],
|
164
|
+
}
|
165
|
+
]
|
166
|
+
|
167
|
+
# --- 3. Call OpenAI API using .parse for structured output ---
|
168
|
+
logger.debug(f"Running Gemini detection via OpenAI lib (Model: {model_name}). Asking for classes: {options.classes}")
|
169
|
+
|
170
|
+
# Extract relevant generation parameters from extra_args if provided
|
171
|
+
# Mapping common names: temperature, top_p, max_tokens
|
172
|
+
completion_kwargs = {
|
173
|
+
"temperature": options.extra_args.get("temperature", 0.2), # Default to low temp
|
174
|
+
"top_p": options.extra_args.get("top_p"),
|
175
|
+
"max_tokens": options.extra_args.get("max_tokens", 4096), # Map from max_output_tokens
|
176
|
+
}
|
177
|
+
# Filter out None values
|
178
|
+
completion_kwargs = {k: v for k, v in completion_kwargs.items() if v is not None}
|
179
|
+
|
180
|
+
completion: ChatCompletion = client.beta.chat.completions.parse(
|
181
|
+
model=model_name,
|
182
|
+
messages=messages,
|
183
|
+
response_format=List[DetectedRegion], # Pass the Pydantic model list
|
184
|
+
**completion_kwargs
|
185
|
+
)
|
186
|
+
|
187
|
+
logger.debug(f"Gemini response received via OpenAI lib.")
|
188
|
+
|
189
|
+
# --- 4. Process Parsed Response ---
|
190
|
+
if not completion.choices:
|
191
|
+
logger.error("Gemini response (via OpenAI lib) contained no choices.")
|
192
|
+
return []
|
193
|
+
|
194
|
+
# Get the parsed Pydantic objects
|
195
|
+
parsed_results = completion.choices[0].message.parsed
|
196
|
+
if not parsed_results or not isinstance(parsed_results, list):
|
197
|
+
logger.error(f"Gemini response (via OpenAI lib) did not contain a valid list of parsed regions. Found: {type(parsed_results)}")
|
198
|
+
return []
|
199
|
+
|
200
|
+
# --- 5. Convert to Detections & Filter ---
|
201
|
+
normalized_classes_req = {
|
202
|
+
self._normalize_class_name(c) for c in options.classes
|
203
|
+
}
|
204
|
+
normalized_classes_excl = {
|
205
|
+
self._normalize_class_name(c) for c in options.exclude_classes
|
206
|
+
} if options.exclude_classes else set()
|
207
|
+
|
208
|
+
for item in parsed_results:
|
209
|
+
# The item is already a validated DetectedRegion Pydantic object
|
210
|
+
# Access fields directly
|
211
|
+
label = item.label
|
212
|
+
bbox_raw = item.bbox
|
213
|
+
confidence_score = item.confidence
|
214
|
+
|
215
|
+
# Coordinates should already be floats, but ensure tuple format
|
216
|
+
xmin, ymin, xmax, ymax = tuple(bbox_raw)
|
217
|
+
|
218
|
+
# --- Apply Filtering ---
|
219
|
+
normalized_class = self._normalize_class_name(label)
|
220
|
+
|
221
|
+
# Check against requested classes (Should be guaranteed by schema, but doesn't hurt)
|
222
|
+
if normalized_class not in normalized_classes_req:
|
223
|
+
logger.warning(f"Gemini (via OpenAI) returned unexpected class '{label}' despite schema. Skipping.")
|
224
|
+
continue
|
225
|
+
|
226
|
+
# Check against excluded classes
|
227
|
+
if normalized_class in normalized_classes_excl:
|
228
|
+
logger.debug(f"Skipping excluded class '{label}' (normalized: {normalized_class}).")
|
229
|
+
continue
|
230
|
+
|
231
|
+
# Check against base confidence threshold from options
|
232
|
+
if confidence_score < options.confidence:
|
233
|
+
logger.debug(f"Skipping item with confidence {confidence_score:.3f} below threshold {options.confidence}.")
|
234
|
+
continue
|
235
|
+
|
236
|
+
# Add detection
|
237
|
+
detections.append({
|
238
|
+
"bbox": (xmin, ymin, xmax, ymax),
|
239
|
+
"class": label, # Use original label from LLM
|
240
|
+
"confidence": confidence_score,
|
241
|
+
"normalized_class": normalized_class,
|
242
|
+
"source": "layout",
|
243
|
+
"model": "gemini", # Keep model name generic as gemini
|
244
|
+
})
|
245
|
+
|
246
|
+
self.logger.info(
|
247
|
+
f"Gemini (via OpenAI lib) processed response. Detected {len(detections)} layout elements matching criteria."
|
248
|
+
)
|
249
|
+
|
250
|
+
except Exception as e:
|
251
|
+
# Catch potential OpenAI API errors or other issues
|
252
|
+
self.logger.error(f"Error during Gemini detection (via OpenAI lib): {e}", exc_info=True)
|
253
|
+
return []
|
254
|
+
|
255
|
+
return detections
|
256
|
+
|
257
|
+
def _normalize_class_name(self, name: str) -> str:
|
258
|
+
"""Normalizes class names for filtering (lowercase, hyphenated)."""
|
259
|
+
return super()._normalize_class_name(name)
|
260
|
+
|
261
|
+
def validate_classes(self, classes: List[str]):
|
262
|
+
"""Validation is less critical as we pass requested classes to the LLM."""
|
263
|
+
pass # Override base validation if needed, but likely not necessary
|
264
|
+
|