natural-pdf 0.1.8__py3-none-any.whl → 0.1.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- natural_pdf/__init__.py +1 -0
- natural_pdf/analyzers/layout/base.py +1 -5
- natural_pdf/analyzers/layout/gemini.py +61 -51
- natural_pdf/analyzers/layout/layout_analyzer.py +40 -11
- natural_pdf/analyzers/layout/layout_manager.py +26 -84
- natural_pdf/analyzers/layout/layout_options.py +7 -0
- natural_pdf/analyzers/layout/pdfplumber_table_finder.py +142 -0
- natural_pdf/analyzers/layout/surya.py +46 -123
- natural_pdf/analyzers/layout/tatr.py +51 -4
- natural_pdf/analyzers/text_structure.py +3 -5
- natural_pdf/analyzers/utils.py +3 -3
- natural_pdf/classification/manager.py +230 -151
- natural_pdf/classification/mixin.py +49 -35
- natural_pdf/classification/results.py +64 -46
- natural_pdf/collections/mixins.py +68 -20
- natural_pdf/collections/pdf_collection.py +177 -64
- natural_pdf/core/element_manager.py +30 -14
- natural_pdf/core/highlighting_service.py +13 -22
- natural_pdf/core/page.py +423 -101
- natural_pdf/core/pdf.py +633 -190
- natural_pdf/elements/base.py +134 -40
- natural_pdf/elements/collections.py +503 -131
- natural_pdf/elements/region.py +659 -90
- natural_pdf/elements/text.py +1 -1
- natural_pdf/export/mixin.py +137 -0
- natural_pdf/exporters/base.py +3 -3
- natural_pdf/exporters/paddleocr.py +4 -3
- natural_pdf/extraction/manager.py +50 -49
- natural_pdf/extraction/mixin.py +90 -57
- natural_pdf/extraction/result.py +9 -23
- natural_pdf/ocr/__init__.py +5 -5
- natural_pdf/ocr/engine_doctr.py +346 -0
- natural_pdf/ocr/ocr_factory.py +24 -4
- natural_pdf/ocr/ocr_manager.py +61 -25
- natural_pdf/ocr/ocr_options.py +70 -10
- natural_pdf/ocr/utils.py +6 -4
- natural_pdf/search/__init__.py +20 -34
- natural_pdf/search/haystack_search_service.py +309 -265
- natural_pdf/search/haystack_utils.py +99 -75
- natural_pdf/search/search_service_protocol.py +11 -12
- natural_pdf/selectors/parser.py +219 -143
- natural_pdf/utils/debug.py +3 -3
- natural_pdf/utils/identifiers.py +1 -1
- natural_pdf/utils/locks.py +1 -1
- natural_pdf/utils/packaging.py +8 -6
- natural_pdf/utils/text_extraction.py +24 -16
- natural_pdf/utils/tqdm_utils.py +18 -10
- natural_pdf/utils/visualization.py +18 -0
- natural_pdf/widgets/viewer.py +4 -25
- {natural_pdf-0.1.8.dist-info → natural_pdf-0.1.9.dist-info}/METADATA +12 -3
- natural_pdf-0.1.9.dist-info/RECORD +80 -0
- {natural_pdf-0.1.8.dist-info → natural_pdf-0.1.9.dist-info}/WHEEL +1 -1
- {natural_pdf-0.1.8.dist-info → natural_pdf-0.1.9.dist-info}/top_level.txt +0 -2
- docs/api/index.md +0 -386
- docs/assets/favicon.png +0 -3
- docs/assets/favicon.svg +0 -3
- docs/assets/javascripts/custom.js +0 -17
- docs/assets/logo.svg +0 -3
- docs/assets/sample-screen.png +0 -0
- docs/assets/social-preview.png +0 -17
- docs/assets/social-preview.svg +0 -17
- docs/assets/stylesheets/custom.css +0 -65
- docs/categorizing-documents/index.md +0 -168
- docs/data-extraction/index.md +0 -87
- docs/document-qa/index.ipynb +0 -435
- docs/document-qa/index.md +0 -79
- docs/element-selection/index.ipynb +0 -969
- docs/element-selection/index.md +0 -249
- docs/finetuning/index.md +0 -176
- docs/index.md +0 -189
- docs/installation/index.md +0 -69
- docs/interactive-widget/index.ipynb +0 -962
- docs/interactive-widget/index.md +0 -12
- docs/layout-analysis/index.ipynb +0 -818
- docs/layout-analysis/index.md +0 -185
- docs/ocr/index.md +0 -256
- docs/pdf-navigation/index.ipynb +0 -314
- docs/pdf-navigation/index.md +0 -97
- docs/regions/index.ipynb +0 -816
- docs/regions/index.md +0 -294
- docs/tables/index.ipynb +0 -658
- docs/tables/index.md +0 -144
- docs/text-analysis/index.ipynb +0 -370
- docs/text-analysis/index.md +0 -105
- docs/text-extraction/index.ipynb +0 -1478
- docs/text-extraction/index.md +0 -292
- docs/tutorials/01-loading-and-extraction.ipynb +0 -1873
- docs/tutorials/01-loading-and-extraction.md +0 -95
- docs/tutorials/02-finding-elements.ipynb +0 -417
- docs/tutorials/02-finding-elements.md +0 -149
- docs/tutorials/03-extracting-blocks.ipynb +0 -152
- docs/tutorials/03-extracting-blocks.md +0 -48
- docs/tutorials/04-table-extraction.ipynb +0 -119
- docs/tutorials/04-table-extraction.md +0 -50
- docs/tutorials/05-excluding-content.ipynb +0 -275
- docs/tutorials/05-excluding-content.md +0 -109
- docs/tutorials/06-document-qa.ipynb +0 -337
- docs/tutorials/06-document-qa.md +0 -91
- docs/tutorials/07-layout-analysis.ipynb +0 -293
- docs/tutorials/07-layout-analysis.md +0 -66
- docs/tutorials/07-working-with-regions.ipynb +0 -414
- docs/tutorials/07-working-with-regions.md +0 -151
- docs/tutorials/08-spatial-navigation.ipynb +0 -513
- docs/tutorials/08-spatial-navigation.md +0 -190
- docs/tutorials/09-section-extraction.ipynb +0 -2439
- docs/tutorials/09-section-extraction.md +0 -256
- docs/tutorials/10-form-field-extraction.ipynb +0 -517
- docs/tutorials/10-form-field-extraction.md +0 -201
- docs/tutorials/11-enhanced-table-processing.ipynb +0 -59
- docs/tutorials/11-enhanced-table-processing.md +0 -9
- docs/tutorials/12-ocr-integration.ipynb +0 -3712
- docs/tutorials/12-ocr-integration.md +0 -137
- docs/tutorials/13-semantic-search.ipynb +0 -1718
- docs/tutorials/13-semantic-search.md +0 -77
- docs/visual-debugging/index.ipynb +0 -2970
- docs/visual-debugging/index.md +0 -157
- docs/visual-debugging/region.png +0 -0
- natural_pdf/templates/finetune/fine_tune_paddleocr.md +0 -420
- natural_pdf/templates/spa/css/style.css +0 -334
- natural_pdf/templates/spa/index.html +0 -31
- natural_pdf/templates/spa/js/app.js +0 -472
- natural_pdf/templates/spa/words.txt +0 -235976
- natural_pdf/widgets/frontend/viewer.js +0 -88
- natural_pdf-0.1.8.dist-info/RECORD +0 -156
- notebooks/Examples.ipynb +0 -1293
- pdfs/.gitkeep +0 -0
- pdfs/01-practice.pdf +0 -543
- pdfs/0500000US42001.pdf +0 -0
- pdfs/0500000US42007.pdf +0 -0
- pdfs/2014 Statistics.pdf +0 -0
- pdfs/2019 Statistics.pdf +0 -0
- pdfs/Atlanta_Public_Schools_GA_sample.pdf +0 -0
- pdfs/needs-ocr.pdf +0 -0
- {natural_pdf-0.1.8.dist-info → natural_pdf-0.1.9.dist-info}/licenses/LICENSE +0 -0
@@ -1,28 +1,35 @@
|
|
1
1
|
import logging
|
2
2
|
import time
|
3
3
|
from datetime import datetime
|
4
|
-
from typing import TYPE_CHECKING, Any, Dict, List, Optional,
|
4
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
5
|
+
|
6
|
+
from PIL import Image
|
5
7
|
|
6
8
|
# Use try-except for robustness if dependencies are missing
|
7
9
|
try:
|
8
10
|
import torch
|
9
|
-
from
|
10
|
-
|
11
|
+
from transformers import (
|
12
|
+
AutoModelForSequenceClassification,
|
13
|
+
AutoModelForZeroShotImageClassification,
|
14
|
+
AutoTokenizer,
|
15
|
+
pipeline,
|
16
|
+
)
|
17
|
+
|
11
18
|
_CLASSIFICATION_AVAILABLE = True
|
12
19
|
except ImportError:
|
13
20
|
_CLASSIFICATION_AVAILABLE = False
|
14
21
|
# Define dummy types for type hinting if imports fail
|
15
|
-
Image = type("Image", (), {})
|
16
22
|
pipeline = object
|
17
23
|
AutoTokenizer = object
|
18
24
|
AutoModelForZeroShotImageClassification = object
|
19
25
|
AutoModelForSequenceClassification = object
|
20
26
|
torch = None
|
21
27
|
|
22
|
-
# Import result classes
|
23
|
-
from .results import ClassificationResult, CategoryScore
|
24
28
|
from natural_pdf.utils.tqdm_utils import get_tqdm
|
25
29
|
|
30
|
+
# Import result classes
|
31
|
+
from .results import CategoryScore, ClassificationResult
|
32
|
+
|
26
33
|
if TYPE_CHECKING:
|
27
34
|
from transformers import Pipeline
|
28
35
|
|
@@ -34,8 +41,10 @@ _PIPELINE_CACHE: Dict[str, "Pipeline"] = {}
|
|
34
41
|
_TOKENIZER_CACHE: Dict[str, Any] = {}
|
35
42
|
_MODEL_CACHE: Dict[str, Any] = {}
|
36
43
|
|
44
|
+
|
37
45
|
class ClassificationError(Exception):
|
38
46
|
"""Custom exception for classification errors."""
|
47
|
+
|
39
48
|
pass
|
40
49
|
|
41
50
|
|
@@ -60,10 +69,12 @@ class ClassificationManager:
|
|
60
69
|
if not _CLASSIFICATION_AVAILABLE:
|
61
70
|
raise ImportError(
|
62
71
|
"Classification dependencies missing. "
|
63
|
-
|
72
|
+
'Install with: pip install "natural-pdf[classification]"'
|
64
73
|
)
|
65
74
|
|
66
|
-
self.pipelines: Dict[Tuple[str, str], "Pipeline"] =
|
75
|
+
self.pipelines: Dict[Tuple[str, str], "Pipeline"] = (
|
76
|
+
{}
|
77
|
+
) # Cache: (model_id, device) -> pipeline
|
67
78
|
|
68
79
|
self.device = default_device
|
69
80
|
logger.info(f"ClassificationManager initialized on device: {self.device}")
|
@@ -76,7 +87,9 @@ class ClassificationManager:
|
|
76
87
|
"""Get or create a classification pipeline."""
|
77
88
|
cache_key = f"{model_id}_{using}_{self.device}"
|
78
89
|
if cache_key not in _PIPELINE_CACHE:
|
79
|
-
logger.info(
|
90
|
+
logger.info(
|
91
|
+
f"Loading {using} classification pipeline for model '{model_id}' on device '{self.device}'..."
|
92
|
+
)
|
80
93
|
start_time = time.time()
|
81
94
|
try:
|
82
95
|
task = (
|
@@ -84,16 +97,19 @@ class ClassificationManager:
|
|
84
97
|
if using == "text"
|
85
98
|
else "zero-shot-image-classification"
|
86
99
|
)
|
87
|
-
_PIPELINE_CACHE[cache_key] = pipeline(
|
88
|
-
task,
|
89
|
-
model=model_id,
|
90
|
-
device=self.device
|
91
|
-
)
|
100
|
+
_PIPELINE_CACHE[cache_key] = pipeline(task, model=model_id, device=self.device)
|
92
101
|
end_time = time.time()
|
93
|
-
logger.info(
|
102
|
+
logger.info(
|
103
|
+
f"Pipeline for '{model_id}' loaded in {end_time - start_time:.2f} seconds."
|
104
|
+
)
|
94
105
|
except Exception as e:
|
95
|
-
logger.error(
|
96
|
-
|
106
|
+
logger.error(
|
107
|
+
f"Failed to load pipeline for model '{model_id}' (using: {using}): {e}",
|
108
|
+
exc_info=True,
|
109
|
+
)
|
110
|
+
raise ClassificationError(
|
111
|
+
f"Failed to load pipeline for model '{model_id}'. Ensure the model ID is correct and supports the {task} task."
|
112
|
+
) from e
|
97
113
|
return _PIPELINE_CACHE[cache_key]
|
98
114
|
|
99
115
|
def infer_using(self, model_id: str, using: Optional[str] = None) -> str:
|
@@ -103,28 +119,44 @@ class ClassificationManager:
|
|
103
119
|
|
104
120
|
# Simple inference based on common model names
|
105
121
|
normalized_model_id = model_id.lower()
|
106
|
-
if
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
122
|
+
if (
|
123
|
+
"clip" in normalized_model_id
|
124
|
+
or "vit" in normalized_model_id
|
125
|
+
or "siglip" in normalized_model_id
|
126
|
+
):
|
127
|
+
logger.debug(f"Inferred using='vision' for model '{model_id}'")
|
128
|
+
return "vision"
|
129
|
+
if (
|
130
|
+
"bart" in normalized_model_id
|
131
|
+
or "bert" in normalized_model_id
|
132
|
+
or "mnli" in normalized_model_id
|
133
|
+
or "xnli" in normalized_model_id
|
134
|
+
or "deberta" in normalized_model_id
|
135
|
+
):
|
136
|
+
logger.debug(f"Inferred using='text' for model '{model_id}'")
|
137
|
+
return "text"
|
112
138
|
|
113
139
|
# Fallback or raise error? Let's try loading text first, then vision.
|
114
|
-
logger.warning(
|
140
|
+
logger.warning(
|
141
|
+
f"Could not reliably infer mode for '{model_id}'. Trying text, then vision pipeline loading."
|
142
|
+
)
|
115
143
|
try:
|
116
144
|
self._get_pipeline(model_id, "text")
|
117
145
|
logger.info(f"Successfully loaded '{model_id}' as a text model.")
|
118
146
|
return "text"
|
119
147
|
except Exception:
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
148
|
+
logger.warning(f"Failed to load '{model_id}' as text model. Trying vision.")
|
149
|
+
try:
|
150
|
+
self._get_pipeline(model_id, "vision")
|
151
|
+
logger.info(f"Successfully loaded '{model_id}' as a vision model.")
|
152
|
+
return "vision"
|
153
|
+
except Exception as e_vision:
|
154
|
+
logger.error(
|
155
|
+
f"Failed to load '{model_id}' as either text or vision model.", exc_info=True
|
156
|
+
)
|
157
|
+
raise ClassificationError(
|
158
|
+
f"Cannot determine mode for model '{model_id}'. Please specify `using='text'` or `using='vision'`. Error: {e_vision}"
|
159
|
+
)
|
128
160
|
|
129
161
|
def classify_item(
|
130
162
|
self,
|
@@ -134,90 +166,109 @@ class ClassificationManager:
|
|
134
166
|
using: Optional[str] = None,
|
135
167
|
min_confidence: float = 0.0,
|
136
168
|
multi_label: bool = False,
|
137
|
-
**kwargs
|
138
|
-
) -> ClassificationResult:
|
169
|
+
**kwargs,
|
170
|
+
) -> ClassificationResult: # Return ClassificationResult
|
139
171
|
"""Classifies a single item (text or image)."""
|
140
172
|
|
141
173
|
# Determine model and engine type
|
142
174
|
effective_using = using
|
143
175
|
if model_id is None:
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
176
|
+
# Try inferring based on content type
|
177
|
+
if isinstance(item_content, str):
|
178
|
+
effective_using = "text"
|
179
|
+
model_id = self.DEFAULT_TEXT_MODEL
|
180
|
+
elif isinstance(item_content, Image.Image):
|
181
|
+
effective_using = "vision"
|
182
|
+
model_id = self.DEFAULT_VISION_MODEL
|
183
|
+
else:
|
184
|
+
raise TypeError(f"Unsupported item_content type: {type(item_content)}")
|
153
185
|
else:
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
186
|
+
# Infer engine type if not given
|
187
|
+
effective_using = self.infer_using(model_id, using)
|
188
|
+
# Set default model if needed (though should usually be provided if engine known)
|
189
|
+
if model_id is None:
|
190
|
+
model_id = (
|
191
|
+
self.DEFAULT_TEXT_MODEL
|
192
|
+
if effective_using == "text"
|
193
|
+
else self.DEFAULT_VISION_MODEL
|
194
|
+
)
|
159
195
|
|
160
196
|
if not categories:
|
161
|
-
|
197
|
+
raise ValueError("Categories list cannot be empty.")
|
162
198
|
|
163
199
|
pipeline_instance = self._get_pipeline(model_id, effective_using)
|
164
200
|
timestamp = datetime.now()
|
165
|
-
parameters = {
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
**kwargs
|
201
|
+
parameters = { # Store parameters used for this run
|
202
|
+
"categories": categories,
|
203
|
+
"model_id": model_id,
|
204
|
+
"using": effective_using,
|
205
|
+
"min_confidence": min_confidence,
|
206
|
+
"multi_label": multi_label,
|
207
|
+
**kwargs,
|
172
208
|
}
|
173
209
|
|
174
|
-
logger.debug(
|
210
|
+
logger.debug(
|
211
|
+
f"Classifying content (type: {type(item_content).__name__}) with model '{model_id}'"
|
212
|
+
)
|
175
213
|
try:
|
176
214
|
# Handle potential kwargs for specific pipelines if needed
|
177
215
|
# The zero-shot pipelines expect `candidate_labels`
|
178
|
-
result_raw = pipeline_instance(
|
216
|
+
result_raw = pipeline_instance(
|
217
|
+
item_content, candidate_labels=categories, multi_label=multi_label, **kwargs
|
218
|
+
)
|
179
219
|
logger.debug(f"Raw pipeline result: {result_raw}")
|
180
220
|
|
181
|
-
# --- Process raw result into ClassificationResult --- #
|
221
|
+
# --- Process raw result into ClassificationResult --- #
|
182
222
|
scores_list: List[CategoryScore] = []
|
183
|
-
|
223
|
+
|
184
224
|
# Handle text pipeline format (dict with 'labels' and 'scores')
|
185
|
-
if isinstance(result_raw, dict) and
|
186
|
-
for label, score_val in zip(result_raw[
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
225
|
+
if isinstance(result_raw, dict) and "labels" in result_raw and "scores" in result_raw:
|
226
|
+
for label, score_val in zip(result_raw["labels"], result_raw["scores"]):
|
227
|
+
if score_val >= min_confidence:
|
228
|
+
try:
|
229
|
+
scores_list.append(CategoryScore(label=label, confidence=score_val))
|
230
|
+
except (ValueError, TypeError) as score_err:
|
231
|
+
logger.warning(
|
232
|
+
f"Skipping invalid score from text pipeline: label='{label}', score={score_val}. Error: {score_err}"
|
233
|
+
)
|
192
234
|
# Handle vision pipeline format (list of dicts with 'label' and 'score')
|
193
|
-
elif isinstance(result_raw, list) and all(
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
235
|
+
elif isinstance(result_raw, list) and all(
|
236
|
+
isinstance(item, dict) and "label" in item and "score" in item
|
237
|
+
for item in result_raw
|
238
|
+
):
|
239
|
+
for item in result_raw:
|
240
|
+
score_val = item["score"]
|
241
|
+
label = item["label"]
|
242
|
+
if score_val >= min_confidence:
|
243
|
+
try:
|
244
|
+
scores_list.append(CategoryScore(label=label, confidence=score_val))
|
245
|
+
except (ValueError, TypeError) as score_err:
|
246
|
+
logger.warning(
|
247
|
+
f"Skipping invalid score from vision pipeline: label='{label}', score={score_val}. Error: {score_err}"
|
248
|
+
)
|
202
249
|
else:
|
203
|
-
|
204
|
-
|
205
|
-
|
250
|
+
logger.warning(
|
251
|
+
f"Unexpected raw result format from pipeline for model '{model_id}': {type(result_raw)}. Cannot extract scores."
|
252
|
+
)
|
253
|
+
# Return empty result?
|
254
|
+
# scores_list = []
|
206
255
|
|
207
256
|
return ClassificationResult(
|
208
257
|
model_id=model_id,
|
209
258
|
using=effective_using,
|
210
259
|
timestamp=timestamp,
|
211
260
|
parameters=parameters,
|
212
|
-
scores=scores_list
|
261
|
+
scores=scores_list,
|
213
262
|
)
|
214
263
|
# --- End Processing --- #
|
215
264
|
|
216
265
|
except Exception as e:
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
266
|
+
logger.error(f"Classification failed for model '{model_id}': {e}", exc_info=True)
|
267
|
+
# Return an empty result object on failure?
|
268
|
+
# return ClassificationResult(model_id=model_id, engine_type=engine_type, timestamp=timestamp, parameters=parameters, scores=[])
|
269
|
+
raise ClassificationError(
|
270
|
+
f"Classification failed using model '{model_id}'. Error: {e}"
|
271
|
+
) from e
|
221
272
|
|
222
273
|
def classify_batch(
|
223
274
|
self,
|
@@ -229,45 +280,51 @@ class ClassificationManager:
|
|
229
280
|
multi_label: bool = False,
|
230
281
|
batch_size: int = 8,
|
231
282
|
progress_bar: bool = True,
|
232
|
-
**kwargs
|
233
|
-
) -> List[ClassificationResult]:
|
283
|
+
**kwargs,
|
284
|
+
) -> List[ClassificationResult]: # Return list of ClassificationResult
|
234
285
|
"""Classifies a batch of items (text or image) using the pipeline's batching."""
|
235
286
|
if not item_contents:
|
236
|
-
|
287
|
+
return []
|
237
288
|
|
238
289
|
# Determine model and engine type (assuming uniform type in batch)
|
239
290
|
first_item = item_contents[0]
|
240
291
|
effective_using = using
|
241
292
|
if model_id is None:
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
293
|
+
if isinstance(first_item, str):
|
294
|
+
effective_using = "text"
|
295
|
+
model_id = self.DEFAULT_TEXT_MODEL
|
296
|
+
elif isinstance(first_item, Image.Image):
|
297
|
+
effective_using = "vision"
|
298
|
+
model_id = self.DEFAULT_VISION_MODEL
|
299
|
+
else:
|
300
|
+
raise TypeError(f"Unsupported item_content type in batch: {type(first_item)}")
|
250
301
|
else:
|
251
|
-
|
252
|
-
|
253
|
-
|
302
|
+
effective_using = self.infer_using(model_id, using)
|
303
|
+
if model_id is None:
|
304
|
+
model_id = (
|
305
|
+
self.DEFAULT_TEXT_MODEL
|
306
|
+
if effective_using == "text"
|
307
|
+
else self.DEFAULT_VISION_MODEL
|
308
|
+
)
|
254
309
|
|
255
310
|
if not categories:
|
256
|
-
|
311
|
+
raise ValueError("Categories list cannot be empty.")
|
257
312
|
|
258
313
|
pipeline_instance = self._get_pipeline(model_id, effective_using)
|
259
|
-
timestamp = datetime.now()
|
260
|
-
parameters = {
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
267
|
-
**kwargs
|
314
|
+
timestamp = datetime.now() # Single timestamp for the batch run
|
315
|
+
parameters = { # Parameters for the whole batch
|
316
|
+
"categories": categories,
|
317
|
+
"model_id": model_id,
|
318
|
+
"using": effective_using,
|
319
|
+
"min_confidence": min_confidence,
|
320
|
+
"multi_label": multi_label,
|
321
|
+
"batch_size": batch_size,
|
322
|
+
**kwargs,
|
268
323
|
}
|
269
324
|
|
270
|
-
logger.info(
|
325
|
+
logger.info(
|
326
|
+
f"Classifying batch of {len(item_contents)} items with model '{model_id}' (batch size: {batch_size})"
|
327
|
+
)
|
271
328
|
batch_results_list: List[ClassificationResult] = []
|
272
329
|
|
273
330
|
try:
|
@@ -277,67 +334,89 @@ class ClassificationManager:
|
|
277
334
|
candidate_labels=categories,
|
278
335
|
multi_label=multi_label,
|
279
336
|
batch_size=batch_size,
|
280
|
-
**kwargs
|
337
|
+
**kwargs,
|
281
338
|
)
|
282
|
-
|
339
|
+
|
283
340
|
# Wrap with tqdm for progress if requested
|
284
341
|
total_items = len(item_contents)
|
285
342
|
if progress_bar:
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
|
291
|
-
|
292
|
-
|
293
|
-
|
343
|
+
# Get the appropriate tqdm class
|
344
|
+
tqdm_class = get_tqdm()
|
345
|
+
results_iterator = tqdm_class(
|
346
|
+
results_iterator,
|
347
|
+
total=total_items,
|
348
|
+
desc=f"Classifying batch ({model_id})",
|
349
|
+
leave=False, # Don't leave progress bar hanging
|
350
|
+
)
|
294
351
|
|
295
352
|
for raw_result in results_iterator:
|
296
|
-
# --- Process each raw result (which corresponds to ONE input item) --- #
|
353
|
+
# --- Process each raw result (which corresponds to ONE input item) --- #
|
297
354
|
scores_list: List[CategoryScore] = []
|
298
355
|
try:
|
299
356
|
# Check for text format (dict with 'labels' and 'scores')
|
300
|
-
if
|
301
|
-
|
302
|
-
|
303
|
-
|
304
|
-
|
305
|
-
|
306
|
-
|
357
|
+
if (
|
358
|
+
isinstance(raw_result, dict)
|
359
|
+
and "labels" in raw_result
|
360
|
+
and "scores" in raw_result
|
361
|
+
):
|
362
|
+
for label, score_val in zip(raw_result["labels"], raw_result["scores"]):
|
363
|
+
if score_val >= min_confidence:
|
364
|
+
try:
|
365
|
+
scores_list.append(
|
366
|
+
CategoryScore(label=label, confidence=score_val)
|
367
|
+
)
|
368
|
+
except (ValueError, TypeError) as score_err:
|
369
|
+
logger.warning(
|
370
|
+
f"Skipping invalid score from text pipeline batch: label='{label}', score={score_val}. Error: {score_err}"
|
371
|
+
)
|
307
372
|
# Check for vision format (list of dicts with 'label' and 'score')
|
308
373
|
elif isinstance(raw_result, list):
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
|
315
|
-
|
316
|
-
|
374
|
+
for item in raw_result:
|
375
|
+
try:
|
376
|
+
score_val = item["score"]
|
377
|
+
label = item["label"]
|
378
|
+
if score_val >= min_confidence:
|
379
|
+
scores_list.append(
|
380
|
+
CategoryScore(label=label, confidence=score_val)
|
381
|
+
)
|
382
|
+
except (KeyError, ValueError, TypeError) as item_err:
|
383
|
+
logger.warning(
|
384
|
+
f"Skipping invalid item in vision result list from batch: {item}. Error: {item_err}"
|
385
|
+
)
|
317
386
|
else:
|
318
|
-
|
319
|
-
|
387
|
+
logger.warning(
|
388
|
+
f"Unexpected raw result format in batch item from model '{model_id}': {type(raw_result)}. Cannot extract scores."
|
389
|
+
)
|
390
|
+
|
320
391
|
except Exception as proc_err:
|
321
|
-
|
322
|
-
|
392
|
+
logger.error(
|
393
|
+
f"Error processing result item in batch: {proc_err}", exc_info=True
|
394
|
+
)
|
395
|
+
# scores_list remains empty for this item
|
323
396
|
|
324
397
|
# Append result object for this item
|
325
|
-
batch_results_list.append(
|
326
|
-
|
327
|
-
|
328
|
-
|
329
|
-
|
330
|
-
|
331
|
-
|
398
|
+
batch_results_list.append(
|
399
|
+
ClassificationResult(
|
400
|
+
model_id=model_id,
|
401
|
+
using=effective_using,
|
402
|
+
timestamp=timestamp, # Use same timestamp for batch
|
403
|
+
parameters=parameters, # Use same params for batch
|
404
|
+
scores=scores_list,
|
405
|
+
)
|
406
|
+
)
|
332
407
|
# --- End Processing --- #
|
333
408
|
|
334
409
|
if len(batch_results_list) != total_items:
|
335
|
-
|
410
|
+
logger.warning(
|
411
|
+
f"Batch classification returned {len(batch_results_list)} results, but expected {total_items}. Results might be incomplete or misaligned."
|
412
|
+
)
|
336
413
|
|
337
414
|
return batch_results_list
|
338
415
|
|
339
416
|
except Exception as e:
|
340
|
-
|
341
|
-
|
342
|
-
|
343
|
-
|
417
|
+
logger.error(f"Batch classification failed for model '{model_id}': {e}", exc_info=True)
|
418
|
+
# Return list of empty results?
|
419
|
+
# return [ClassificationResult(model_id=model_id, s=engine_type, timestamp=timestamp, parameters=parameters, scores=[]) for _ in item_contents]
|
420
|
+
raise ClassificationError(
|
421
|
+
f"Batch classification failed using model '{model_id}'. Error: {e}"
|
422
|
+
) from e
|