natural-pdf 0.1.8__py3-none-any.whl → 0.1.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- natural_pdf/__init__.py +1 -0
- natural_pdf/analyzers/layout/base.py +1 -5
- natural_pdf/analyzers/layout/gemini.py +61 -51
- natural_pdf/analyzers/layout/layout_analyzer.py +40 -11
- natural_pdf/analyzers/layout/layout_manager.py +26 -84
- natural_pdf/analyzers/layout/layout_options.py +7 -0
- natural_pdf/analyzers/layout/pdfplumber_table_finder.py +142 -0
- natural_pdf/analyzers/layout/surya.py +46 -123
- natural_pdf/analyzers/layout/tatr.py +51 -4
- natural_pdf/analyzers/text_structure.py +3 -5
- natural_pdf/analyzers/utils.py +3 -3
- natural_pdf/classification/manager.py +241 -158
- natural_pdf/classification/mixin.py +52 -38
- natural_pdf/classification/results.py +71 -45
- natural_pdf/collections/mixins.py +85 -20
- natural_pdf/collections/pdf_collection.py +245 -100
- natural_pdf/core/element_manager.py +30 -14
- natural_pdf/core/highlighting_service.py +13 -22
- natural_pdf/core/page.py +423 -101
- natural_pdf/core/pdf.py +694 -195
- natural_pdf/elements/base.py +134 -40
- natural_pdf/elements/collections.py +610 -134
- natural_pdf/elements/region.py +659 -90
- natural_pdf/elements/text.py +1 -1
- natural_pdf/export/mixin.py +137 -0
- natural_pdf/exporters/base.py +3 -3
- natural_pdf/exporters/paddleocr.py +4 -3
- natural_pdf/extraction/manager.py +50 -49
- natural_pdf/extraction/mixin.py +90 -57
- natural_pdf/extraction/result.py +9 -23
- natural_pdf/ocr/__init__.py +5 -5
- natural_pdf/ocr/engine_doctr.py +346 -0
- natural_pdf/ocr/ocr_factory.py +24 -4
- natural_pdf/ocr/ocr_manager.py +61 -25
- natural_pdf/ocr/ocr_options.py +70 -10
- natural_pdf/ocr/utils.py +6 -4
- natural_pdf/search/__init__.py +20 -34
- natural_pdf/search/haystack_search_service.py +309 -265
- natural_pdf/search/haystack_utils.py +99 -75
- natural_pdf/search/search_service_protocol.py +11 -12
- natural_pdf/selectors/parser.py +219 -143
- natural_pdf/utils/debug.py +3 -3
- natural_pdf/utils/identifiers.py +1 -1
- natural_pdf/utils/locks.py +1 -1
- natural_pdf/utils/packaging.py +8 -6
- natural_pdf/utils/text_extraction.py +24 -16
- natural_pdf/utils/tqdm_utils.py +18 -10
- natural_pdf/utils/visualization.py +18 -0
- natural_pdf/widgets/viewer.py +4 -25
- {natural_pdf-0.1.8.dist-info → natural_pdf-0.1.10.dist-info}/METADATA +12 -3
- natural_pdf-0.1.10.dist-info/RECORD +80 -0
- {natural_pdf-0.1.8.dist-info → natural_pdf-0.1.10.dist-info}/WHEEL +1 -1
- {natural_pdf-0.1.8.dist-info → natural_pdf-0.1.10.dist-info}/top_level.txt +0 -2
- docs/api/index.md +0 -386
- docs/assets/favicon.png +0 -3
- docs/assets/favicon.svg +0 -3
- docs/assets/javascripts/custom.js +0 -17
- docs/assets/logo.svg +0 -3
- docs/assets/sample-screen.png +0 -0
- docs/assets/social-preview.png +0 -17
- docs/assets/social-preview.svg +0 -17
- docs/assets/stylesheets/custom.css +0 -65
- docs/categorizing-documents/index.md +0 -168
- docs/data-extraction/index.md +0 -87
- docs/document-qa/index.ipynb +0 -435
- docs/document-qa/index.md +0 -79
- docs/element-selection/index.ipynb +0 -969
- docs/element-selection/index.md +0 -249
- docs/finetuning/index.md +0 -176
- docs/index.md +0 -189
- docs/installation/index.md +0 -69
- docs/interactive-widget/index.ipynb +0 -962
- docs/interactive-widget/index.md +0 -12
- docs/layout-analysis/index.ipynb +0 -818
- docs/layout-analysis/index.md +0 -185
- docs/ocr/index.md +0 -256
- docs/pdf-navigation/index.ipynb +0 -314
- docs/pdf-navigation/index.md +0 -97
- docs/regions/index.ipynb +0 -816
- docs/regions/index.md +0 -294
- docs/tables/index.ipynb +0 -658
- docs/tables/index.md +0 -144
- docs/text-analysis/index.ipynb +0 -370
- docs/text-analysis/index.md +0 -105
- docs/text-extraction/index.ipynb +0 -1478
- docs/text-extraction/index.md +0 -292
- docs/tutorials/01-loading-and-extraction.ipynb +0 -1873
- docs/tutorials/01-loading-and-extraction.md +0 -95
- docs/tutorials/02-finding-elements.ipynb +0 -417
- docs/tutorials/02-finding-elements.md +0 -149
- docs/tutorials/03-extracting-blocks.ipynb +0 -152
- docs/tutorials/03-extracting-blocks.md +0 -48
- docs/tutorials/04-table-extraction.ipynb +0 -119
- docs/tutorials/04-table-extraction.md +0 -50
- docs/tutorials/05-excluding-content.ipynb +0 -275
- docs/tutorials/05-excluding-content.md +0 -109
- docs/tutorials/06-document-qa.ipynb +0 -337
- docs/tutorials/06-document-qa.md +0 -91
- docs/tutorials/07-layout-analysis.ipynb +0 -293
- docs/tutorials/07-layout-analysis.md +0 -66
- docs/tutorials/07-working-with-regions.ipynb +0 -414
- docs/tutorials/07-working-with-regions.md +0 -151
- docs/tutorials/08-spatial-navigation.ipynb +0 -513
- docs/tutorials/08-spatial-navigation.md +0 -190
- docs/tutorials/09-section-extraction.ipynb +0 -2439
- docs/tutorials/09-section-extraction.md +0 -256
- docs/tutorials/10-form-field-extraction.ipynb +0 -517
- docs/tutorials/10-form-field-extraction.md +0 -201
- docs/tutorials/11-enhanced-table-processing.ipynb +0 -59
- docs/tutorials/11-enhanced-table-processing.md +0 -9
- docs/tutorials/12-ocr-integration.ipynb +0 -3712
- docs/tutorials/12-ocr-integration.md +0 -137
- docs/tutorials/13-semantic-search.ipynb +0 -1718
- docs/tutorials/13-semantic-search.md +0 -77
- docs/visual-debugging/index.ipynb +0 -2970
- docs/visual-debugging/index.md +0 -157
- docs/visual-debugging/region.png +0 -0
- natural_pdf/templates/finetune/fine_tune_paddleocr.md +0 -420
- natural_pdf/templates/spa/css/style.css +0 -334
- natural_pdf/templates/spa/index.html +0 -31
- natural_pdf/templates/spa/js/app.js +0 -472
- natural_pdf/templates/spa/words.txt +0 -235976
- natural_pdf/widgets/frontend/viewer.js +0 -88
- natural_pdf-0.1.8.dist-info/RECORD +0 -156
- notebooks/Examples.ipynb +0 -1293
- pdfs/.gitkeep +0 -0
- pdfs/01-practice.pdf +0 -543
- pdfs/0500000US42001.pdf +0 -0
- pdfs/0500000US42007.pdf +0 -0
- pdfs/2014 Statistics.pdf +0 -0
- pdfs/2019 Statistics.pdf +0 -0
- pdfs/Atlanta_Public_Schools_GA_sample.pdf +0 -0
- pdfs/needs-ocr.pdf +0 -0
- {natural_pdf-0.1.8.dist-info → natural_pdf-0.1.10.dist-info}/licenses/LICENSE +0 -0
@@ -1,28 +1,35 @@
|
|
1
1
|
import logging
|
2
2
|
import time
|
3
3
|
from datetime import datetime
|
4
|
-
from typing import TYPE_CHECKING, Any, Dict, List, Optional,
|
4
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
5
|
+
|
6
|
+
from PIL import Image
|
5
7
|
|
6
8
|
# Use try-except for robustness if dependencies are missing
|
7
9
|
try:
|
8
10
|
import torch
|
9
|
-
from
|
10
|
-
|
11
|
+
from transformers import (
|
12
|
+
AutoModelForSequenceClassification,
|
13
|
+
AutoModelForZeroShotImageClassification,
|
14
|
+
AutoTokenizer,
|
15
|
+
pipeline,
|
16
|
+
)
|
17
|
+
|
11
18
|
_CLASSIFICATION_AVAILABLE = True
|
12
19
|
except ImportError:
|
13
20
|
_CLASSIFICATION_AVAILABLE = False
|
14
21
|
# Define dummy types for type hinting if imports fail
|
15
|
-
Image = type("Image", (), {})
|
16
22
|
pipeline = object
|
17
23
|
AutoTokenizer = object
|
18
24
|
AutoModelForZeroShotImageClassification = object
|
19
25
|
AutoModelForSequenceClassification = object
|
20
26
|
torch = None
|
21
27
|
|
22
|
-
# Import result classes
|
23
|
-
from .results import ClassificationResult, CategoryScore
|
24
28
|
from natural_pdf.utils.tqdm_utils import get_tqdm
|
25
29
|
|
30
|
+
# Import result classes
|
31
|
+
from .results import CategoryScore, ClassificationResult
|
32
|
+
|
26
33
|
if TYPE_CHECKING:
|
27
34
|
from transformers import Pipeline
|
28
35
|
|
@@ -34,8 +41,10 @@ _PIPELINE_CACHE: Dict[str, "Pipeline"] = {}
|
|
34
41
|
_TOKENIZER_CACHE: Dict[str, Any] = {}
|
35
42
|
_MODEL_CACHE: Dict[str, Any] = {}
|
36
43
|
|
44
|
+
|
37
45
|
class ClassificationError(Exception):
|
38
46
|
"""Custom exception for classification errors."""
|
47
|
+
|
39
48
|
pass
|
40
49
|
|
41
50
|
|
@@ -60,10 +69,12 @@ class ClassificationManager:
|
|
60
69
|
if not _CLASSIFICATION_AVAILABLE:
|
61
70
|
raise ImportError(
|
62
71
|
"Classification dependencies missing. "
|
63
|
-
|
72
|
+
'Install with: pip install "natural-pdf[classification]"'
|
64
73
|
)
|
65
74
|
|
66
|
-
self.pipelines: Dict[Tuple[str, str], "Pipeline"] =
|
75
|
+
self.pipelines: Dict[Tuple[str, str], "Pipeline"] = (
|
76
|
+
{}
|
77
|
+
) # Cache: (model_id, device) -> pipeline
|
67
78
|
|
68
79
|
self.device = default_device
|
69
80
|
logger.info(f"ClassificationManager initialized on device: {self.device}")
|
@@ -76,7 +87,9 @@ class ClassificationManager:
|
|
76
87
|
"""Get or create a classification pipeline."""
|
77
88
|
cache_key = f"{model_id}_{using}_{self.device}"
|
78
89
|
if cache_key not in _PIPELINE_CACHE:
|
79
|
-
logger.info(
|
90
|
+
logger.info(
|
91
|
+
f"Loading {using} classification pipeline for model '{model_id}' on device '{self.device}'..."
|
92
|
+
)
|
80
93
|
start_time = time.time()
|
81
94
|
try:
|
82
95
|
task = (
|
@@ -84,16 +97,19 @@ class ClassificationManager:
|
|
84
97
|
if using == "text"
|
85
98
|
else "zero-shot-image-classification"
|
86
99
|
)
|
87
|
-
_PIPELINE_CACHE[cache_key] = pipeline(
|
88
|
-
task,
|
89
|
-
model=model_id,
|
90
|
-
device=self.device
|
91
|
-
)
|
100
|
+
_PIPELINE_CACHE[cache_key] = pipeline(task, model=model_id, device=self.device)
|
92
101
|
end_time = time.time()
|
93
|
-
logger.info(
|
102
|
+
logger.info(
|
103
|
+
f"Pipeline for '{model_id}' loaded in {end_time - start_time:.2f} seconds."
|
104
|
+
)
|
94
105
|
except Exception as e:
|
95
|
-
logger.error(
|
96
|
-
|
106
|
+
logger.error(
|
107
|
+
f"Failed to load pipeline for model '{model_id}' (using: {using}): {e}",
|
108
|
+
exc_info=True,
|
109
|
+
)
|
110
|
+
raise ClassificationError(
|
111
|
+
f"Failed to load pipeline for model '{model_id}'. Ensure the model ID is correct and supports the {task} task."
|
112
|
+
) from e
|
97
113
|
return _PIPELINE_CACHE[cache_key]
|
98
114
|
|
99
115
|
def infer_using(self, model_id: str, using: Optional[str] = None) -> str:
|
@@ -103,241 +119,308 @@ class ClassificationManager:
|
|
103
119
|
|
104
120
|
# Simple inference based on common model names
|
105
121
|
normalized_model_id = model_id.lower()
|
106
|
-
if
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
122
|
+
if (
|
123
|
+
"clip" in normalized_model_id
|
124
|
+
or "vit" in normalized_model_id
|
125
|
+
or "siglip" in normalized_model_id
|
126
|
+
):
|
127
|
+
logger.debug(f"Inferred using='vision' for model '{model_id}'")
|
128
|
+
return "vision"
|
129
|
+
if (
|
130
|
+
"bart" in normalized_model_id
|
131
|
+
or "bert" in normalized_model_id
|
132
|
+
or "mnli" in normalized_model_id
|
133
|
+
or "xnli" in normalized_model_id
|
134
|
+
or "deberta" in normalized_model_id
|
135
|
+
):
|
136
|
+
logger.debug(f"Inferred using='text' for model '{model_id}'")
|
137
|
+
return "text"
|
112
138
|
|
113
139
|
# Fallback or raise error? Let's try loading text first, then vision.
|
114
|
-
logger.warning(
|
140
|
+
logger.warning(
|
141
|
+
f"Could not reliably infer mode for '{model_id}'. Trying text, then vision pipeline loading."
|
142
|
+
)
|
115
143
|
try:
|
116
144
|
self._get_pipeline(model_id, "text")
|
117
145
|
logger.info(f"Successfully loaded '{model_id}' as a text model.")
|
118
146
|
return "text"
|
119
147
|
except Exception:
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
148
|
+
logger.warning(f"Failed to load '{model_id}' as text model. Trying vision.")
|
149
|
+
try:
|
150
|
+
self._get_pipeline(model_id, "vision")
|
151
|
+
logger.info(f"Successfully loaded '{model_id}' as a vision model.")
|
152
|
+
return "vision"
|
153
|
+
except Exception as e_vision:
|
154
|
+
logger.error(
|
155
|
+
f"Failed to load '{model_id}' as either text or vision model.", exc_info=True
|
156
|
+
)
|
157
|
+
raise ClassificationError(
|
158
|
+
f"Cannot determine mode for model '{model_id}'. Please specify `using='text'` or `using='vision'`. Error: {e_vision}"
|
159
|
+
)
|
128
160
|
|
129
161
|
def classify_item(
|
130
162
|
self,
|
131
163
|
item_content: Union[str, Image.Image],
|
132
|
-
|
164
|
+
labels: List[str],
|
133
165
|
model_id: Optional[str] = None,
|
134
166
|
using: Optional[str] = None,
|
135
167
|
min_confidence: float = 0.0,
|
136
168
|
multi_label: bool = False,
|
137
|
-
**kwargs
|
138
|
-
) -> ClassificationResult:
|
169
|
+
**kwargs,
|
170
|
+
) -> ClassificationResult: # Return ClassificationResult
|
139
171
|
"""Classifies a single item (text or image)."""
|
140
172
|
|
141
173
|
# Determine model and engine type
|
142
174
|
effective_using = using
|
143
175
|
if model_id is None:
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
176
|
+
# Try inferring based on content type
|
177
|
+
if isinstance(item_content, str):
|
178
|
+
effective_using = "text"
|
179
|
+
model_id = self.DEFAULT_TEXT_MODEL
|
180
|
+
elif isinstance(item_content, Image.Image):
|
181
|
+
effective_using = "vision"
|
182
|
+
model_id = self.DEFAULT_VISION_MODEL
|
183
|
+
else:
|
184
|
+
raise TypeError(f"Unsupported item_content type: {type(item_content)}")
|
153
185
|
else:
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
186
|
+
# Infer engine type if not given
|
187
|
+
effective_using = self.infer_using(model_id, using)
|
188
|
+
# Set default model if needed (though should usually be provided if engine known)
|
189
|
+
if model_id is None:
|
190
|
+
model_id = (
|
191
|
+
self.DEFAULT_TEXT_MODEL
|
192
|
+
if effective_using == "text"
|
193
|
+
else self.DEFAULT_VISION_MODEL
|
194
|
+
)
|
159
195
|
|
160
|
-
if not
|
161
|
-
|
196
|
+
if not labels:
|
197
|
+
raise ValueError("Labels list cannot be empty.")
|
162
198
|
|
163
199
|
pipeline_instance = self._get_pipeline(model_id, effective_using)
|
164
200
|
timestamp = datetime.now()
|
165
|
-
parameters = {
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
**kwargs
|
201
|
+
parameters = { # Store parameters used for this run
|
202
|
+
"labels": labels,
|
203
|
+
"model_id": model_id,
|
204
|
+
"using": effective_using,
|
205
|
+
"min_confidence": min_confidence,
|
206
|
+
"multi_label": multi_label,
|
207
|
+
**kwargs,
|
172
208
|
}
|
173
209
|
|
174
|
-
logger.debug(
|
210
|
+
logger.debug(
|
211
|
+
f"Classifying content (type: {type(item_content).__name__}) with model '{model_id}'"
|
212
|
+
)
|
175
213
|
try:
|
176
214
|
# Handle potential kwargs for specific pipelines if needed
|
177
215
|
# The zero-shot pipelines expect `candidate_labels`
|
178
|
-
result_raw = pipeline_instance(
|
216
|
+
result_raw = pipeline_instance(
|
217
|
+
item_content, candidate_labels=labels, multi_label=multi_label, **kwargs
|
218
|
+
)
|
179
219
|
logger.debug(f"Raw pipeline result: {result_raw}")
|
180
220
|
|
181
|
-
# --- Process raw result into ClassificationResult --- #
|
221
|
+
# --- Process raw result into ClassificationResult --- #
|
182
222
|
scores_list: List[CategoryScore] = []
|
183
|
-
|
223
|
+
|
184
224
|
# Handle text pipeline format (dict with 'labels' and 'scores')
|
185
|
-
if isinstance(result_raw, dict) and
|
186
|
-
for label, score_val in zip(result_raw[
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
225
|
+
if isinstance(result_raw, dict) and "labels" in result_raw and "scores" in result_raw:
|
226
|
+
for label, score_val in zip(result_raw["labels"], result_raw["scores"]):
|
227
|
+
if score_val >= min_confidence:
|
228
|
+
try:
|
229
|
+
scores_list.append(CategoryScore(label, score_val))
|
230
|
+
except (ValueError, TypeError) as score_err:
|
231
|
+
logger.warning(
|
232
|
+
f"Skipping invalid score from text pipeline: label='{label}', score={score_val}. Error: {score_err}"
|
233
|
+
)
|
192
234
|
# Handle vision pipeline format (list of dicts with 'label' and 'score')
|
193
|
-
elif isinstance(result_raw, list) and all(
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
235
|
+
elif isinstance(result_raw, list) and all(
|
236
|
+
isinstance(item, dict) and "label" in item and "score" in item
|
237
|
+
for item in result_raw
|
238
|
+
):
|
239
|
+
for item in result_raw:
|
240
|
+
score_val = item["score"]
|
241
|
+
label = item["label"]
|
242
|
+
if score_val >= min_confidence:
|
243
|
+
try:
|
244
|
+
scores_list.append(CategoryScore(label, score_val))
|
245
|
+
except (ValueError, TypeError) as score_err:
|
246
|
+
logger.warning(
|
247
|
+
f"Skipping invalid score from vision pipeline: label='{label}', score={score_val}. Error: {score_err}"
|
248
|
+
)
|
202
249
|
else:
|
203
|
-
|
204
|
-
|
205
|
-
|
250
|
+
logger.warning(
|
251
|
+
f"Unexpected raw result format from pipeline for model '{model_id}': {type(result_raw)}. Cannot extract scores."
|
252
|
+
)
|
253
|
+
# Return empty result?
|
254
|
+
# scores_list = []
|
206
255
|
|
207
|
-
|
256
|
+
# ClassificationResult now calculates top score/category internally
|
257
|
+
result_obj = ClassificationResult(
|
258
|
+
scores=scores_list, # Pass the filtered list
|
208
259
|
model_id=model_id,
|
209
260
|
using=effective_using,
|
210
|
-
timestamp=timestamp,
|
211
261
|
parameters=parameters,
|
212
|
-
|
262
|
+
timestamp=timestamp,
|
213
263
|
)
|
264
|
+
return result_obj
|
214
265
|
# --- End Processing --- #
|
215
266
|
|
216
267
|
except Exception as e:
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
268
|
+
logger.error(f"Classification failed for model '{model_id}': {e}", exc_info=True)
|
269
|
+
# Return an empty result object on failure?
|
270
|
+
# return ClassificationResult(model_id=model_id, engine_type=engine_type, timestamp=timestamp, parameters=parameters, scores=[])
|
271
|
+
raise ClassificationError(
|
272
|
+
f"Classification failed using model '{model_id}'. Error: {e}"
|
273
|
+
) from e
|
221
274
|
|
222
275
|
def classify_batch(
|
223
276
|
self,
|
224
277
|
item_contents: List[Union[str, Image.Image]],
|
225
|
-
|
278
|
+
labels: List[str],
|
226
279
|
model_id: Optional[str] = None,
|
227
280
|
using: Optional[str] = None,
|
228
281
|
min_confidence: float = 0.0,
|
229
282
|
multi_label: bool = False,
|
230
283
|
batch_size: int = 8,
|
231
284
|
progress_bar: bool = True,
|
232
|
-
**kwargs
|
233
|
-
) -> List[ClassificationResult]:
|
285
|
+
**kwargs,
|
286
|
+
) -> List[ClassificationResult]: # Return list of ClassificationResult
|
234
287
|
"""Classifies a batch of items (text or image) using the pipeline's batching."""
|
235
288
|
if not item_contents:
|
236
|
-
|
289
|
+
return []
|
237
290
|
|
238
291
|
# Determine model and engine type (assuming uniform type in batch)
|
239
292
|
first_item = item_contents[0]
|
240
293
|
effective_using = using
|
241
294
|
if model_id is None:
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
295
|
+
if isinstance(first_item, str):
|
296
|
+
effective_using = "text"
|
297
|
+
model_id = self.DEFAULT_TEXT_MODEL
|
298
|
+
elif isinstance(first_item, Image.Image):
|
299
|
+
effective_using = "vision"
|
300
|
+
model_id = self.DEFAULT_VISION_MODEL
|
301
|
+
else:
|
302
|
+
raise TypeError(f"Unsupported item_content type in batch: {type(first_item)}")
|
250
303
|
else:
|
251
|
-
|
252
|
-
|
253
|
-
|
304
|
+
effective_using = self.infer_using(model_id, using)
|
305
|
+
if model_id is None:
|
306
|
+
model_id = (
|
307
|
+
self.DEFAULT_TEXT_MODEL
|
308
|
+
if effective_using == "text"
|
309
|
+
else self.DEFAULT_VISION_MODEL
|
310
|
+
)
|
254
311
|
|
255
|
-
if not
|
256
|
-
|
312
|
+
if not labels:
|
313
|
+
raise ValueError("Labels list cannot be empty.")
|
257
314
|
|
258
315
|
pipeline_instance = self._get_pipeline(model_id, effective_using)
|
259
|
-
timestamp = datetime.now()
|
260
|
-
parameters = {
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
267
|
-
**kwargs
|
316
|
+
timestamp = datetime.now() # Single timestamp for the batch run
|
317
|
+
parameters = { # Parameters for the whole batch
|
318
|
+
"labels": labels,
|
319
|
+
"model_id": model_id,
|
320
|
+
"using": effective_using,
|
321
|
+
"min_confidence": min_confidence,
|
322
|
+
"multi_label": multi_label,
|
323
|
+
"batch_size": batch_size,
|
324
|
+
**kwargs,
|
268
325
|
}
|
269
326
|
|
270
|
-
logger.info(
|
327
|
+
logger.info(
|
328
|
+
f"Classifying batch of {len(item_contents)} items with model '{model_id}' (batch size: {batch_size})"
|
329
|
+
)
|
271
330
|
batch_results_list: List[ClassificationResult] = []
|
272
331
|
|
273
332
|
try:
|
274
333
|
# Use pipeline directly for batching
|
275
334
|
results_iterator = pipeline_instance(
|
276
335
|
item_contents,
|
277
|
-
candidate_labels=
|
336
|
+
candidate_labels=labels,
|
278
337
|
multi_label=multi_label,
|
279
338
|
batch_size=batch_size,
|
280
|
-
**kwargs
|
339
|
+
**kwargs,
|
281
340
|
)
|
282
|
-
|
341
|
+
|
283
342
|
# Wrap with tqdm for progress if requested
|
284
343
|
total_items = len(item_contents)
|
285
344
|
if progress_bar:
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
|
291
|
-
|
292
|
-
|
293
|
-
|
345
|
+
# Get the appropriate tqdm class
|
346
|
+
tqdm_class = get_tqdm()
|
347
|
+
results_iterator = tqdm_class(
|
348
|
+
results_iterator,
|
349
|
+
total=total_items,
|
350
|
+
desc=f"Classifying batch ({model_id})",
|
351
|
+
leave=False, # Don't leave progress bar hanging
|
352
|
+
)
|
294
353
|
|
295
354
|
for raw_result in results_iterator:
|
296
|
-
# --- Process each raw result (which corresponds to ONE input item) --- #
|
355
|
+
# --- Process each raw result (which corresponds to ONE input item) --- #
|
297
356
|
scores_list: List[CategoryScore] = []
|
298
357
|
try:
|
299
358
|
# Check for text format (dict with 'labels' and 'scores')
|
300
|
-
if
|
301
|
-
|
302
|
-
|
303
|
-
|
304
|
-
|
305
|
-
|
306
|
-
|
359
|
+
if (
|
360
|
+
isinstance(raw_result, dict)
|
361
|
+
and "labels" in raw_result
|
362
|
+
and "scores" in raw_result
|
363
|
+
):
|
364
|
+
for label, score_val in zip(raw_result["labels"], raw_result["scores"]):
|
365
|
+
if score_val >= min_confidence:
|
366
|
+
try:
|
367
|
+
scores_list.append(CategoryScore(label, score_val))
|
368
|
+
except (ValueError, TypeError) as score_err:
|
369
|
+
logger.warning(
|
370
|
+
f"Skipping invalid score from text pipeline batch: label='{label}', score={score_val}. Error: {score_err}"
|
371
|
+
)
|
307
372
|
# Check for vision format (list of dicts with 'label' and 'score')
|
308
373
|
elif isinstance(raw_result, list):
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
|
315
|
-
|
316
|
-
|
374
|
+
for item in raw_result:
|
375
|
+
try:
|
376
|
+
score_val = item["score"]
|
377
|
+
label = item["label"]
|
378
|
+
if score_val >= min_confidence:
|
379
|
+
scores_list.append(CategoryScore(label, score_val))
|
380
|
+
except (KeyError, ValueError, TypeError) as item_err:
|
381
|
+
logger.warning(
|
382
|
+
f"Skipping invalid item in vision result list from batch: {item}. Error: {item_err}"
|
383
|
+
)
|
317
384
|
else:
|
318
|
-
|
319
|
-
|
385
|
+
logger.warning(
|
386
|
+
f"Unexpected raw result format in batch item from model '{model_id}': {type(raw_result)}. Cannot extract scores."
|
387
|
+
)
|
388
|
+
|
320
389
|
except Exception as proc_err:
|
321
|
-
|
322
|
-
|
390
|
+
logger.error(
|
391
|
+
f"Error processing result item in batch: {proc_err}", exc_info=True
|
392
|
+
)
|
393
|
+
# scores_list remains empty for this item
|
394
|
+
|
395
|
+
# --- Determine top category and score ---
|
396
|
+
scores_list.sort(key=lambda s: s.score, reverse=True)
|
397
|
+
top_category = scores_list[0].label
|
398
|
+
top_score = scores_list[0].score
|
399
|
+
# --- End Determine top category ---
|
323
400
|
|
324
401
|
# Append result object for this item
|
325
|
-
batch_results_list.append(
|
326
|
-
|
327
|
-
|
328
|
-
|
329
|
-
|
330
|
-
|
331
|
-
|
402
|
+
batch_results_list.append(
|
403
|
+
ClassificationResult(
|
404
|
+
scores=scores_list, # Pass the full list, init will sort/filter
|
405
|
+
model_id=model_id,
|
406
|
+
using=effective_using,
|
407
|
+
timestamp=timestamp, # Use same timestamp for batch
|
408
|
+
parameters=parameters, # Use same params for batch
|
409
|
+
)
|
410
|
+
)
|
332
411
|
# --- End Processing --- #
|
333
412
|
|
334
413
|
if len(batch_results_list) != total_items:
|
335
|
-
|
414
|
+
logger.warning(
|
415
|
+
f"Batch classification returned {len(batch_results_list)} results, but expected {total_items}. Results might be incomplete or misaligned."
|
416
|
+
)
|
336
417
|
|
337
418
|
return batch_results_list
|
338
419
|
|
339
420
|
except Exception as e:
|
340
|
-
|
341
|
-
|
342
|
-
|
343
|
-
|
421
|
+
logger.error(f"Batch classification failed for model '{model_id}': {e}", exc_info=True)
|
422
|
+
# Return list of empty results?
|
423
|
+
# return [ClassificationResult(model_id=model_id, s=engine_type, timestamp=timestamp, parameters=parameters, scores=[]) for _ in item_contents]
|
424
|
+
raise ClassificationError(
|
425
|
+
f"Batch classification failed using model '{model_id}'. Error: {e}"
|
426
|
+
) from e
|