natural-pdf 0.1.5__py3-none-any.whl → 0.1.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/finetuning/index.md +176 -0
- docs/ocr/index.md +34 -47
- docs/tutorials/01-loading-and-extraction.ipynb +34 -1536
- docs/tutorials/02-finding-elements.ipynb +42 -42
- docs/tutorials/03-extracting-blocks.ipynb +17 -17
- docs/tutorials/04-table-extraction.ipynb +12 -12
- docs/tutorials/05-excluding-content.ipynb +30 -30
- docs/tutorials/06-document-qa.ipynb +28 -28
- docs/tutorials/07-layout-analysis.ipynb +63 -35
- docs/tutorials/07-working-with-regions.ipynb +55 -51
- docs/tutorials/07-working-with-regions.md +2 -2
- docs/tutorials/08-spatial-navigation.ipynb +60 -60
- docs/tutorials/09-section-extraction.ipynb +113 -113
- docs/tutorials/10-form-field-extraction.ipynb +78 -50
- docs/tutorials/11-enhanced-table-processing.ipynb +6 -6
- docs/tutorials/12-ocr-integration.ipynb +149 -131
- docs/tutorials/12-ocr-integration.md +0 -13
- docs/tutorials/13-semantic-search.ipynb +313 -873
- natural_pdf/__init__.py +21 -22
- natural_pdf/analyzers/layout/gemini.py +280 -0
- natural_pdf/analyzers/layout/layout_manager.py +28 -1
- natural_pdf/analyzers/layout/layout_options.py +11 -0
- natural_pdf/analyzers/layout/yolo.py +6 -2
- natural_pdf/collections/pdf_collection.py +24 -0
- natural_pdf/core/element_manager.py +18 -13
- natural_pdf/core/page.py +174 -36
- natural_pdf/core/pdf.py +156 -42
- natural_pdf/elements/base.py +9 -17
- natural_pdf/elements/collections.py +99 -38
- natural_pdf/elements/region.py +77 -37
- natural_pdf/elements/text.py +5 -0
- natural_pdf/exporters/__init__.py +4 -0
- natural_pdf/exporters/base.py +61 -0
- natural_pdf/exporters/paddleocr.py +345 -0
- natural_pdf/ocr/__init__.py +57 -36
- natural_pdf/ocr/engine.py +160 -49
- natural_pdf/ocr/engine_easyocr.py +178 -157
- natural_pdf/ocr/engine_paddle.py +114 -189
- natural_pdf/ocr/engine_surya.py +87 -144
- natural_pdf/ocr/ocr_factory.py +125 -0
- natural_pdf/ocr/ocr_manager.py +65 -89
- natural_pdf/ocr/ocr_options.py +8 -13
- natural_pdf/ocr/utils.py +113 -0
- natural_pdf/templates/finetune/fine_tune_paddleocr.md +415 -0
- natural_pdf/templates/spa/css/style.css +334 -0
- natural_pdf/templates/spa/index.html +31 -0
- natural_pdf/templates/spa/js/app.js +472 -0
- natural_pdf/templates/spa/words.txt +235976 -0
- natural_pdf/utils/debug.py +34 -0
- natural_pdf/utils/identifiers.py +33 -0
- natural_pdf/utils/packaging.py +485 -0
- natural_pdf/utils/text_extraction.py +44 -64
- natural_pdf/utils/visualization.py +1 -1
- {natural_pdf-0.1.5.dist-info → natural_pdf-0.1.7.dist-info}/METADATA +44 -20
- {natural_pdf-0.1.5.dist-info → natural_pdf-0.1.7.dist-info}/RECORD +58 -47
- {natural_pdf-0.1.5.dist-info → natural_pdf-0.1.7.dist-info}/WHEEL +1 -1
- {natural_pdf-0.1.5.dist-info → natural_pdf-0.1.7.dist-info}/top_level.txt +0 -1
- natural_pdf/templates/ocr_debug.html +0 -517
- tests/test_loading.py +0 -50
- tests/test_optional_deps.py +0 -298
- {natural_pdf-0.1.5.dist-info → natural_pdf-0.1.7.dist-info}/licenses/LICENSE +0 -0
natural_pdf/__init__.py
CHANGED
@@ -12,17 +12,16 @@ logger = logging.getLogger("natural_pdf")
|
|
12
12
|
logger.addHandler(logging.NullHandler())
|
13
13
|
|
14
14
|
|
15
|
-
# Utility function for users to easily configure logging
|
16
15
|
def configure_logging(level=logging.INFO, handler=None):
|
17
|
-
"""Configure
|
16
|
+
"""Configure logging for the natural_pdf package.
|
18
17
|
|
19
18
|
Args:
|
20
|
-
level:
|
21
|
-
handler:
|
19
|
+
level: Logging level (e.g., logging.INFO, logging.DEBUG)
|
20
|
+
handler: Optional custom handler. Defaults to a StreamHandler.
|
22
21
|
"""
|
23
|
-
#
|
24
|
-
if
|
25
|
-
|
22
|
+
# Avoid adding duplicate handlers
|
23
|
+
if any(isinstance(h, logging.StreamHandler) for h in logger.handlers):
|
24
|
+
return
|
26
25
|
|
27
26
|
if handler is None:
|
28
27
|
handler = logging.StreamHandler()
|
@@ -32,10 +31,7 @@ def configure_logging(level=logging.INFO, handler=None):
|
|
32
31
|
logger.addHandler(handler)
|
33
32
|
logger.setLevel(level)
|
34
33
|
|
35
|
-
|
36
|
-
for name in logging.root.manager.loggerDict:
|
37
|
-
if name.startswith("natural_pdf."):
|
38
|
-
logging.getLogger(name).setLevel(level)
|
34
|
+
logger.propagate = False
|
39
35
|
|
40
36
|
|
41
37
|
from natural_pdf.core.page import Page
|
@@ -53,18 +49,21 @@ except ImportError:
|
|
53
49
|
|
54
50
|
__version__ = "0.1.1"
|
55
51
|
|
52
|
+
__all__ = [
|
53
|
+
"PDF",
|
54
|
+
"PDFCollection",
|
55
|
+
"Page",
|
56
|
+
"Region",
|
57
|
+
"ElementCollection",
|
58
|
+
"TextSearchOptions",
|
59
|
+
"MultiModalSearchOptions",
|
60
|
+
"BaseSearchOptions",
|
61
|
+
"configure_logging",
|
62
|
+
]
|
63
|
+
|
56
64
|
if HAS_QA:
|
57
|
-
__all__
|
58
|
-
|
59
|
-
"Page",
|
60
|
-
"Region",
|
61
|
-
"ElementCollection",
|
62
|
-
"configure_logging",
|
63
|
-
"DocumentQA",
|
64
|
-
"get_qa_engine",
|
65
|
-
]
|
66
|
-
else:
|
67
|
-
__all__ = ["PDF", "Page", "Region", "ElementCollection", "configure_logging"]
|
65
|
+
__all__.extend(["DocumentQA", "get_qa_engine"])
|
66
|
+
|
68
67
|
|
69
68
|
from .collections.pdf_collection import PDFCollection
|
70
69
|
|
@@ -0,0 +1,280 @@
|
|
1
|
+
# layout_detector_gemini.py
|
2
|
+
import importlib.util
|
3
|
+
import logging
|
4
|
+
import os
|
5
|
+
from typing import Any, Dict, List, Optional
|
6
|
+
import base64
|
7
|
+
import io
|
8
|
+
|
9
|
+
from pydantic import BaseModel, Field
|
10
|
+
from PIL import Image
|
11
|
+
|
12
|
+
# Use OpenAI library for interaction
|
13
|
+
try:
|
14
|
+
from openai import OpenAI
|
15
|
+
from openai.types.chat import ChatCompletion
|
16
|
+
|
17
|
+
# Import OpenAIError for exception handling if needed
|
18
|
+
except ImportError:
|
19
|
+
OpenAI = None
|
20
|
+
ChatCompletion = None
|
21
|
+
|
22
|
+
try:
|
23
|
+
from .base import LayoutDetector
|
24
|
+
from .layout_options import BaseLayoutOptions, GeminiLayoutOptions
|
25
|
+
except ImportError:
|
26
|
+
# Placeholders if run standalone or imports fail
|
27
|
+
class BaseLayoutOptions:
|
28
|
+
pass
|
29
|
+
|
30
|
+
class GeminiLayoutOptions(BaseLayoutOptions):
|
31
|
+
pass
|
32
|
+
|
33
|
+
class LayoutDetector:
|
34
|
+
def __init__(self):
|
35
|
+
self.logger = logging.getLogger()
|
36
|
+
self.supported_classes = set() # Will be dynamic based on user request
|
37
|
+
|
38
|
+
def _get_model(self, options):
|
39
|
+
raise NotImplementedError
|
40
|
+
|
41
|
+
def _normalize_class_name(self, n):
|
42
|
+
return n.lower().replace("_", "-").replace(" ", "-")
|
43
|
+
|
44
|
+
def validate_classes(self, c):
|
45
|
+
pass # Less strict validation needed for LLM
|
46
|
+
|
47
|
+
logging.basicConfig()
|
48
|
+
|
49
|
+
logger = logging.getLogger(__name__)
|
50
|
+
|
51
|
+
|
52
|
+
# Define Pydantic model for the expected output structure
|
53
|
+
# This is used by the openai library's `response_format`
|
54
|
+
class DetectedRegion(BaseModel):
|
55
|
+
label: str = Field(description="The identified class name.")
|
56
|
+
bbox: List[float] = Field(
|
57
|
+
description="Bounding box coordinates [xmin, ymin, xmax, ymax].", min_items=4, max_items=4
|
58
|
+
)
|
59
|
+
confidence: float = Field(description="Confidence score [0.0, 1.0].", ge=0.0, le=1.0)
|
60
|
+
|
61
|
+
|
62
|
+
class GeminiLayoutDetector(LayoutDetector):
|
63
|
+
"""Document layout detector using Google's Gemini models via OpenAI compatibility layer."""
|
64
|
+
|
65
|
+
# Base URL for the Gemini OpenAI-compatible endpoint
|
66
|
+
GEMINI_BASE_URL = "https://generativelanguage.googleapis.com/v1beta/openai/"
|
67
|
+
|
68
|
+
def __init__(self):
|
69
|
+
super().__init__()
|
70
|
+
self.supported_classes = set() # Indicate dynamic nature
|
71
|
+
|
72
|
+
def is_available(self) -> bool:
|
73
|
+
"""Check if openai library is installed and GOOGLE_API_KEY is available."""
|
74
|
+
api_key = os.environ.get("GOOGLE_API_KEY")
|
75
|
+
if not api_key:
|
76
|
+
logger.warning(
|
77
|
+
"GOOGLE_API_KEY environment variable not set. Gemini detector (via OpenAI lib) will not be available."
|
78
|
+
)
|
79
|
+
return False
|
80
|
+
if OpenAI is None:
|
81
|
+
logger.warning(
|
82
|
+
"openai package not found. Gemini detector (via OpenAI lib) will not be available."
|
83
|
+
)
|
84
|
+
return False
|
85
|
+
return True
|
86
|
+
|
87
|
+
def _get_cache_key(self, options: GeminiLayoutOptions) -> str:
|
88
|
+
"""Generate cache key based on model name."""
|
89
|
+
if not isinstance(options, GeminiLayoutOptions):
|
90
|
+
options = GeminiLayoutOptions() # Use defaults
|
91
|
+
|
92
|
+
model_key = options.model_name
|
93
|
+
# Prompt is built dynamically, so not part of cache key based on options
|
94
|
+
return f"{self.__class__.__name__}_{model_key}"
|
95
|
+
|
96
|
+
def _load_model_from_options(self, options: GeminiLayoutOptions) -> Any:
|
97
|
+
"""Validate options and return the model name."""
|
98
|
+
if not self.is_available():
|
99
|
+
raise RuntimeError(
|
100
|
+
"OpenAI library not installed or GOOGLE_API_KEY not set. Please run: pip install openai"
|
101
|
+
)
|
102
|
+
|
103
|
+
if not isinstance(options, GeminiLayoutOptions):
|
104
|
+
raise TypeError("Incorrect options type provided for Gemini model loading.")
|
105
|
+
|
106
|
+
# Simply return the model name, client is created in detect()
|
107
|
+
return options.model_name
|
108
|
+
|
109
|
+
def detect(self, image: Image.Image, options: BaseLayoutOptions) -> List[Dict[str, Any]]:
|
110
|
+
"""Detect layout elements in an image using Gemini via OpenAI library."""
|
111
|
+
if not self.is_available():
|
112
|
+
raise RuntimeError("OpenAI library not installed or GOOGLE_API_KEY not set.")
|
113
|
+
|
114
|
+
# Ensure options are the correct type
|
115
|
+
if not isinstance(options, GeminiLayoutOptions):
|
116
|
+
self.logger.warning(
|
117
|
+
"Received BaseLayoutOptions, expected GeminiLayoutOptions. Using defaults."
|
118
|
+
)
|
119
|
+
options = GeminiLayoutOptions(
|
120
|
+
confidence=options.confidence,
|
121
|
+
classes=options.classes,
|
122
|
+
exclude_classes=options.exclude_classes,
|
123
|
+
device=options.device,
|
124
|
+
extra_args=options.extra_args,
|
125
|
+
)
|
126
|
+
|
127
|
+
model_name = self._get_model(options)
|
128
|
+
api_key = os.environ.get("GOOGLE_API_KEY")
|
129
|
+
|
130
|
+
detections = []
|
131
|
+
try:
|
132
|
+
# --- 1. Initialize OpenAI Client for Gemini ---
|
133
|
+
client = OpenAI(api_key=api_key, base_url=self.GEMINI_BASE_URL)
|
134
|
+
|
135
|
+
# --- 2. Prepare Input for OpenAI API ---
|
136
|
+
if not options.classes:
|
137
|
+
logger.error("Gemini layout detection requires a list of classes to find.")
|
138
|
+
return []
|
139
|
+
|
140
|
+
width, height = image.size
|
141
|
+
|
142
|
+
# Convert image to base64
|
143
|
+
buffered = io.BytesIO()
|
144
|
+
image.save(buffered, format="PNG")
|
145
|
+
img_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
146
|
+
image_url = f"data:image/png;base64,{img_base64}"
|
147
|
+
|
148
|
+
# Construct the prompt text
|
149
|
+
class_list_str = ", ".join(f"`{c}`" for c in options.classes)
|
150
|
+
prompt_text = (
|
151
|
+
f"Analyze the provided image of a document page ({width}x{height}). "
|
152
|
+
f"Identify all regions corresponding to the following types: {class_list_str}. "
|
153
|
+
f"Return ONLY the structured data requested."
|
154
|
+
)
|
155
|
+
|
156
|
+
# Prepare messages for chat completions endpoint
|
157
|
+
messages = [
|
158
|
+
{
|
159
|
+
"role": "user",
|
160
|
+
"content": [
|
161
|
+
{"type": "text", "text": prompt_text},
|
162
|
+
{
|
163
|
+
"type": "image_url",
|
164
|
+
"image_url": {"url": image_url},
|
165
|
+
},
|
166
|
+
],
|
167
|
+
}
|
168
|
+
]
|
169
|
+
|
170
|
+
# --- 3. Call OpenAI API using .parse for structured output ---
|
171
|
+
logger.debug(
|
172
|
+
f"Running Gemini detection via OpenAI lib (Model: {model_name}). Asking for classes: {options.classes}"
|
173
|
+
)
|
174
|
+
|
175
|
+
# Extract relevant generation parameters from extra_args if provided
|
176
|
+
# Mapping common names: temperature, top_p, max_tokens
|
177
|
+
completion_kwargs = {
|
178
|
+
"temperature": options.extra_args.get("temperature", 0.2), # Default to low temp
|
179
|
+
"top_p": options.extra_args.get("top_p"),
|
180
|
+
"max_tokens": options.extra_args.get(
|
181
|
+
"max_tokens", 4096
|
182
|
+
), # Map from max_output_tokens
|
183
|
+
}
|
184
|
+
# Filter out None values
|
185
|
+
completion_kwargs = {k: v for k, v in completion_kwargs.items() if v is not None}
|
186
|
+
|
187
|
+
completion: ChatCompletion = client.beta.chat.completions.parse(
|
188
|
+
model=model_name,
|
189
|
+
messages=messages,
|
190
|
+
response_format=List[DetectedRegion], # Pass the Pydantic model list
|
191
|
+
**completion_kwargs,
|
192
|
+
)
|
193
|
+
|
194
|
+
logger.debug(f"Gemini response received via OpenAI lib.")
|
195
|
+
|
196
|
+
# --- 4. Process Parsed Response ---
|
197
|
+
if not completion.choices:
|
198
|
+
logger.error("Gemini response (via OpenAI lib) contained no choices.")
|
199
|
+
return []
|
200
|
+
|
201
|
+
# Get the parsed Pydantic objects
|
202
|
+
parsed_results = completion.choices[0].message.parsed
|
203
|
+
if not parsed_results or not isinstance(parsed_results, list):
|
204
|
+
logger.error(
|
205
|
+
f"Gemini response (via OpenAI lib) did not contain a valid list of parsed regions. Found: {type(parsed_results)}"
|
206
|
+
)
|
207
|
+
return []
|
208
|
+
|
209
|
+
# --- 5. Convert to Detections & Filter ---
|
210
|
+
normalized_classes_req = {self._normalize_class_name(c) for c in options.classes}
|
211
|
+
normalized_classes_excl = (
|
212
|
+
{self._normalize_class_name(c) for c in options.exclude_classes}
|
213
|
+
if options.exclude_classes
|
214
|
+
else set()
|
215
|
+
)
|
216
|
+
|
217
|
+
for item in parsed_results:
|
218
|
+
# The item is already a validated DetectedRegion Pydantic object
|
219
|
+
# Access fields directly
|
220
|
+
label = item.label
|
221
|
+
bbox_raw = item.bbox
|
222
|
+
confidence_score = item.confidence
|
223
|
+
|
224
|
+
# Coordinates should already be floats, but ensure tuple format
|
225
|
+
xmin, ymin, xmax, ymax = tuple(bbox_raw)
|
226
|
+
|
227
|
+
# --- Apply Filtering ---
|
228
|
+
normalized_class = self._normalize_class_name(label)
|
229
|
+
|
230
|
+
# Check against requested classes (Should be guaranteed by schema, but doesn't hurt)
|
231
|
+
if normalized_class not in normalized_classes_req:
|
232
|
+
logger.warning(
|
233
|
+
f"Gemini (via OpenAI) returned unexpected class '{label}' despite schema. Skipping."
|
234
|
+
)
|
235
|
+
continue
|
236
|
+
|
237
|
+
# Check against excluded classes
|
238
|
+
if normalized_class in normalized_classes_excl:
|
239
|
+
logger.debug(
|
240
|
+
f"Skipping excluded class '{label}' (normalized: {normalized_class})."
|
241
|
+
)
|
242
|
+
continue
|
243
|
+
|
244
|
+
# Check against base confidence threshold from options
|
245
|
+
if confidence_score < options.confidence:
|
246
|
+
logger.debug(
|
247
|
+
f"Skipping item with confidence {confidence_score:.3f} below threshold {options.confidence}."
|
248
|
+
)
|
249
|
+
continue
|
250
|
+
|
251
|
+
# Add detection
|
252
|
+
detections.append(
|
253
|
+
{
|
254
|
+
"bbox": (xmin, ymin, xmax, ymax),
|
255
|
+
"class": label, # Use original label from LLM
|
256
|
+
"confidence": confidence_score,
|
257
|
+
"normalized_class": normalized_class,
|
258
|
+
"source": "layout",
|
259
|
+
"model": "gemini", # Keep model name generic as gemini
|
260
|
+
}
|
261
|
+
)
|
262
|
+
|
263
|
+
self.logger.info(
|
264
|
+
f"Gemini (via OpenAI lib) processed response. Detected {len(detections)} layout elements matching criteria."
|
265
|
+
)
|
266
|
+
|
267
|
+
except Exception as e:
|
268
|
+
# Catch potential OpenAI API errors or other issues
|
269
|
+
self.logger.error(f"Error during Gemini detection (via OpenAI lib): {e}", exc_info=True)
|
270
|
+
return []
|
271
|
+
|
272
|
+
return detections
|
273
|
+
|
274
|
+
def _normalize_class_name(self, name: str) -> str:
|
275
|
+
"""Normalizes class names for filtering (lowercase, hyphenated)."""
|
276
|
+
return super()._normalize_class_name(name)
|
277
|
+
|
278
|
+
def validate_classes(self, classes: List[str]):
|
279
|
+
"""Validation is less critical as we pass requested classes to the LLM."""
|
280
|
+
pass # Override base validation if needed, but likely not necessary
|
@@ -37,9 +37,15 @@ try:
|
|
37
37
|
except ImportError:
|
38
38
|
DoclingLayoutDetector = None
|
39
39
|
|
40
|
+
try:
|
41
|
+
from .gemini import GeminiLayoutDetector
|
42
|
+
except ImportError:
|
43
|
+
GeminiLayoutDetector = None
|
44
|
+
|
40
45
|
from .layout_options import (
|
41
46
|
BaseLayoutOptions,
|
42
47
|
DoclingLayoutOptions,
|
48
|
+
GeminiLayoutOptions,
|
43
49
|
LayoutOptions,
|
44
50
|
PaddleLayoutOptions,
|
45
51
|
SuryaLayoutOptions,
|
@@ -83,6 +89,13 @@ class LayoutManager:
|
|
83
89
|
"options_class": DoclingLayoutOptions,
|
84
90
|
}
|
85
91
|
|
92
|
+
# Add Gemini entry if available
|
93
|
+
if GeminiLayoutDetector:
|
94
|
+
ENGINE_REGISTRY["gemini"] = {
|
95
|
+
"class": GeminiLayoutDetector,
|
96
|
+
"options_class": GeminiLayoutOptions,
|
97
|
+
}
|
98
|
+
|
86
99
|
# Define the limited set of kwargs allowed for the simple analyze_layout call
|
87
100
|
SIMPLE_MODE_ALLOWED_KWARGS = {"engine", "confidence", "classes", "exclude_classes", "device"}
|
88
101
|
|
@@ -108,8 +121,22 @@ class LayoutManager:
|
|
108
121
|
detector_instance = engine_class() # Instantiate
|
109
122
|
if not detector_instance.is_available():
|
110
123
|
# Check availability before storing
|
124
|
+
# Construct helpful error message with install hint
|
125
|
+
install_hint = ""
|
126
|
+
if engine_name == "yolo":
|
127
|
+
install_hint = "pip install 'natural-pdf[layout_yolo]'"
|
128
|
+
elif engine_name == "tatr":
|
129
|
+
install_hint = "pip install 'natural-pdf[core-ml]'"
|
130
|
+
elif engine_name == "paddle":
|
131
|
+
install_hint = "pip install 'natural-pdf[paddle]'"
|
132
|
+
elif engine_name == "surya":
|
133
|
+
install_hint = "pip install 'natural-pdf[surya]'"
|
134
|
+
# Add other engines like docling if they become optional extras
|
135
|
+
else:
|
136
|
+
install_hint = f"(Check installation requirements for {engine_name})"
|
137
|
+
|
111
138
|
raise RuntimeError(
|
112
|
-
f"Layout engine '{engine_name}' is not available. Please
|
139
|
+
f"Layout engine '{engine_name}' is not available. Please install the required dependencies: {install_hint}"
|
113
140
|
)
|
114
141
|
self._detector_instances[engine_name] = detector_instance # Store if available
|
115
142
|
|
@@ -80,6 +80,16 @@ class DoclingLayoutOptions(BaseLayoutOptions):
|
|
80
80
|
# Other kwargs like 'device', 'batch_size' can go in extra_args
|
81
81
|
|
82
82
|
|
83
|
+
# --- Gemini Specific Options ---
|
84
|
+
@dataclass
|
85
|
+
class GeminiLayoutOptions(BaseLayoutOptions):
|
86
|
+
"""Options specific to Gemini-based layout detection (using OpenAI compatibility)."""
|
87
|
+
|
88
|
+
model_name: str = "gemini-2.0-flash"
|
89
|
+
# Removed: prompt_template, temperature, top_p, max_output_tokens
|
90
|
+
# These are typically passed directly to the chat completion call or via extra_args
|
91
|
+
|
92
|
+
|
83
93
|
# --- Union Type ---
|
84
94
|
LayoutOptions = Union[
|
85
95
|
YOLOLayoutOptions,
|
@@ -87,5 +97,6 @@ LayoutOptions = Union[
|
|
87
97
|
PaddleLayoutOptions,
|
88
98
|
SuryaLayoutOptions,
|
89
99
|
DoclingLayoutOptions,
|
100
|
+
GeminiLayoutOptions,
|
90
101
|
BaseLayoutOptions, # Include base for typing flexibility
|
91
102
|
]
|
@@ -91,7 +91,9 @@ class YOLODocLayoutDetector(LayoutDetector):
|
|
91
91
|
def _load_model_from_options(self, options: YOLOLayoutOptions) -> Any:
|
92
92
|
"""Load the YOLOv10 model based on options."""
|
93
93
|
if not self.is_available():
|
94
|
-
raise RuntimeError(
|
94
|
+
raise RuntimeError(
|
95
|
+
"YOLO dependencies not installed. Please run: pip install 'natural-pdf[layout_yolo]'"
|
96
|
+
)
|
95
97
|
self.logger.info(f"Loading YOLO model: {options.model_repo}/{options.model_file}")
|
96
98
|
try:
|
97
99
|
model_path = hf_hub_download(repo_id=options.model_repo, filename=options.model_file)
|
@@ -105,7 +107,9 @@ class YOLODocLayoutDetector(LayoutDetector):
|
|
105
107
|
def detect(self, image: Image.Image, options: BaseLayoutOptions) -> List[Dict[str, Any]]:
|
106
108
|
"""Detect layout elements in an image using YOLO."""
|
107
109
|
if not self.is_available():
|
108
|
-
raise RuntimeError(
|
110
|
+
raise RuntimeError(
|
111
|
+
"YOLO dependencies not installed. Please run: pip install 'natural-pdf[layout_yolo]'"
|
112
|
+
)
|
109
113
|
|
110
114
|
# Ensure options are the correct type, falling back to defaults if base type passed
|
111
115
|
if not isinstance(options, YOLOLayoutOptions):
|
@@ -267,6 +267,30 @@ class PDFCollection(SearchableMixin): # Inherit from the mixin
|
|
267
267
|
# Implementation requires integrating with classification models or logic
|
268
268
|
raise NotImplementedError("categorize requires classification implementation.")
|
269
269
|
|
270
|
+
def export_ocr_correction_task(self, output_zip_path: str, **kwargs):
|
271
|
+
"""
|
272
|
+
Exports OCR results from all PDFs in this collection into a single
|
273
|
+
correction task package (zip file).
|
274
|
+
|
275
|
+
Args:
|
276
|
+
output_zip_path: The path to save the output zip file.
|
277
|
+
**kwargs: Additional arguments passed to create_correction_task_package
|
278
|
+
(e.g., image_render_scale, overwrite).
|
279
|
+
"""
|
280
|
+
try:
|
281
|
+
from natural_pdf.utils.packaging import create_correction_task_package
|
282
|
+
|
283
|
+
# Pass the collection itself (self) as the source
|
284
|
+
create_correction_task_package(source=self, output_zip_path=output_zip_path, **kwargs)
|
285
|
+
except ImportError:
|
286
|
+
logger.error(
|
287
|
+
"Failed to import 'create_correction_task_package'. Packaging utility might be missing."
|
288
|
+
)
|
289
|
+
# Or raise
|
290
|
+
except Exception as e:
|
291
|
+
logger.error(f"Failed to export correction task for collection: {e}", exc_info=True)
|
292
|
+
raise # Re-raise the exception from the utility function
|
293
|
+
|
270
294
|
# --- Mixin Required Implementation ---
|
271
295
|
def get_indexable_items(self) -> Iterable[Indexable]:
|
272
296
|
"""Yields Page objects from the collection, conforming to Indexable."""
|
@@ -312,6 +312,7 @@ class ElementManager:
|
|
312
312
|
|
313
313
|
Args:
|
314
314
|
ocr_results: List of OCR results dictionaries with 'text', 'bbox', 'confidence'.
|
315
|
+
Confidence can be None for detection-only results.
|
315
316
|
scale_x: Factor to convert image x-coordinates to PDF coordinates.
|
316
317
|
scale_y: Factor to convert image y-coordinates to PDF coordinates.
|
317
318
|
|
@@ -356,9 +357,16 @@ class ElementManager:
|
|
356
357
|
pdf_bottom = bottom_img * scale_y
|
357
358
|
pdf_height = (bottom_img - top_img) * scale_y
|
358
359
|
|
360
|
+
# Handle potential None confidence
|
361
|
+
raw_confidence = result.get("confidence")
|
362
|
+
confidence_value = (
|
363
|
+
float(raw_confidence) if raw_confidence is not None else None
|
364
|
+
) # Keep None if it was None
|
365
|
+
ocr_text = result.get("text") # Get text, will be None if detect_only
|
366
|
+
|
359
367
|
# Create the TextElement for the word
|
360
368
|
word_element_data = {
|
361
|
-
"text":
|
369
|
+
"text": ocr_text,
|
362
370
|
"x0": pdf_x0,
|
363
371
|
"top": pdf_top,
|
364
372
|
"x1": pdf_x1,
|
@@ -367,7 +375,7 @@ class ElementManager:
|
|
367
375
|
"height": pdf_height,
|
368
376
|
"object_type": "word", # Treat OCR results as whole words
|
369
377
|
"source": "ocr",
|
370
|
-
"confidence":
|
378
|
+
"confidence": confidence_value, # Use the handled confidence
|
371
379
|
"fontname": "OCR", # Use consistent OCR fontname
|
372
380
|
"size": (
|
373
381
|
round(pdf_height) if pdf_height > 0 else 10.0
|
@@ -385,7 +393,7 @@ class ElementManager:
|
|
385
393
|
ocr_char_dict.setdefault("adv", ocr_char_dict.get("width", 0))
|
386
394
|
|
387
395
|
# Add the char dict list to the word data before creating TextElement
|
388
|
-
word_element_data["_char_dicts"] = [ocr_char_dict]
|
396
|
+
word_element_data["_char_dicts"] = [ocr_char_dict] # Store itself as its only char
|
389
397
|
|
390
398
|
word_elem = TextElement(word_element_data, self._page)
|
391
399
|
added_word_elements.append(word_elem)
|
@@ -393,16 +401,13 @@ class ElementManager:
|
|
393
401
|
# Append the word element to the manager's list
|
394
402
|
self._elements["words"].append(word_elem)
|
395
403
|
|
396
|
-
#
|
397
|
-
|
398
|
-
|
399
|
-
|
400
|
-
|
401
|
-
|
402
|
-
|
403
|
-
char_dict_data.setdefault("adv", char_dict_data.get("width", 0))
|
404
|
-
|
405
|
-
self._elements["chars"].append(char_dict_data) # Append the dictionary
|
404
|
+
# Only add a representative char dict if text actually exists
|
405
|
+
if ocr_text is not None:
|
406
|
+
# This char dict represents the entire OCR word as a single 'char'.
|
407
|
+
char_dict_data = ocr_char_dict # Use the one we already created
|
408
|
+
char_dict_data["object_type"] = "char" # Mark as char type
|
409
|
+
char_dict_data.setdefault("adv", char_dict_data.get("width", 0))
|
410
|
+
self._elements["chars"].append(char_dict_data) # Append the dictionary
|
406
411
|
|
407
412
|
except (KeyError, ValueError, TypeError) as e:
|
408
413
|
logger.error(f"Failed to process OCR result: {result}. Error: {e}", exc_info=True)
|