natural-pdf 0.1.6__py3-none-any.whl → 0.1.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. docs/categorizing-documents/index.md +168 -0
  2. docs/data-extraction/index.md +87 -0
  3. docs/element-selection/index.ipynb +218 -164
  4. docs/element-selection/index.md +20 -0
  5. docs/finetuning/index.md +176 -0
  6. docs/index.md +19 -0
  7. docs/ocr/index.md +63 -16
  8. docs/tutorials/01-loading-and-extraction.ipynb +411 -248
  9. docs/tutorials/02-finding-elements.ipynb +123 -46
  10. docs/tutorials/03-extracting-blocks.ipynb +24 -19
  11. docs/tutorials/04-table-extraction.ipynb +17 -12
  12. docs/tutorials/05-excluding-content.ipynb +37 -32
  13. docs/tutorials/06-document-qa.ipynb +36 -31
  14. docs/tutorials/07-layout-analysis.ipynb +45 -40
  15. docs/tutorials/07-working-with-regions.ipynb +61 -60
  16. docs/tutorials/08-spatial-navigation.ipynb +76 -71
  17. docs/tutorials/09-section-extraction.ipynb +160 -155
  18. docs/tutorials/10-form-field-extraction.ipynb +71 -66
  19. docs/tutorials/11-enhanced-table-processing.ipynb +11 -6
  20. docs/tutorials/12-ocr-integration.ipynb +3420 -312
  21. docs/tutorials/12-ocr-integration.md +68 -106
  22. docs/tutorials/13-semantic-search.ipynb +641 -251
  23. natural_pdf/__init__.py +3 -0
  24. natural_pdf/analyzers/layout/gemini.py +63 -47
  25. natural_pdf/classification/manager.py +343 -0
  26. natural_pdf/classification/mixin.py +149 -0
  27. natural_pdf/classification/results.py +62 -0
  28. natural_pdf/collections/mixins.py +63 -0
  29. natural_pdf/collections/pdf_collection.py +326 -17
  30. natural_pdf/core/element_manager.py +73 -4
  31. natural_pdf/core/page.py +255 -83
  32. natural_pdf/core/pdf.py +385 -367
  33. natural_pdf/elements/base.py +1 -3
  34. natural_pdf/elements/collections.py +279 -49
  35. natural_pdf/elements/region.py +106 -21
  36. natural_pdf/elements/text.py +5 -2
  37. natural_pdf/exporters/__init__.py +4 -0
  38. natural_pdf/exporters/base.py +61 -0
  39. natural_pdf/exporters/paddleocr.py +345 -0
  40. natural_pdf/extraction/manager.py +134 -0
  41. natural_pdf/extraction/mixin.py +246 -0
  42. natural_pdf/extraction/result.py +37 -0
  43. natural_pdf/ocr/__init__.py +16 -8
  44. natural_pdf/ocr/engine.py +46 -30
  45. natural_pdf/ocr/engine_easyocr.py +86 -42
  46. natural_pdf/ocr/engine_paddle.py +39 -28
  47. natural_pdf/ocr/engine_surya.py +32 -16
  48. natural_pdf/ocr/ocr_factory.py +34 -23
  49. natural_pdf/ocr/ocr_manager.py +98 -34
  50. natural_pdf/ocr/ocr_options.py +38 -10
  51. natural_pdf/ocr/utils.py +59 -33
  52. natural_pdf/qa/document_qa.py +0 -4
  53. natural_pdf/selectors/parser.py +363 -238
  54. natural_pdf/templates/finetune/fine_tune_paddleocr.md +420 -0
  55. natural_pdf/utils/debug.py +4 -2
  56. natural_pdf/utils/identifiers.py +9 -5
  57. natural_pdf/utils/locks.py +8 -0
  58. natural_pdf/utils/packaging.py +172 -105
  59. natural_pdf/utils/text_extraction.py +96 -65
  60. natural_pdf/utils/tqdm_utils.py +43 -0
  61. natural_pdf/utils/visualization.py +1 -1
  62. {natural_pdf-0.1.6.dist-info → natural_pdf-0.1.8.dist-info}/METADATA +10 -3
  63. {natural_pdf-0.1.6.dist-info → natural_pdf-0.1.8.dist-info}/RECORD +66 -51
  64. {natural_pdf-0.1.6.dist-info → natural_pdf-0.1.8.dist-info}/WHEEL +1 -1
  65. {natural_pdf-0.1.6.dist-info → natural_pdf-0.1.8.dist-info}/licenses/LICENSE +0 -0
  66. {natural_pdf-0.1.6.dist-info → natural_pdf-0.1.8.dist-info}/top_level.txt +0 -0
natural_pdf/__init__.py CHANGED
@@ -3,6 +3,8 @@ Natural PDF - A more intuitive interface for working with PDFs.
3
3
  """
4
4
 
5
5
  import logging
6
+ import os
7
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
6
8
 
7
9
  # Create library logger
8
10
  logger = logging.getLogger("natural_pdf")
@@ -33,6 +35,7 @@ def configure_logging(level=logging.INFO, handler=None):
33
35
 
34
36
  logger.propagate = False
35
37
 
38
+
36
39
  from natural_pdf.core.page import Page
37
40
  from natural_pdf.core.pdf import PDF
38
41
  from natural_pdf.elements.collections import ElementCollection
@@ -13,6 +13,7 @@ from PIL import Image
13
13
  try:
14
14
  from openai import OpenAI
15
15
  from openai.types.chat import ChatCompletion
16
+
16
17
  # Import OpenAIError for exception handling if needed
17
18
  except ImportError:
18
19
  OpenAI = None
@@ -32,7 +33,7 @@ except ImportError:
32
33
  class LayoutDetector:
33
34
  def __init__(self):
34
35
  self.logger = logging.getLogger()
35
- self.supported_classes = set() # Will be dynamic based on user request
36
+ self.supported_classes = set() # Will be dynamic based on user request
36
37
 
37
38
  def _get_model(self, options):
38
39
  raise NotImplementedError
@@ -41,17 +42,20 @@ except ImportError:
41
42
  return n.lower().replace("_", "-").replace(" ", "-")
42
43
 
43
44
  def validate_classes(self, c):
44
- pass # Less strict validation needed for LLM
45
+ pass # Less strict validation needed for LLM
45
46
 
46
47
  logging.basicConfig()
47
48
 
48
49
  logger = logging.getLogger(__name__)
49
50
 
51
+
50
52
  # Define Pydantic model for the expected output structure
51
53
  # This is used by the openai library's `response_format`
52
54
  class DetectedRegion(BaseModel):
53
55
  label: str = Field(description="The identified class name.")
54
- bbox: List[float] = Field(description="Bounding box coordinates [xmin, ymin, xmax, ymax].", min_items=4, max_items=4)
56
+ bbox: List[float] = Field(
57
+ description="Bounding box coordinates [xmin, ymin, xmax, ymax].", min_items=4, max_items=4
58
+ )
55
59
  confidence: float = Field(description="Confidence score [0.0, 1.0].", ge=0.0, le=1.0)
56
60
 
57
61
 
@@ -63,23 +67,27 @@ class GeminiLayoutDetector(LayoutDetector):
63
67
 
64
68
  def __init__(self):
65
69
  super().__init__()
66
- self.supported_classes = set() # Indicate dynamic nature
70
+ self.supported_classes = set() # Indicate dynamic nature
67
71
 
68
72
  def is_available(self) -> bool:
69
73
  """Check if openai library is installed and GOOGLE_API_KEY is available."""
70
74
  api_key = os.environ.get("GOOGLE_API_KEY")
71
75
  if not api_key:
72
- logger.warning("GOOGLE_API_KEY environment variable not set. Gemini detector (via OpenAI lib) will not be available.")
76
+ logger.warning(
77
+ "GOOGLE_API_KEY environment variable not set. Gemini detector (via OpenAI lib) will not be available."
78
+ )
73
79
  return False
74
80
  if OpenAI is None:
75
- logger.warning("openai package not found. Gemini detector (via OpenAI lib) will not be available.")
76
- return False
81
+ logger.warning(
82
+ "openai package not found. Gemini detector (via OpenAI lib) will not be available."
83
+ )
84
+ return False
77
85
  return True
78
86
 
79
87
  def _get_cache_key(self, options: GeminiLayoutOptions) -> str:
80
88
  """Generate cache key based on model name."""
81
89
  if not isinstance(options, GeminiLayoutOptions):
82
- options = GeminiLayoutOptions() # Use defaults
90
+ options = GeminiLayoutOptions() # Use defaults
83
91
 
84
92
  model_key = options.model_name
85
93
  # Prompt is built dynamically, so not part of cache key based on options
@@ -101,9 +109,7 @@ class GeminiLayoutDetector(LayoutDetector):
101
109
  def detect(self, image: Image.Image, options: BaseLayoutOptions) -> List[Dict[str, Any]]:
102
110
  """Detect layout elements in an image using Gemini via OpenAI library."""
103
111
  if not self.is_available():
104
- raise RuntimeError(
105
- "OpenAI library not installed or GOOGLE_API_KEY not set."
106
- )
112
+ raise RuntimeError("OpenAI library not installed or GOOGLE_API_KEY not set.")
107
113
 
108
114
  # Ensure options are the correct type
109
115
  if not isinstance(options, GeminiLayoutOptions):
@@ -124,10 +130,7 @@ class GeminiLayoutDetector(LayoutDetector):
124
130
  detections = []
125
131
  try:
126
132
  # --- 1. Initialize OpenAI Client for Gemini ---
127
- client = OpenAI(
128
- api_key=api_key,
129
- base_url=self.GEMINI_BASE_URL
130
- )
133
+ client = OpenAI(api_key=api_key, base_url=self.GEMINI_BASE_URL)
131
134
 
132
135
  # --- 2. Prepare Input for OpenAI API ---
133
136
  if not options.classes:
@@ -139,11 +142,11 @@ class GeminiLayoutDetector(LayoutDetector):
139
142
  # Convert image to base64
140
143
  buffered = io.BytesIO()
141
144
  image.save(buffered, format="PNG")
142
- img_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
145
+ img_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
143
146
  image_url = f"data:image/png;base64,{img_base64}"
144
147
 
145
148
  # Construct the prompt text
146
- class_list_str = ", ".join(f'`{c}`' for c in options.classes)
149
+ class_list_str = ", ".join(f"`{c}`" for c in options.classes)
147
150
  prompt_text = (
148
151
  f"Analyze the provided image of a document page ({width}x{height}). "
149
152
  f"Identify all regions corresponding to the following types: {class_list_str}. "
@@ -165,14 +168,18 @@ class GeminiLayoutDetector(LayoutDetector):
165
168
  ]
166
169
 
167
170
  # --- 3. Call OpenAI API using .parse for structured output ---
168
- logger.debug(f"Running Gemini detection via OpenAI lib (Model: {model_name}). Asking for classes: {options.classes}")
171
+ logger.debug(
172
+ f"Running Gemini detection via OpenAI lib (Model: {model_name}). Asking for classes: {options.classes}"
173
+ )
169
174
 
170
175
  # Extract relevant generation parameters from extra_args if provided
171
176
  # Mapping common names: temperature, top_p, max_tokens
172
177
  completion_kwargs = {
173
- "temperature": options.extra_args.get("temperature", 0.2), # Default to low temp
178
+ "temperature": options.extra_args.get("temperature", 0.2), # Default to low temp
174
179
  "top_p": options.extra_args.get("top_p"),
175
- "max_tokens": options.extra_args.get("max_tokens", 4096), # Map from max_output_tokens
180
+ "max_tokens": options.extra_args.get(
181
+ "max_tokens", 4096
182
+ ), # Map from max_output_tokens
176
183
  }
177
184
  # Filter out None values
178
185
  completion_kwargs = {k: v for k, v in completion_kwargs.items() if v is not None}
@@ -180,13 +187,13 @@ class GeminiLayoutDetector(LayoutDetector):
180
187
  completion: ChatCompletion = client.beta.chat.completions.parse(
181
188
  model=model_name,
182
189
  messages=messages,
183
- response_format=List[DetectedRegion], # Pass the Pydantic model list
184
- **completion_kwargs
190
+ response_format=List[DetectedRegion], # Pass the Pydantic model list
191
+ **completion_kwargs,
185
192
  )
186
193
 
187
194
  logger.debug(f"Gemini response received via OpenAI lib.")
188
195
 
189
- # --- 4. Process Parsed Response ---
196
+ # --- 4. Process Parsed Response ---
190
197
  if not completion.choices:
191
198
  logger.error("Gemini response (via OpenAI lib) contained no choices.")
192
199
  return []
@@ -194,16 +201,18 @@ class GeminiLayoutDetector(LayoutDetector):
194
201
  # Get the parsed Pydantic objects
195
202
  parsed_results = completion.choices[0].message.parsed
196
203
  if not parsed_results or not isinstance(parsed_results, list):
197
- logger.error(f"Gemini response (via OpenAI lib) did not contain a valid list of parsed regions. Found: {type(parsed_results)}")
198
- return []
204
+ logger.error(
205
+ f"Gemini response (via OpenAI lib) did not contain a valid list of parsed regions. Found: {type(parsed_results)}"
206
+ )
207
+ return []
199
208
 
200
- # --- 5. Convert to Detections & Filter ---
201
- normalized_classes_req = {
202
- self._normalize_class_name(c) for c in options.classes
203
- }
204
- normalized_classes_excl = {
205
- self._normalize_class_name(c) for c in options.exclude_classes
206
- } if options.exclude_classes else set()
209
+ # --- 5. Convert to Detections & Filter ---
210
+ normalized_classes_req = {self._normalize_class_name(c) for c in options.classes}
211
+ normalized_classes_excl = (
212
+ {self._normalize_class_name(c) for c in options.exclude_classes}
213
+ if options.exclude_classes
214
+ else set()
215
+ )
207
216
 
208
217
  for item in parsed_results:
209
218
  # The item is already a validated DetectedRegion Pydantic object
@@ -215,33 +224,41 @@ class GeminiLayoutDetector(LayoutDetector):
215
224
  # Coordinates should already be floats, but ensure tuple format
216
225
  xmin, ymin, xmax, ymax = tuple(bbox_raw)
217
226
 
218
- # --- Apply Filtering ---
227
+ # --- Apply Filtering ---
219
228
  normalized_class = self._normalize_class_name(label)
220
229
 
221
230
  # Check against requested classes (Should be guaranteed by schema, but doesn't hurt)
222
231
  if normalized_class not in normalized_classes_req:
223
- logger.warning(f"Gemini (via OpenAI) returned unexpected class '{label}' despite schema. Skipping.")
232
+ logger.warning(
233
+ f"Gemini (via OpenAI) returned unexpected class '{label}' despite schema. Skipping."
234
+ )
224
235
  continue
225
236
 
226
237
  # Check against excluded classes
227
238
  if normalized_class in normalized_classes_excl:
228
- logger.debug(f"Skipping excluded class '{label}' (normalized: {normalized_class}).")
239
+ logger.debug(
240
+ f"Skipping excluded class '{label}' (normalized: {normalized_class})."
241
+ )
229
242
  continue
230
-
243
+
231
244
  # Check against base confidence threshold from options
232
245
  if confidence_score < options.confidence:
233
- logger.debug(f"Skipping item with confidence {confidence_score:.3f} below threshold {options.confidence}.")
246
+ logger.debug(
247
+ f"Skipping item with confidence {confidence_score:.3f} below threshold {options.confidence}."
248
+ )
234
249
  continue
235
250
 
236
251
  # Add detection
237
- detections.append({
238
- "bbox": (xmin, ymin, xmax, ymax),
239
- "class": label, # Use original label from LLM
240
- "confidence": confidence_score,
241
- "normalized_class": normalized_class,
242
- "source": "layout",
243
- "model": "gemini", # Keep model name generic as gemini
244
- })
252
+ detections.append(
253
+ {
254
+ "bbox": (xmin, ymin, xmax, ymax),
255
+ "class": label, # Use original label from LLM
256
+ "confidence": confidence_score,
257
+ "normalized_class": normalized_class,
258
+ "source": "layout",
259
+ "model": "gemini", # Keep model name generic as gemini
260
+ }
261
+ )
245
262
 
246
263
  self.logger.info(
247
264
  f"Gemini (via OpenAI lib) processed response. Detected {len(detections)} layout elements matching criteria."
@@ -260,5 +277,4 @@ class GeminiLayoutDetector(LayoutDetector):
260
277
 
261
278
  def validate_classes(self, classes: List[str]):
262
279
  """Validation is less critical as we pass requested classes to the LLM."""
263
- pass # Override base validation if needed, but likely not necessary
264
-
280
+ pass # Override base validation if needed, but likely not necessary
@@ -0,0 +1,343 @@
1
+ import logging
2
+ import time
3
+ from datetime import datetime
4
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, Tuple
5
+
6
+ # Use try-except for robustness if dependencies are missing
7
+ try:
8
+ import torch
9
+ from PIL import Image
10
+ from transformers import pipeline, AutoTokenizer, AutoModelForZeroShotImageClassification, AutoModelForSequenceClassification
11
+ _CLASSIFICATION_AVAILABLE = True
12
+ except ImportError:
13
+ _CLASSIFICATION_AVAILABLE = False
14
+ # Define dummy types for type hinting if imports fail
15
+ Image = type("Image", (), {})
16
+ pipeline = object
17
+ AutoTokenizer = object
18
+ AutoModelForZeroShotImageClassification = object
19
+ AutoModelForSequenceClassification = object
20
+ torch = None
21
+
22
+ # Import result classes
23
+ from .results import ClassificationResult, CategoryScore
24
+ from natural_pdf.utils.tqdm_utils import get_tqdm
25
+
26
+ if TYPE_CHECKING:
27
+ from transformers import Pipeline
28
+
29
+
30
+ logger = logging.getLogger(__name__)
31
+
32
+ # Global cache for models/pipelines
33
+ _PIPELINE_CACHE: Dict[str, "Pipeline"] = {}
34
+ _TOKENIZER_CACHE: Dict[str, Any] = {}
35
+ _MODEL_CACHE: Dict[str, Any] = {}
36
+
37
+ class ClassificationError(Exception):
38
+ """Custom exception for classification errors."""
39
+ pass
40
+
41
+
42
+ class ClassificationManager:
43
+ """Manages classification models and execution."""
44
+
45
+ DEFAULT_TEXT_MODEL = "facebook/bart-large-mnli"
46
+ DEFAULT_VISION_MODEL = "openai/clip-vit-base-patch16"
47
+
48
+ def __init__(
49
+ self,
50
+ model_mapping: Optional[Dict[str, str]] = None,
51
+ default_device: Optional[str] = None,
52
+ ):
53
+ """
54
+ Initialize the ClassificationManager.
55
+
56
+ Args:
57
+ model_mapping: Optional dictionary mapping aliases ('text', 'vision') to model IDs.
58
+ default_device: Default device ('cpu', 'cuda') if not specified in classify calls.
59
+ """
60
+ if not _CLASSIFICATION_AVAILABLE:
61
+ raise ImportError(
62
+ "Classification dependencies missing. "
63
+ "Install with: pip install \"natural-pdf[classification]\""
64
+ )
65
+
66
+ self.pipelines: Dict[Tuple[str, str], "Pipeline"] = {} # Cache: (model_id, device) -> pipeline
67
+
68
+ self.device = default_device
69
+ logger.info(f"ClassificationManager initialized on device: {self.device}")
70
+
71
+ def is_available(self) -> bool:
72
+ """Check if required dependencies are installed."""
73
+ return _CLASSIFICATION_AVAILABLE
74
+
75
+ def _get_pipeline(self, model_id: str, using: str) -> "Pipeline":
76
+ """Get or create a classification pipeline."""
77
+ cache_key = f"{model_id}_{using}_{self.device}"
78
+ if cache_key not in _PIPELINE_CACHE:
79
+ logger.info(f"Loading {using} classification pipeline for model '{model_id}' on device '{self.device}'...")
80
+ start_time = time.time()
81
+ try:
82
+ task = (
83
+ "zero-shot-classification"
84
+ if using == "text"
85
+ else "zero-shot-image-classification"
86
+ )
87
+ _PIPELINE_CACHE[cache_key] = pipeline(
88
+ task,
89
+ model=model_id,
90
+ device=self.device
91
+ )
92
+ end_time = time.time()
93
+ logger.info(f"Pipeline for '{model_id}' loaded in {end_time - start_time:.2f} seconds.")
94
+ except Exception as e:
95
+ logger.error(f"Failed to load pipeline for model '{model_id}' (using: {using}): {e}", exc_info=True)
96
+ raise ClassificationError(f"Failed to load pipeline for model '{model_id}'. Ensure the model ID is correct and supports the {task} task.") from e
97
+ return _PIPELINE_CACHE[cache_key]
98
+
99
+ def infer_using(self, model_id: str, using: Optional[str] = None) -> str:
100
+ """Infers processing mode ('text' or 'vision') if not provided."""
101
+ if using in ["text", "vision"]:
102
+ return using
103
+
104
+ # Simple inference based on common model names
105
+ normalized_model_id = model_id.lower()
106
+ if "clip" in normalized_model_id or "vit" in normalized_model_id or "siglip" in normalized_model_id:
107
+ logger.debug(f"Inferred using='vision' for model '{model_id}'")
108
+ return "vision"
109
+ if "bart" in normalized_model_id or "bert" in normalized_model_id or "mnli" in normalized_model_id or "xnli" in normalized_model_id or "deberta" in normalized_model_id:
110
+ logger.debug(f"Inferred using='text' for model '{model_id}'")
111
+ return "text"
112
+
113
+ # Fallback or raise error? Let's try loading text first, then vision.
114
+ logger.warning(f"Could not reliably infer mode for '{model_id}'. Trying text, then vision pipeline loading.")
115
+ try:
116
+ self._get_pipeline(model_id, "text")
117
+ logger.info(f"Successfully loaded '{model_id}' as a text model.")
118
+ return "text"
119
+ except Exception:
120
+ logger.warning(f"Failed to load '{model_id}' as text model. Trying vision.")
121
+ try:
122
+ self._get_pipeline(model_id, "vision")
123
+ logger.info(f"Successfully loaded '{model_id}' as a vision model.")
124
+ return "vision"
125
+ except Exception as e_vision:
126
+ logger.error(f"Failed to load '{model_id}' as either text or vision model.", exc_info=True)
127
+ raise ClassificationError(f"Cannot determine mode for model '{model_id}'. Please specify `using='text'` or `using='vision'`. Error: {e_vision}")
128
+
129
+ def classify_item(
130
+ self,
131
+ item_content: Union[str, Image.Image],
132
+ categories: List[str],
133
+ model_id: Optional[str] = None,
134
+ using: Optional[str] = None,
135
+ min_confidence: float = 0.0,
136
+ multi_label: bool = False,
137
+ **kwargs
138
+ ) -> ClassificationResult: # Return ClassificationResult
139
+ """Classifies a single item (text or image)."""
140
+
141
+ # Determine model and engine type
142
+ effective_using = using
143
+ if model_id is None:
144
+ # Try inferring based on content type
145
+ if isinstance(item_content, str):
146
+ effective_using = "text"
147
+ model_id = self.DEFAULT_TEXT_MODEL
148
+ elif isinstance(item_content, Image.Image):
149
+ effective_using = "vision"
150
+ model_id = self.DEFAULT_VISION_MODEL
151
+ else:
152
+ raise TypeError(f"Unsupported item_content type: {type(item_content)}")
153
+ else:
154
+ # Infer engine type if not given
155
+ effective_using = self.infer_using(model_id, using)
156
+ # Set default model if needed (though should usually be provided if engine known)
157
+ if model_id is None:
158
+ model_id = self.DEFAULT_TEXT_MODEL if effective_using == "text" else self.DEFAULT_VISION_MODEL
159
+
160
+ if not categories:
161
+ raise ValueError("Categories list cannot be empty.")
162
+
163
+ pipeline_instance = self._get_pipeline(model_id, effective_using)
164
+ timestamp = datetime.now()
165
+ parameters = { # Store parameters used for this run
166
+ 'categories': categories,
167
+ 'model_id': model_id,
168
+ 'using': effective_using,
169
+ 'min_confidence': min_confidence,
170
+ 'multi_label': multi_label,
171
+ **kwargs
172
+ }
173
+
174
+ logger.debug(f"Classifying content (type: {type(item_content).__name__}) with model '{model_id}'")
175
+ try:
176
+ # Handle potential kwargs for specific pipelines if needed
177
+ # The zero-shot pipelines expect `candidate_labels`
178
+ result_raw = pipeline_instance(item_content, candidate_labels=categories, multi_label=multi_label, **kwargs)
179
+ logger.debug(f"Raw pipeline result: {result_raw}")
180
+
181
+ # --- Process raw result into ClassificationResult --- #
182
+ scores_list: List[CategoryScore] = []
183
+
184
+ # Handle text pipeline format (dict with 'labels' and 'scores')
185
+ if isinstance(result_raw, dict) and 'labels' in result_raw and 'scores' in result_raw:
186
+ for label, score_val in zip(result_raw['labels'], result_raw['scores']):
187
+ if score_val >= min_confidence:
188
+ try:
189
+ scores_list.append(CategoryScore(label=label, confidence=score_val))
190
+ except (ValueError, TypeError) as score_err:
191
+ logger.warning(f"Skipping invalid score from text pipeline: label='{label}', score={score_val}. Error: {score_err}")
192
+ # Handle vision pipeline format (list of dicts with 'label' and 'score')
193
+ elif isinstance(result_raw, list) and all(isinstance(item, dict) and 'label' in item and 'score' in item for item in result_raw):
194
+ for item in result_raw:
195
+ score_val = item['score']
196
+ label = item['label']
197
+ if score_val >= min_confidence:
198
+ try:
199
+ scores_list.append(CategoryScore(label=label, confidence=score_val))
200
+ except (ValueError, TypeError) as score_err:
201
+ logger.warning(f"Skipping invalid score from vision pipeline: label='{label}', score={score_val}. Error: {score_err}")
202
+ else:
203
+ logger.warning(f"Unexpected raw result format from pipeline for model '{model_id}': {type(result_raw)}. Cannot extract scores.")
204
+ # Return empty result?
205
+ # scores_list = []
206
+
207
+ return ClassificationResult(
208
+ model_id=model_id,
209
+ using=effective_using,
210
+ timestamp=timestamp,
211
+ parameters=parameters,
212
+ scores=scores_list
213
+ )
214
+ # --- End Processing --- #
215
+
216
+ except Exception as e:
217
+ logger.error(f"Classification failed for model '{model_id}': {e}", exc_info=True)
218
+ # Return an empty result object on failure?
219
+ # return ClassificationResult(model_id=model_id, engine_type=engine_type, timestamp=timestamp, parameters=parameters, scores=[])
220
+ raise ClassificationError(f"Classification failed using model '{model_id}'. Error: {e}") from e
221
+
222
+ def classify_batch(
223
+ self,
224
+ item_contents: List[Union[str, Image.Image]],
225
+ categories: List[str],
226
+ model_id: Optional[str] = None,
227
+ using: Optional[str] = None,
228
+ min_confidence: float = 0.0,
229
+ multi_label: bool = False,
230
+ batch_size: int = 8,
231
+ progress_bar: bool = True,
232
+ **kwargs
233
+ ) -> List[ClassificationResult]: # Return list of ClassificationResult
234
+ """Classifies a batch of items (text or image) using the pipeline's batching."""
235
+ if not item_contents:
236
+ return []
237
+
238
+ # Determine model and engine type (assuming uniform type in batch)
239
+ first_item = item_contents[0]
240
+ effective_using = using
241
+ if model_id is None:
242
+ if isinstance(first_item, str):
243
+ effective_using = "text"
244
+ model_id = self.DEFAULT_TEXT_MODEL
245
+ elif isinstance(first_item, Image.Image):
246
+ effective_using = "vision"
247
+ model_id = self.DEFAULT_VISION_MODEL
248
+ else:
249
+ raise TypeError(f"Unsupported item_content type in batch: {type(first_item)}")
250
+ else:
251
+ effective_using = self.infer_using(model_id, using)
252
+ if model_id is None:
253
+ model_id = self.DEFAULT_TEXT_MODEL if effective_using == "text" else self.DEFAULT_VISION_MODEL
254
+
255
+ if not categories:
256
+ raise ValueError("Categories list cannot be empty.")
257
+
258
+ pipeline_instance = self._get_pipeline(model_id, effective_using)
259
+ timestamp = datetime.now() # Single timestamp for the batch run
260
+ parameters = { # Parameters for the whole batch
261
+ 'categories': categories,
262
+ 'model_id': model_id,
263
+ 'using': effective_using,
264
+ 'min_confidence': min_confidence,
265
+ 'multi_label': multi_label,
266
+ 'batch_size': batch_size,
267
+ **kwargs
268
+ }
269
+
270
+ logger.info(f"Classifying batch of {len(item_contents)} items with model '{model_id}' (batch size: {batch_size})")
271
+ batch_results_list: List[ClassificationResult] = []
272
+
273
+ try:
274
+ # Use pipeline directly for batching
275
+ results_iterator = pipeline_instance(
276
+ item_contents,
277
+ candidate_labels=categories,
278
+ multi_label=multi_label,
279
+ batch_size=batch_size,
280
+ **kwargs
281
+ )
282
+
283
+ # Wrap with tqdm for progress if requested
284
+ total_items = len(item_contents)
285
+ if progress_bar:
286
+ # Get the appropriate tqdm class
287
+ tqdm_class = get_tqdm()
288
+ results_iterator = tqdm_class(
289
+ results_iterator,
290
+ total=total_items,
291
+ desc=f"Classifying batch ({model_id})",
292
+ leave=False # Don't leave progress bar hanging
293
+ )
294
+
295
+ for raw_result in results_iterator:
296
+ # --- Process each raw result (which corresponds to ONE input item) --- #
297
+ scores_list: List[CategoryScore] = []
298
+ try:
299
+ # Check for text format (dict with 'labels' and 'scores')
300
+ if isinstance(raw_result, dict) and 'labels' in raw_result and 'scores' in raw_result:
301
+ for label, score_val in zip(raw_result['labels'], raw_result['scores']):
302
+ if score_val >= min_confidence:
303
+ try:
304
+ scores_list.append(CategoryScore(label=label, confidence=score_val))
305
+ except (ValueError, TypeError) as score_err:
306
+ logger.warning(f"Skipping invalid score from text pipeline batch: label='{label}', score={score_val}. Error: {score_err}")
307
+ # Check for vision format (list of dicts with 'label' and 'score')
308
+ elif isinstance(raw_result, list):
309
+ for item in raw_result:
310
+ try:
311
+ score_val = item['score']
312
+ label = item['label']
313
+ if score_val >= min_confidence:
314
+ scores_list.append(CategoryScore(label=label, confidence=score_val))
315
+ except (KeyError, ValueError, TypeError) as item_err:
316
+ logger.warning(f"Skipping invalid item in vision result list from batch: {item}. Error: {item_err}")
317
+ else:
318
+ logger.warning(f"Unexpected raw result format in batch item from model '{model_id}': {type(raw_result)}. Cannot extract scores.")
319
+
320
+ except Exception as proc_err:
321
+ logger.error(f"Error processing result item in batch: {proc_err}", exc_info=True)
322
+ # scores_list remains empty for this item
323
+
324
+ # Append result object for this item
325
+ batch_results_list.append(ClassificationResult(
326
+ model_id=model_id,
327
+ using=effective_using,
328
+ timestamp=timestamp, # Use same timestamp for batch
329
+ parameters=parameters, # Use same params for batch
330
+ scores=scores_list
331
+ ))
332
+ # --- End Processing --- #
333
+
334
+ if len(batch_results_list) != total_items:
335
+ logger.warning(f"Batch classification returned {len(batch_results_list)} results, but expected {total_items}. Results might be incomplete or misaligned.")
336
+
337
+ return batch_results_list
338
+
339
+ except Exception as e:
340
+ logger.error(f"Batch classification failed for model '{model_id}': {e}", exc_info=True)
341
+ # Return list of empty results?
342
+ # return [ClassificationResult(model_id=model_id, s=engine_type, timestamp=timestamp, parameters=parameters, scores=[]) for _ in item_contents]
343
+ raise ClassificationError(f"Batch classification failed using model '{model_id}'. Error: {e}") from e