natural-pdf 0.1.6__py3-none-any.whl → 0.1.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/categorizing-documents/index.md +168 -0
- docs/data-extraction/index.md +87 -0
- docs/element-selection/index.ipynb +218 -164
- docs/element-selection/index.md +20 -0
- docs/finetuning/index.md +176 -0
- docs/index.md +19 -0
- docs/ocr/index.md +63 -16
- docs/tutorials/01-loading-and-extraction.ipynb +411 -248
- docs/tutorials/02-finding-elements.ipynb +123 -46
- docs/tutorials/03-extracting-blocks.ipynb +24 -19
- docs/tutorials/04-table-extraction.ipynb +17 -12
- docs/tutorials/05-excluding-content.ipynb +37 -32
- docs/tutorials/06-document-qa.ipynb +36 -31
- docs/tutorials/07-layout-analysis.ipynb +45 -40
- docs/tutorials/07-working-with-regions.ipynb +61 -60
- docs/tutorials/08-spatial-navigation.ipynb +76 -71
- docs/tutorials/09-section-extraction.ipynb +160 -155
- docs/tutorials/10-form-field-extraction.ipynb +71 -66
- docs/tutorials/11-enhanced-table-processing.ipynb +11 -6
- docs/tutorials/12-ocr-integration.ipynb +3420 -312
- docs/tutorials/12-ocr-integration.md +68 -106
- docs/tutorials/13-semantic-search.ipynb +641 -251
- natural_pdf/__init__.py +3 -0
- natural_pdf/analyzers/layout/gemini.py +63 -47
- natural_pdf/classification/manager.py +343 -0
- natural_pdf/classification/mixin.py +149 -0
- natural_pdf/classification/results.py +62 -0
- natural_pdf/collections/mixins.py +63 -0
- natural_pdf/collections/pdf_collection.py +326 -17
- natural_pdf/core/element_manager.py +73 -4
- natural_pdf/core/page.py +255 -83
- natural_pdf/core/pdf.py +385 -367
- natural_pdf/elements/base.py +1 -3
- natural_pdf/elements/collections.py +279 -49
- natural_pdf/elements/region.py +106 -21
- natural_pdf/elements/text.py +5 -2
- natural_pdf/exporters/__init__.py +4 -0
- natural_pdf/exporters/base.py +61 -0
- natural_pdf/exporters/paddleocr.py +345 -0
- natural_pdf/extraction/manager.py +134 -0
- natural_pdf/extraction/mixin.py +246 -0
- natural_pdf/extraction/result.py +37 -0
- natural_pdf/ocr/__init__.py +16 -8
- natural_pdf/ocr/engine.py +46 -30
- natural_pdf/ocr/engine_easyocr.py +86 -42
- natural_pdf/ocr/engine_paddle.py +39 -28
- natural_pdf/ocr/engine_surya.py +32 -16
- natural_pdf/ocr/ocr_factory.py +34 -23
- natural_pdf/ocr/ocr_manager.py +98 -34
- natural_pdf/ocr/ocr_options.py +38 -10
- natural_pdf/ocr/utils.py +59 -33
- natural_pdf/qa/document_qa.py +0 -4
- natural_pdf/selectors/parser.py +363 -238
- natural_pdf/templates/finetune/fine_tune_paddleocr.md +420 -0
- natural_pdf/utils/debug.py +4 -2
- natural_pdf/utils/identifiers.py +9 -5
- natural_pdf/utils/locks.py +8 -0
- natural_pdf/utils/packaging.py +172 -105
- natural_pdf/utils/text_extraction.py +96 -65
- natural_pdf/utils/tqdm_utils.py +43 -0
- natural_pdf/utils/visualization.py +1 -1
- {natural_pdf-0.1.6.dist-info → natural_pdf-0.1.8.dist-info}/METADATA +10 -3
- {natural_pdf-0.1.6.dist-info → natural_pdf-0.1.8.dist-info}/RECORD +66 -51
- {natural_pdf-0.1.6.dist-info → natural_pdf-0.1.8.dist-info}/WHEEL +1 -1
- {natural_pdf-0.1.6.dist-info → natural_pdf-0.1.8.dist-info}/licenses/LICENSE +0 -0
- {natural_pdf-0.1.6.dist-info → natural_pdf-0.1.8.dist-info}/top_level.txt +0 -0
natural_pdf/__init__.py
CHANGED
@@ -3,6 +3,8 @@ Natural PDF - A more intuitive interface for working with PDFs.
|
|
3
3
|
"""
|
4
4
|
|
5
5
|
import logging
|
6
|
+
import os
|
7
|
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
6
8
|
|
7
9
|
# Create library logger
|
8
10
|
logger = logging.getLogger("natural_pdf")
|
@@ -33,6 +35,7 @@ def configure_logging(level=logging.INFO, handler=None):
|
|
33
35
|
|
34
36
|
logger.propagate = False
|
35
37
|
|
38
|
+
|
36
39
|
from natural_pdf.core.page import Page
|
37
40
|
from natural_pdf.core.pdf import PDF
|
38
41
|
from natural_pdf.elements.collections import ElementCollection
|
@@ -13,6 +13,7 @@ from PIL import Image
|
|
13
13
|
try:
|
14
14
|
from openai import OpenAI
|
15
15
|
from openai.types.chat import ChatCompletion
|
16
|
+
|
16
17
|
# Import OpenAIError for exception handling if needed
|
17
18
|
except ImportError:
|
18
19
|
OpenAI = None
|
@@ -32,7 +33,7 @@ except ImportError:
|
|
32
33
|
class LayoutDetector:
|
33
34
|
def __init__(self):
|
34
35
|
self.logger = logging.getLogger()
|
35
|
-
self.supported_classes = set()
|
36
|
+
self.supported_classes = set() # Will be dynamic based on user request
|
36
37
|
|
37
38
|
def _get_model(self, options):
|
38
39
|
raise NotImplementedError
|
@@ -41,17 +42,20 @@ except ImportError:
|
|
41
42
|
return n.lower().replace("_", "-").replace(" ", "-")
|
42
43
|
|
43
44
|
def validate_classes(self, c):
|
44
|
-
pass
|
45
|
+
pass # Less strict validation needed for LLM
|
45
46
|
|
46
47
|
logging.basicConfig()
|
47
48
|
|
48
49
|
logger = logging.getLogger(__name__)
|
49
50
|
|
51
|
+
|
50
52
|
# Define Pydantic model for the expected output structure
|
51
53
|
# This is used by the openai library's `response_format`
|
52
54
|
class DetectedRegion(BaseModel):
|
53
55
|
label: str = Field(description="The identified class name.")
|
54
|
-
bbox: List[float] = Field(
|
56
|
+
bbox: List[float] = Field(
|
57
|
+
description="Bounding box coordinates [xmin, ymin, xmax, ymax].", min_items=4, max_items=4
|
58
|
+
)
|
55
59
|
confidence: float = Field(description="Confidence score [0.0, 1.0].", ge=0.0, le=1.0)
|
56
60
|
|
57
61
|
|
@@ -63,23 +67,27 @@ class GeminiLayoutDetector(LayoutDetector):
|
|
63
67
|
|
64
68
|
def __init__(self):
|
65
69
|
super().__init__()
|
66
|
-
self.supported_classes = set()
|
70
|
+
self.supported_classes = set() # Indicate dynamic nature
|
67
71
|
|
68
72
|
def is_available(self) -> bool:
|
69
73
|
"""Check if openai library is installed and GOOGLE_API_KEY is available."""
|
70
74
|
api_key = os.environ.get("GOOGLE_API_KEY")
|
71
75
|
if not api_key:
|
72
|
-
logger.warning(
|
76
|
+
logger.warning(
|
77
|
+
"GOOGLE_API_KEY environment variable not set. Gemini detector (via OpenAI lib) will not be available."
|
78
|
+
)
|
73
79
|
return False
|
74
80
|
if OpenAI is None:
|
75
|
-
|
76
|
-
|
81
|
+
logger.warning(
|
82
|
+
"openai package not found. Gemini detector (via OpenAI lib) will not be available."
|
83
|
+
)
|
84
|
+
return False
|
77
85
|
return True
|
78
86
|
|
79
87
|
def _get_cache_key(self, options: GeminiLayoutOptions) -> str:
|
80
88
|
"""Generate cache key based on model name."""
|
81
89
|
if not isinstance(options, GeminiLayoutOptions):
|
82
|
-
options = GeminiLayoutOptions()
|
90
|
+
options = GeminiLayoutOptions() # Use defaults
|
83
91
|
|
84
92
|
model_key = options.model_name
|
85
93
|
# Prompt is built dynamically, so not part of cache key based on options
|
@@ -101,9 +109,7 @@ class GeminiLayoutDetector(LayoutDetector):
|
|
101
109
|
def detect(self, image: Image.Image, options: BaseLayoutOptions) -> List[Dict[str, Any]]:
|
102
110
|
"""Detect layout elements in an image using Gemini via OpenAI library."""
|
103
111
|
if not self.is_available():
|
104
|
-
raise RuntimeError(
|
105
|
-
"OpenAI library not installed or GOOGLE_API_KEY not set."
|
106
|
-
)
|
112
|
+
raise RuntimeError("OpenAI library not installed or GOOGLE_API_KEY not set.")
|
107
113
|
|
108
114
|
# Ensure options are the correct type
|
109
115
|
if not isinstance(options, GeminiLayoutOptions):
|
@@ -124,10 +130,7 @@ class GeminiLayoutDetector(LayoutDetector):
|
|
124
130
|
detections = []
|
125
131
|
try:
|
126
132
|
# --- 1. Initialize OpenAI Client for Gemini ---
|
127
|
-
client = OpenAI(
|
128
|
-
api_key=api_key,
|
129
|
-
base_url=self.GEMINI_BASE_URL
|
130
|
-
)
|
133
|
+
client = OpenAI(api_key=api_key, base_url=self.GEMINI_BASE_URL)
|
131
134
|
|
132
135
|
# --- 2. Prepare Input for OpenAI API ---
|
133
136
|
if not options.classes:
|
@@ -139,11 +142,11 @@ class GeminiLayoutDetector(LayoutDetector):
|
|
139
142
|
# Convert image to base64
|
140
143
|
buffered = io.BytesIO()
|
141
144
|
image.save(buffered, format="PNG")
|
142
|
-
img_base64 = base64.b64encode(buffered.getvalue()).decode(
|
145
|
+
img_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
143
146
|
image_url = f"data:image/png;base64,{img_base64}"
|
144
147
|
|
145
148
|
# Construct the prompt text
|
146
|
-
class_list_str = ", ".join(f
|
149
|
+
class_list_str = ", ".join(f"`{c}`" for c in options.classes)
|
147
150
|
prompt_text = (
|
148
151
|
f"Analyze the provided image of a document page ({width}x{height}). "
|
149
152
|
f"Identify all regions corresponding to the following types: {class_list_str}. "
|
@@ -165,14 +168,18 @@ class GeminiLayoutDetector(LayoutDetector):
|
|
165
168
|
]
|
166
169
|
|
167
170
|
# --- 3. Call OpenAI API using .parse for structured output ---
|
168
|
-
logger.debug(
|
171
|
+
logger.debug(
|
172
|
+
f"Running Gemini detection via OpenAI lib (Model: {model_name}). Asking for classes: {options.classes}"
|
173
|
+
)
|
169
174
|
|
170
175
|
# Extract relevant generation parameters from extra_args if provided
|
171
176
|
# Mapping common names: temperature, top_p, max_tokens
|
172
177
|
completion_kwargs = {
|
173
|
-
"temperature": options.extra_args.get("temperature", 0.2),
|
178
|
+
"temperature": options.extra_args.get("temperature", 0.2), # Default to low temp
|
174
179
|
"top_p": options.extra_args.get("top_p"),
|
175
|
-
"max_tokens": options.extra_args.get(
|
180
|
+
"max_tokens": options.extra_args.get(
|
181
|
+
"max_tokens", 4096
|
182
|
+
), # Map from max_output_tokens
|
176
183
|
}
|
177
184
|
# Filter out None values
|
178
185
|
completion_kwargs = {k: v for k, v in completion_kwargs.items() if v is not None}
|
@@ -180,13 +187,13 @@ class GeminiLayoutDetector(LayoutDetector):
|
|
180
187
|
completion: ChatCompletion = client.beta.chat.completions.parse(
|
181
188
|
model=model_name,
|
182
189
|
messages=messages,
|
183
|
-
response_format=List[DetectedRegion],
|
184
|
-
**completion_kwargs
|
190
|
+
response_format=List[DetectedRegion], # Pass the Pydantic model list
|
191
|
+
**completion_kwargs,
|
185
192
|
)
|
186
193
|
|
187
194
|
logger.debug(f"Gemini response received via OpenAI lib.")
|
188
195
|
|
189
|
-
# --- 4. Process Parsed Response ---
|
196
|
+
# --- 4. Process Parsed Response ---
|
190
197
|
if not completion.choices:
|
191
198
|
logger.error("Gemini response (via OpenAI lib) contained no choices.")
|
192
199
|
return []
|
@@ -194,16 +201,18 @@ class GeminiLayoutDetector(LayoutDetector):
|
|
194
201
|
# Get the parsed Pydantic objects
|
195
202
|
parsed_results = completion.choices[0].message.parsed
|
196
203
|
if not parsed_results or not isinstance(parsed_results, list):
|
197
|
-
|
198
|
-
|
204
|
+
logger.error(
|
205
|
+
f"Gemini response (via OpenAI lib) did not contain a valid list of parsed regions. Found: {type(parsed_results)}"
|
206
|
+
)
|
207
|
+
return []
|
199
208
|
|
200
|
-
# --- 5. Convert to Detections & Filter ---
|
201
|
-
normalized_classes_req = {
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
209
|
+
# --- 5. Convert to Detections & Filter ---
|
210
|
+
normalized_classes_req = {self._normalize_class_name(c) for c in options.classes}
|
211
|
+
normalized_classes_excl = (
|
212
|
+
{self._normalize_class_name(c) for c in options.exclude_classes}
|
213
|
+
if options.exclude_classes
|
214
|
+
else set()
|
215
|
+
)
|
207
216
|
|
208
217
|
for item in parsed_results:
|
209
218
|
# The item is already a validated DetectedRegion Pydantic object
|
@@ -215,33 +224,41 @@ class GeminiLayoutDetector(LayoutDetector):
|
|
215
224
|
# Coordinates should already be floats, but ensure tuple format
|
216
225
|
xmin, ymin, xmax, ymax = tuple(bbox_raw)
|
217
226
|
|
218
|
-
# --- Apply Filtering ---
|
227
|
+
# --- Apply Filtering ---
|
219
228
|
normalized_class = self._normalize_class_name(label)
|
220
229
|
|
221
230
|
# Check against requested classes (Should be guaranteed by schema, but doesn't hurt)
|
222
231
|
if normalized_class not in normalized_classes_req:
|
223
|
-
logger.warning(
|
232
|
+
logger.warning(
|
233
|
+
f"Gemini (via OpenAI) returned unexpected class '{label}' despite schema. Skipping."
|
234
|
+
)
|
224
235
|
continue
|
225
236
|
|
226
237
|
# Check against excluded classes
|
227
238
|
if normalized_class in normalized_classes_excl:
|
228
|
-
logger.debug(
|
239
|
+
logger.debug(
|
240
|
+
f"Skipping excluded class '{label}' (normalized: {normalized_class})."
|
241
|
+
)
|
229
242
|
continue
|
230
|
-
|
243
|
+
|
231
244
|
# Check against base confidence threshold from options
|
232
245
|
if confidence_score < options.confidence:
|
233
|
-
logger.debug(
|
246
|
+
logger.debug(
|
247
|
+
f"Skipping item with confidence {confidence_score:.3f} below threshold {options.confidence}."
|
248
|
+
)
|
234
249
|
continue
|
235
250
|
|
236
251
|
# Add detection
|
237
|
-
detections.append(
|
238
|
-
|
239
|
-
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
|
244
|
-
|
252
|
+
detections.append(
|
253
|
+
{
|
254
|
+
"bbox": (xmin, ymin, xmax, ymax),
|
255
|
+
"class": label, # Use original label from LLM
|
256
|
+
"confidence": confidence_score,
|
257
|
+
"normalized_class": normalized_class,
|
258
|
+
"source": "layout",
|
259
|
+
"model": "gemini", # Keep model name generic as gemini
|
260
|
+
}
|
261
|
+
)
|
245
262
|
|
246
263
|
self.logger.info(
|
247
264
|
f"Gemini (via OpenAI lib) processed response. Detected {len(detections)} layout elements matching criteria."
|
@@ -260,5 +277,4 @@ class GeminiLayoutDetector(LayoutDetector):
|
|
260
277
|
|
261
278
|
def validate_classes(self, classes: List[str]):
|
262
279
|
"""Validation is less critical as we pass requested classes to the LLM."""
|
263
|
-
pass
|
264
|
-
|
280
|
+
pass # Override base validation if needed, but likely not necessary
|
@@ -0,0 +1,343 @@
|
|
1
|
+
import logging
|
2
|
+
import time
|
3
|
+
from datetime import datetime
|
4
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, Tuple
|
5
|
+
|
6
|
+
# Use try-except for robustness if dependencies are missing
|
7
|
+
try:
|
8
|
+
import torch
|
9
|
+
from PIL import Image
|
10
|
+
from transformers import pipeline, AutoTokenizer, AutoModelForZeroShotImageClassification, AutoModelForSequenceClassification
|
11
|
+
_CLASSIFICATION_AVAILABLE = True
|
12
|
+
except ImportError:
|
13
|
+
_CLASSIFICATION_AVAILABLE = False
|
14
|
+
# Define dummy types for type hinting if imports fail
|
15
|
+
Image = type("Image", (), {})
|
16
|
+
pipeline = object
|
17
|
+
AutoTokenizer = object
|
18
|
+
AutoModelForZeroShotImageClassification = object
|
19
|
+
AutoModelForSequenceClassification = object
|
20
|
+
torch = None
|
21
|
+
|
22
|
+
# Import result classes
|
23
|
+
from .results import ClassificationResult, CategoryScore
|
24
|
+
from natural_pdf.utils.tqdm_utils import get_tqdm
|
25
|
+
|
26
|
+
if TYPE_CHECKING:
|
27
|
+
from transformers import Pipeline
|
28
|
+
|
29
|
+
|
30
|
+
logger = logging.getLogger(__name__)
|
31
|
+
|
32
|
+
# Global cache for models/pipelines
|
33
|
+
_PIPELINE_CACHE: Dict[str, "Pipeline"] = {}
|
34
|
+
_TOKENIZER_CACHE: Dict[str, Any] = {}
|
35
|
+
_MODEL_CACHE: Dict[str, Any] = {}
|
36
|
+
|
37
|
+
class ClassificationError(Exception):
|
38
|
+
"""Custom exception for classification errors."""
|
39
|
+
pass
|
40
|
+
|
41
|
+
|
42
|
+
class ClassificationManager:
|
43
|
+
"""Manages classification models and execution."""
|
44
|
+
|
45
|
+
DEFAULT_TEXT_MODEL = "facebook/bart-large-mnli"
|
46
|
+
DEFAULT_VISION_MODEL = "openai/clip-vit-base-patch16"
|
47
|
+
|
48
|
+
def __init__(
|
49
|
+
self,
|
50
|
+
model_mapping: Optional[Dict[str, str]] = None,
|
51
|
+
default_device: Optional[str] = None,
|
52
|
+
):
|
53
|
+
"""
|
54
|
+
Initialize the ClassificationManager.
|
55
|
+
|
56
|
+
Args:
|
57
|
+
model_mapping: Optional dictionary mapping aliases ('text', 'vision') to model IDs.
|
58
|
+
default_device: Default device ('cpu', 'cuda') if not specified in classify calls.
|
59
|
+
"""
|
60
|
+
if not _CLASSIFICATION_AVAILABLE:
|
61
|
+
raise ImportError(
|
62
|
+
"Classification dependencies missing. "
|
63
|
+
"Install with: pip install \"natural-pdf[classification]\""
|
64
|
+
)
|
65
|
+
|
66
|
+
self.pipelines: Dict[Tuple[str, str], "Pipeline"] = {} # Cache: (model_id, device) -> pipeline
|
67
|
+
|
68
|
+
self.device = default_device
|
69
|
+
logger.info(f"ClassificationManager initialized on device: {self.device}")
|
70
|
+
|
71
|
+
def is_available(self) -> bool:
|
72
|
+
"""Check if required dependencies are installed."""
|
73
|
+
return _CLASSIFICATION_AVAILABLE
|
74
|
+
|
75
|
+
def _get_pipeline(self, model_id: str, using: str) -> "Pipeline":
|
76
|
+
"""Get or create a classification pipeline."""
|
77
|
+
cache_key = f"{model_id}_{using}_{self.device}"
|
78
|
+
if cache_key not in _PIPELINE_CACHE:
|
79
|
+
logger.info(f"Loading {using} classification pipeline for model '{model_id}' on device '{self.device}'...")
|
80
|
+
start_time = time.time()
|
81
|
+
try:
|
82
|
+
task = (
|
83
|
+
"zero-shot-classification"
|
84
|
+
if using == "text"
|
85
|
+
else "zero-shot-image-classification"
|
86
|
+
)
|
87
|
+
_PIPELINE_CACHE[cache_key] = pipeline(
|
88
|
+
task,
|
89
|
+
model=model_id,
|
90
|
+
device=self.device
|
91
|
+
)
|
92
|
+
end_time = time.time()
|
93
|
+
logger.info(f"Pipeline for '{model_id}' loaded in {end_time - start_time:.2f} seconds.")
|
94
|
+
except Exception as e:
|
95
|
+
logger.error(f"Failed to load pipeline for model '{model_id}' (using: {using}): {e}", exc_info=True)
|
96
|
+
raise ClassificationError(f"Failed to load pipeline for model '{model_id}'. Ensure the model ID is correct and supports the {task} task.") from e
|
97
|
+
return _PIPELINE_CACHE[cache_key]
|
98
|
+
|
99
|
+
def infer_using(self, model_id: str, using: Optional[str] = None) -> str:
|
100
|
+
"""Infers processing mode ('text' or 'vision') if not provided."""
|
101
|
+
if using in ["text", "vision"]:
|
102
|
+
return using
|
103
|
+
|
104
|
+
# Simple inference based on common model names
|
105
|
+
normalized_model_id = model_id.lower()
|
106
|
+
if "clip" in normalized_model_id or "vit" in normalized_model_id or "siglip" in normalized_model_id:
|
107
|
+
logger.debug(f"Inferred using='vision' for model '{model_id}'")
|
108
|
+
return "vision"
|
109
|
+
if "bart" in normalized_model_id or "bert" in normalized_model_id or "mnli" in normalized_model_id or "xnli" in normalized_model_id or "deberta" in normalized_model_id:
|
110
|
+
logger.debug(f"Inferred using='text' for model '{model_id}'")
|
111
|
+
return "text"
|
112
|
+
|
113
|
+
# Fallback or raise error? Let's try loading text first, then vision.
|
114
|
+
logger.warning(f"Could not reliably infer mode for '{model_id}'. Trying text, then vision pipeline loading.")
|
115
|
+
try:
|
116
|
+
self._get_pipeline(model_id, "text")
|
117
|
+
logger.info(f"Successfully loaded '{model_id}' as a text model.")
|
118
|
+
return "text"
|
119
|
+
except Exception:
|
120
|
+
logger.warning(f"Failed to load '{model_id}' as text model. Trying vision.")
|
121
|
+
try:
|
122
|
+
self._get_pipeline(model_id, "vision")
|
123
|
+
logger.info(f"Successfully loaded '{model_id}' as a vision model.")
|
124
|
+
return "vision"
|
125
|
+
except Exception as e_vision:
|
126
|
+
logger.error(f"Failed to load '{model_id}' as either text or vision model.", exc_info=True)
|
127
|
+
raise ClassificationError(f"Cannot determine mode for model '{model_id}'. Please specify `using='text'` or `using='vision'`. Error: {e_vision}")
|
128
|
+
|
129
|
+
def classify_item(
|
130
|
+
self,
|
131
|
+
item_content: Union[str, Image.Image],
|
132
|
+
categories: List[str],
|
133
|
+
model_id: Optional[str] = None,
|
134
|
+
using: Optional[str] = None,
|
135
|
+
min_confidence: float = 0.0,
|
136
|
+
multi_label: bool = False,
|
137
|
+
**kwargs
|
138
|
+
) -> ClassificationResult: # Return ClassificationResult
|
139
|
+
"""Classifies a single item (text or image)."""
|
140
|
+
|
141
|
+
# Determine model and engine type
|
142
|
+
effective_using = using
|
143
|
+
if model_id is None:
|
144
|
+
# Try inferring based on content type
|
145
|
+
if isinstance(item_content, str):
|
146
|
+
effective_using = "text"
|
147
|
+
model_id = self.DEFAULT_TEXT_MODEL
|
148
|
+
elif isinstance(item_content, Image.Image):
|
149
|
+
effective_using = "vision"
|
150
|
+
model_id = self.DEFAULT_VISION_MODEL
|
151
|
+
else:
|
152
|
+
raise TypeError(f"Unsupported item_content type: {type(item_content)}")
|
153
|
+
else:
|
154
|
+
# Infer engine type if not given
|
155
|
+
effective_using = self.infer_using(model_id, using)
|
156
|
+
# Set default model if needed (though should usually be provided if engine known)
|
157
|
+
if model_id is None:
|
158
|
+
model_id = self.DEFAULT_TEXT_MODEL if effective_using == "text" else self.DEFAULT_VISION_MODEL
|
159
|
+
|
160
|
+
if not categories:
|
161
|
+
raise ValueError("Categories list cannot be empty.")
|
162
|
+
|
163
|
+
pipeline_instance = self._get_pipeline(model_id, effective_using)
|
164
|
+
timestamp = datetime.now()
|
165
|
+
parameters = { # Store parameters used for this run
|
166
|
+
'categories': categories,
|
167
|
+
'model_id': model_id,
|
168
|
+
'using': effective_using,
|
169
|
+
'min_confidence': min_confidence,
|
170
|
+
'multi_label': multi_label,
|
171
|
+
**kwargs
|
172
|
+
}
|
173
|
+
|
174
|
+
logger.debug(f"Classifying content (type: {type(item_content).__name__}) with model '{model_id}'")
|
175
|
+
try:
|
176
|
+
# Handle potential kwargs for specific pipelines if needed
|
177
|
+
# The zero-shot pipelines expect `candidate_labels`
|
178
|
+
result_raw = pipeline_instance(item_content, candidate_labels=categories, multi_label=multi_label, **kwargs)
|
179
|
+
logger.debug(f"Raw pipeline result: {result_raw}")
|
180
|
+
|
181
|
+
# --- Process raw result into ClassificationResult --- #
|
182
|
+
scores_list: List[CategoryScore] = []
|
183
|
+
|
184
|
+
# Handle text pipeline format (dict with 'labels' and 'scores')
|
185
|
+
if isinstance(result_raw, dict) and 'labels' in result_raw and 'scores' in result_raw:
|
186
|
+
for label, score_val in zip(result_raw['labels'], result_raw['scores']):
|
187
|
+
if score_val >= min_confidence:
|
188
|
+
try:
|
189
|
+
scores_list.append(CategoryScore(label=label, confidence=score_val))
|
190
|
+
except (ValueError, TypeError) as score_err:
|
191
|
+
logger.warning(f"Skipping invalid score from text pipeline: label='{label}', score={score_val}. Error: {score_err}")
|
192
|
+
# Handle vision pipeline format (list of dicts with 'label' and 'score')
|
193
|
+
elif isinstance(result_raw, list) and all(isinstance(item, dict) and 'label' in item and 'score' in item for item in result_raw):
|
194
|
+
for item in result_raw:
|
195
|
+
score_val = item['score']
|
196
|
+
label = item['label']
|
197
|
+
if score_val >= min_confidence:
|
198
|
+
try:
|
199
|
+
scores_list.append(CategoryScore(label=label, confidence=score_val))
|
200
|
+
except (ValueError, TypeError) as score_err:
|
201
|
+
logger.warning(f"Skipping invalid score from vision pipeline: label='{label}', score={score_val}. Error: {score_err}")
|
202
|
+
else:
|
203
|
+
logger.warning(f"Unexpected raw result format from pipeline for model '{model_id}': {type(result_raw)}. Cannot extract scores.")
|
204
|
+
# Return empty result?
|
205
|
+
# scores_list = []
|
206
|
+
|
207
|
+
return ClassificationResult(
|
208
|
+
model_id=model_id,
|
209
|
+
using=effective_using,
|
210
|
+
timestamp=timestamp,
|
211
|
+
parameters=parameters,
|
212
|
+
scores=scores_list
|
213
|
+
)
|
214
|
+
# --- End Processing --- #
|
215
|
+
|
216
|
+
except Exception as e:
|
217
|
+
logger.error(f"Classification failed for model '{model_id}': {e}", exc_info=True)
|
218
|
+
# Return an empty result object on failure?
|
219
|
+
# return ClassificationResult(model_id=model_id, engine_type=engine_type, timestamp=timestamp, parameters=parameters, scores=[])
|
220
|
+
raise ClassificationError(f"Classification failed using model '{model_id}'. Error: {e}") from e
|
221
|
+
|
222
|
+
def classify_batch(
|
223
|
+
self,
|
224
|
+
item_contents: List[Union[str, Image.Image]],
|
225
|
+
categories: List[str],
|
226
|
+
model_id: Optional[str] = None,
|
227
|
+
using: Optional[str] = None,
|
228
|
+
min_confidence: float = 0.0,
|
229
|
+
multi_label: bool = False,
|
230
|
+
batch_size: int = 8,
|
231
|
+
progress_bar: bool = True,
|
232
|
+
**kwargs
|
233
|
+
) -> List[ClassificationResult]: # Return list of ClassificationResult
|
234
|
+
"""Classifies a batch of items (text or image) using the pipeline's batching."""
|
235
|
+
if not item_contents:
|
236
|
+
return []
|
237
|
+
|
238
|
+
# Determine model and engine type (assuming uniform type in batch)
|
239
|
+
first_item = item_contents[0]
|
240
|
+
effective_using = using
|
241
|
+
if model_id is None:
|
242
|
+
if isinstance(first_item, str):
|
243
|
+
effective_using = "text"
|
244
|
+
model_id = self.DEFAULT_TEXT_MODEL
|
245
|
+
elif isinstance(first_item, Image.Image):
|
246
|
+
effective_using = "vision"
|
247
|
+
model_id = self.DEFAULT_VISION_MODEL
|
248
|
+
else:
|
249
|
+
raise TypeError(f"Unsupported item_content type in batch: {type(first_item)}")
|
250
|
+
else:
|
251
|
+
effective_using = self.infer_using(model_id, using)
|
252
|
+
if model_id is None:
|
253
|
+
model_id = self.DEFAULT_TEXT_MODEL if effective_using == "text" else self.DEFAULT_VISION_MODEL
|
254
|
+
|
255
|
+
if not categories:
|
256
|
+
raise ValueError("Categories list cannot be empty.")
|
257
|
+
|
258
|
+
pipeline_instance = self._get_pipeline(model_id, effective_using)
|
259
|
+
timestamp = datetime.now() # Single timestamp for the batch run
|
260
|
+
parameters = { # Parameters for the whole batch
|
261
|
+
'categories': categories,
|
262
|
+
'model_id': model_id,
|
263
|
+
'using': effective_using,
|
264
|
+
'min_confidence': min_confidence,
|
265
|
+
'multi_label': multi_label,
|
266
|
+
'batch_size': batch_size,
|
267
|
+
**kwargs
|
268
|
+
}
|
269
|
+
|
270
|
+
logger.info(f"Classifying batch of {len(item_contents)} items with model '{model_id}' (batch size: {batch_size})")
|
271
|
+
batch_results_list: List[ClassificationResult] = []
|
272
|
+
|
273
|
+
try:
|
274
|
+
# Use pipeline directly for batching
|
275
|
+
results_iterator = pipeline_instance(
|
276
|
+
item_contents,
|
277
|
+
candidate_labels=categories,
|
278
|
+
multi_label=multi_label,
|
279
|
+
batch_size=batch_size,
|
280
|
+
**kwargs
|
281
|
+
)
|
282
|
+
|
283
|
+
# Wrap with tqdm for progress if requested
|
284
|
+
total_items = len(item_contents)
|
285
|
+
if progress_bar:
|
286
|
+
# Get the appropriate tqdm class
|
287
|
+
tqdm_class = get_tqdm()
|
288
|
+
results_iterator = tqdm_class(
|
289
|
+
results_iterator,
|
290
|
+
total=total_items,
|
291
|
+
desc=f"Classifying batch ({model_id})",
|
292
|
+
leave=False # Don't leave progress bar hanging
|
293
|
+
)
|
294
|
+
|
295
|
+
for raw_result in results_iterator:
|
296
|
+
# --- Process each raw result (which corresponds to ONE input item) --- #
|
297
|
+
scores_list: List[CategoryScore] = []
|
298
|
+
try:
|
299
|
+
# Check for text format (dict with 'labels' and 'scores')
|
300
|
+
if isinstance(raw_result, dict) and 'labels' in raw_result and 'scores' in raw_result:
|
301
|
+
for label, score_val in zip(raw_result['labels'], raw_result['scores']):
|
302
|
+
if score_val >= min_confidence:
|
303
|
+
try:
|
304
|
+
scores_list.append(CategoryScore(label=label, confidence=score_val))
|
305
|
+
except (ValueError, TypeError) as score_err:
|
306
|
+
logger.warning(f"Skipping invalid score from text pipeline batch: label='{label}', score={score_val}. Error: {score_err}")
|
307
|
+
# Check for vision format (list of dicts with 'label' and 'score')
|
308
|
+
elif isinstance(raw_result, list):
|
309
|
+
for item in raw_result:
|
310
|
+
try:
|
311
|
+
score_val = item['score']
|
312
|
+
label = item['label']
|
313
|
+
if score_val >= min_confidence:
|
314
|
+
scores_list.append(CategoryScore(label=label, confidence=score_val))
|
315
|
+
except (KeyError, ValueError, TypeError) as item_err:
|
316
|
+
logger.warning(f"Skipping invalid item in vision result list from batch: {item}. Error: {item_err}")
|
317
|
+
else:
|
318
|
+
logger.warning(f"Unexpected raw result format in batch item from model '{model_id}': {type(raw_result)}. Cannot extract scores.")
|
319
|
+
|
320
|
+
except Exception as proc_err:
|
321
|
+
logger.error(f"Error processing result item in batch: {proc_err}", exc_info=True)
|
322
|
+
# scores_list remains empty for this item
|
323
|
+
|
324
|
+
# Append result object for this item
|
325
|
+
batch_results_list.append(ClassificationResult(
|
326
|
+
model_id=model_id,
|
327
|
+
using=effective_using,
|
328
|
+
timestamp=timestamp, # Use same timestamp for batch
|
329
|
+
parameters=parameters, # Use same params for batch
|
330
|
+
scores=scores_list
|
331
|
+
))
|
332
|
+
# --- End Processing --- #
|
333
|
+
|
334
|
+
if len(batch_results_list) != total_items:
|
335
|
+
logger.warning(f"Batch classification returned {len(batch_results_list)} results, but expected {total_items}. Results might be incomplete or misaligned.")
|
336
|
+
|
337
|
+
return batch_results_list
|
338
|
+
|
339
|
+
except Exception as e:
|
340
|
+
logger.error(f"Batch classification failed for model '{model_id}': {e}", exc_info=True)
|
341
|
+
# Return list of empty results?
|
342
|
+
# return [ClassificationResult(model_id=model_id, s=engine_type, timestamp=timestamp, parameters=parameters, scores=[]) for _ in item_contents]
|
343
|
+
raise ClassificationError(f"Batch classification failed using model '{model_id}'. Error: {e}") from e
|