natural-pdf 0.1.8__py3-none-any.whl → 0.1.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (134) hide show
  1. natural_pdf/__init__.py +1 -0
  2. natural_pdf/analyzers/layout/base.py +1 -5
  3. natural_pdf/analyzers/layout/gemini.py +61 -51
  4. natural_pdf/analyzers/layout/layout_analyzer.py +40 -11
  5. natural_pdf/analyzers/layout/layout_manager.py +26 -84
  6. natural_pdf/analyzers/layout/layout_options.py +7 -0
  7. natural_pdf/analyzers/layout/pdfplumber_table_finder.py +142 -0
  8. natural_pdf/analyzers/layout/surya.py +46 -123
  9. natural_pdf/analyzers/layout/tatr.py +51 -4
  10. natural_pdf/analyzers/text_structure.py +3 -5
  11. natural_pdf/analyzers/utils.py +3 -3
  12. natural_pdf/classification/manager.py +241 -158
  13. natural_pdf/classification/mixin.py +52 -38
  14. natural_pdf/classification/results.py +71 -45
  15. natural_pdf/collections/mixins.py +85 -20
  16. natural_pdf/collections/pdf_collection.py +245 -100
  17. natural_pdf/core/element_manager.py +30 -14
  18. natural_pdf/core/highlighting_service.py +13 -22
  19. natural_pdf/core/page.py +423 -101
  20. natural_pdf/core/pdf.py +694 -195
  21. natural_pdf/elements/base.py +134 -40
  22. natural_pdf/elements/collections.py +610 -134
  23. natural_pdf/elements/region.py +659 -90
  24. natural_pdf/elements/text.py +1 -1
  25. natural_pdf/export/mixin.py +137 -0
  26. natural_pdf/exporters/base.py +3 -3
  27. natural_pdf/exporters/paddleocr.py +4 -3
  28. natural_pdf/extraction/manager.py +50 -49
  29. natural_pdf/extraction/mixin.py +90 -57
  30. natural_pdf/extraction/result.py +9 -23
  31. natural_pdf/ocr/__init__.py +5 -5
  32. natural_pdf/ocr/engine_doctr.py +346 -0
  33. natural_pdf/ocr/ocr_factory.py +24 -4
  34. natural_pdf/ocr/ocr_manager.py +61 -25
  35. natural_pdf/ocr/ocr_options.py +70 -10
  36. natural_pdf/ocr/utils.py +6 -4
  37. natural_pdf/search/__init__.py +20 -34
  38. natural_pdf/search/haystack_search_service.py +309 -265
  39. natural_pdf/search/haystack_utils.py +99 -75
  40. natural_pdf/search/search_service_protocol.py +11 -12
  41. natural_pdf/selectors/parser.py +219 -143
  42. natural_pdf/utils/debug.py +3 -3
  43. natural_pdf/utils/identifiers.py +1 -1
  44. natural_pdf/utils/locks.py +1 -1
  45. natural_pdf/utils/packaging.py +8 -6
  46. natural_pdf/utils/text_extraction.py +24 -16
  47. natural_pdf/utils/tqdm_utils.py +18 -10
  48. natural_pdf/utils/visualization.py +18 -0
  49. natural_pdf/widgets/viewer.py +4 -25
  50. {natural_pdf-0.1.8.dist-info → natural_pdf-0.1.10.dist-info}/METADATA +12 -3
  51. natural_pdf-0.1.10.dist-info/RECORD +80 -0
  52. {natural_pdf-0.1.8.dist-info → natural_pdf-0.1.10.dist-info}/WHEEL +1 -1
  53. {natural_pdf-0.1.8.dist-info → natural_pdf-0.1.10.dist-info}/top_level.txt +0 -2
  54. docs/api/index.md +0 -386
  55. docs/assets/favicon.png +0 -3
  56. docs/assets/favicon.svg +0 -3
  57. docs/assets/javascripts/custom.js +0 -17
  58. docs/assets/logo.svg +0 -3
  59. docs/assets/sample-screen.png +0 -0
  60. docs/assets/social-preview.png +0 -17
  61. docs/assets/social-preview.svg +0 -17
  62. docs/assets/stylesheets/custom.css +0 -65
  63. docs/categorizing-documents/index.md +0 -168
  64. docs/data-extraction/index.md +0 -87
  65. docs/document-qa/index.ipynb +0 -435
  66. docs/document-qa/index.md +0 -79
  67. docs/element-selection/index.ipynb +0 -969
  68. docs/element-selection/index.md +0 -249
  69. docs/finetuning/index.md +0 -176
  70. docs/index.md +0 -189
  71. docs/installation/index.md +0 -69
  72. docs/interactive-widget/index.ipynb +0 -962
  73. docs/interactive-widget/index.md +0 -12
  74. docs/layout-analysis/index.ipynb +0 -818
  75. docs/layout-analysis/index.md +0 -185
  76. docs/ocr/index.md +0 -256
  77. docs/pdf-navigation/index.ipynb +0 -314
  78. docs/pdf-navigation/index.md +0 -97
  79. docs/regions/index.ipynb +0 -816
  80. docs/regions/index.md +0 -294
  81. docs/tables/index.ipynb +0 -658
  82. docs/tables/index.md +0 -144
  83. docs/text-analysis/index.ipynb +0 -370
  84. docs/text-analysis/index.md +0 -105
  85. docs/text-extraction/index.ipynb +0 -1478
  86. docs/text-extraction/index.md +0 -292
  87. docs/tutorials/01-loading-and-extraction.ipynb +0 -1873
  88. docs/tutorials/01-loading-and-extraction.md +0 -95
  89. docs/tutorials/02-finding-elements.ipynb +0 -417
  90. docs/tutorials/02-finding-elements.md +0 -149
  91. docs/tutorials/03-extracting-blocks.ipynb +0 -152
  92. docs/tutorials/03-extracting-blocks.md +0 -48
  93. docs/tutorials/04-table-extraction.ipynb +0 -119
  94. docs/tutorials/04-table-extraction.md +0 -50
  95. docs/tutorials/05-excluding-content.ipynb +0 -275
  96. docs/tutorials/05-excluding-content.md +0 -109
  97. docs/tutorials/06-document-qa.ipynb +0 -337
  98. docs/tutorials/06-document-qa.md +0 -91
  99. docs/tutorials/07-layout-analysis.ipynb +0 -293
  100. docs/tutorials/07-layout-analysis.md +0 -66
  101. docs/tutorials/07-working-with-regions.ipynb +0 -414
  102. docs/tutorials/07-working-with-regions.md +0 -151
  103. docs/tutorials/08-spatial-navigation.ipynb +0 -513
  104. docs/tutorials/08-spatial-navigation.md +0 -190
  105. docs/tutorials/09-section-extraction.ipynb +0 -2439
  106. docs/tutorials/09-section-extraction.md +0 -256
  107. docs/tutorials/10-form-field-extraction.ipynb +0 -517
  108. docs/tutorials/10-form-field-extraction.md +0 -201
  109. docs/tutorials/11-enhanced-table-processing.ipynb +0 -59
  110. docs/tutorials/11-enhanced-table-processing.md +0 -9
  111. docs/tutorials/12-ocr-integration.ipynb +0 -3712
  112. docs/tutorials/12-ocr-integration.md +0 -137
  113. docs/tutorials/13-semantic-search.ipynb +0 -1718
  114. docs/tutorials/13-semantic-search.md +0 -77
  115. docs/visual-debugging/index.ipynb +0 -2970
  116. docs/visual-debugging/index.md +0 -157
  117. docs/visual-debugging/region.png +0 -0
  118. natural_pdf/templates/finetune/fine_tune_paddleocr.md +0 -420
  119. natural_pdf/templates/spa/css/style.css +0 -334
  120. natural_pdf/templates/spa/index.html +0 -31
  121. natural_pdf/templates/spa/js/app.js +0 -472
  122. natural_pdf/templates/spa/words.txt +0 -235976
  123. natural_pdf/widgets/frontend/viewer.js +0 -88
  124. natural_pdf-0.1.8.dist-info/RECORD +0 -156
  125. notebooks/Examples.ipynb +0 -1293
  126. pdfs/.gitkeep +0 -0
  127. pdfs/01-practice.pdf +0 -543
  128. pdfs/0500000US42001.pdf +0 -0
  129. pdfs/0500000US42007.pdf +0 -0
  130. pdfs/2014 Statistics.pdf +0 -0
  131. pdfs/2019 Statistics.pdf +0 -0
  132. pdfs/Atlanta_Public_Schools_GA_sample.pdf +0 -0
  133. pdfs/needs-ocr.pdf +0 -0
  134. {natural_pdf-0.1.8.dist-info → natural_pdf-0.1.10.dist-info}/licenses/LICENSE +0 -0
@@ -1,28 +1,35 @@
1
1
  import logging
2
2
  import time
3
3
  from datetime import datetime
4
- from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, Tuple
4
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
5
+
6
+ from PIL import Image
5
7
 
6
8
  # Use try-except for robustness if dependencies are missing
7
9
  try:
8
10
  import torch
9
- from PIL import Image
10
- from transformers import pipeline, AutoTokenizer, AutoModelForZeroShotImageClassification, AutoModelForSequenceClassification
11
+ from transformers import (
12
+ AutoModelForSequenceClassification,
13
+ AutoModelForZeroShotImageClassification,
14
+ AutoTokenizer,
15
+ pipeline,
16
+ )
17
+
11
18
  _CLASSIFICATION_AVAILABLE = True
12
19
  except ImportError:
13
20
  _CLASSIFICATION_AVAILABLE = False
14
21
  # Define dummy types for type hinting if imports fail
15
- Image = type("Image", (), {})
16
22
  pipeline = object
17
23
  AutoTokenizer = object
18
24
  AutoModelForZeroShotImageClassification = object
19
25
  AutoModelForSequenceClassification = object
20
26
  torch = None
21
27
 
22
- # Import result classes
23
- from .results import ClassificationResult, CategoryScore
24
28
  from natural_pdf.utils.tqdm_utils import get_tqdm
25
29
 
30
+ # Import result classes
31
+ from .results import CategoryScore, ClassificationResult
32
+
26
33
  if TYPE_CHECKING:
27
34
  from transformers import Pipeline
28
35
 
@@ -34,8 +41,10 @@ _PIPELINE_CACHE: Dict[str, "Pipeline"] = {}
34
41
  _TOKENIZER_CACHE: Dict[str, Any] = {}
35
42
  _MODEL_CACHE: Dict[str, Any] = {}
36
43
 
44
+
37
45
  class ClassificationError(Exception):
38
46
  """Custom exception for classification errors."""
47
+
39
48
  pass
40
49
 
41
50
 
@@ -60,10 +69,12 @@ class ClassificationManager:
60
69
  if not _CLASSIFICATION_AVAILABLE:
61
70
  raise ImportError(
62
71
  "Classification dependencies missing. "
63
- "Install with: pip install \"natural-pdf[classification]\""
72
+ 'Install with: pip install "natural-pdf[classification]"'
64
73
  )
65
74
 
66
- self.pipelines: Dict[Tuple[str, str], "Pipeline"] = {} # Cache: (model_id, device) -> pipeline
75
+ self.pipelines: Dict[Tuple[str, str], "Pipeline"] = (
76
+ {}
77
+ ) # Cache: (model_id, device) -> pipeline
67
78
 
68
79
  self.device = default_device
69
80
  logger.info(f"ClassificationManager initialized on device: {self.device}")
@@ -76,7 +87,9 @@ class ClassificationManager:
76
87
  """Get or create a classification pipeline."""
77
88
  cache_key = f"{model_id}_{using}_{self.device}"
78
89
  if cache_key not in _PIPELINE_CACHE:
79
- logger.info(f"Loading {using} classification pipeline for model '{model_id}' on device '{self.device}'...")
90
+ logger.info(
91
+ f"Loading {using} classification pipeline for model '{model_id}' on device '{self.device}'..."
92
+ )
80
93
  start_time = time.time()
81
94
  try:
82
95
  task = (
@@ -84,16 +97,19 @@ class ClassificationManager:
84
97
  if using == "text"
85
98
  else "zero-shot-image-classification"
86
99
  )
87
- _PIPELINE_CACHE[cache_key] = pipeline(
88
- task,
89
- model=model_id,
90
- device=self.device
91
- )
100
+ _PIPELINE_CACHE[cache_key] = pipeline(task, model=model_id, device=self.device)
92
101
  end_time = time.time()
93
- logger.info(f"Pipeline for '{model_id}' loaded in {end_time - start_time:.2f} seconds.")
102
+ logger.info(
103
+ f"Pipeline for '{model_id}' loaded in {end_time - start_time:.2f} seconds."
104
+ )
94
105
  except Exception as e:
95
- logger.error(f"Failed to load pipeline for model '{model_id}' (using: {using}): {e}", exc_info=True)
96
- raise ClassificationError(f"Failed to load pipeline for model '{model_id}'. Ensure the model ID is correct and supports the {task} task.") from e
106
+ logger.error(
107
+ f"Failed to load pipeline for model '{model_id}' (using: {using}): {e}",
108
+ exc_info=True,
109
+ )
110
+ raise ClassificationError(
111
+ f"Failed to load pipeline for model '{model_id}'. Ensure the model ID is correct and supports the {task} task."
112
+ ) from e
97
113
  return _PIPELINE_CACHE[cache_key]
98
114
 
99
115
  def infer_using(self, model_id: str, using: Optional[str] = None) -> str:
@@ -103,241 +119,308 @@ class ClassificationManager:
103
119
 
104
120
  # Simple inference based on common model names
105
121
  normalized_model_id = model_id.lower()
106
- if "clip" in normalized_model_id or "vit" in normalized_model_id or "siglip" in normalized_model_id:
107
- logger.debug(f"Inferred using='vision' for model '{model_id}'")
108
- return "vision"
109
- if "bart" in normalized_model_id or "bert" in normalized_model_id or "mnli" in normalized_model_id or "xnli" in normalized_model_id or "deberta" in normalized_model_id:
110
- logger.debug(f"Inferred using='text' for model '{model_id}'")
111
- return "text"
122
+ if (
123
+ "clip" in normalized_model_id
124
+ or "vit" in normalized_model_id
125
+ or "siglip" in normalized_model_id
126
+ ):
127
+ logger.debug(f"Inferred using='vision' for model '{model_id}'")
128
+ return "vision"
129
+ if (
130
+ "bart" in normalized_model_id
131
+ or "bert" in normalized_model_id
132
+ or "mnli" in normalized_model_id
133
+ or "xnli" in normalized_model_id
134
+ or "deberta" in normalized_model_id
135
+ ):
136
+ logger.debug(f"Inferred using='text' for model '{model_id}'")
137
+ return "text"
112
138
 
113
139
  # Fallback or raise error? Let's try loading text first, then vision.
114
- logger.warning(f"Could not reliably infer mode for '{model_id}'. Trying text, then vision pipeline loading.")
140
+ logger.warning(
141
+ f"Could not reliably infer mode for '{model_id}'. Trying text, then vision pipeline loading."
142
+ )
115
143
  try:
116
144
  self._get_pipeline(model_id, "text")
117
145
  logger.info(f"Successfully loaded '{model_id}' as a text model.")
118
146
  return "text"
119
147
  except Exception:
120
- logger.warning(f"Failed to load '{model_id}' as text model. Trying vision.")
121
- try:
122
- self._get_pipeline(model_id, "vision")
123
- logger.info(f"Successfully loaded '{model_id}' as a vision model.")
124
- return "vision"
125
- except Exception as e_vision:
126
- logger.error(f"Failed to load '{model_id}' as either text or vision model.", exc_info=True)
127
- raise ClassificationError(f"Cannot determine mode for model '{model_id}'. Please specify `using='text'` or `using='vision'`. Error: {e_vision}")
148
+ logger.warning(f"Failed to load '{model_id}' as text model. Trying vision.")
149
+ try:
150
+ self._get_pipeline(model_id, "vision")
151
+ logger.info(f"Successfully loaded '{model_id}' as a vision model.")
152
+ return "vision"
153
+ except Exception as e_vision:
154
+ logger.error(
155
+ f"Failed to load '{model_id}' as either text or vision model.", exc_info=True
156
+ )
157
+ raise ClassificationError(
158
+ f"Cannot determine mode for model '{model_id}'. Please specify `using='text'` or `using='vision'`. Error: {e_vision}"
159
+ )
128
160
 
129
161
  def classify_item(
130
162
  self,
131
163
  item_content: Union[str, Image.Image],
132
- categories: List[str],
164
+ labels: List[str],
133
165
  model_id: Optional[str] = None,
134
166
  using: Optional[str] = None,
135
167
  min_confidence: float = 0.0,
136
168
  multi_label: bool = False,
137
- **kwargs
138
- ) -> ClassificationResult: # Return ClassificationResult
169
+ **kwargs,
170
+ ) -> ClassificationResult: # Return ClassificationResult
139
171
  """Classifies a single item (text or image)."""
140
172
 
141
173
  # Determine model and engine type
142
174
  effective_using = using
143
175
  if model_id is None:
144
- # Try inferring based on content type
145
- if isinstance(item_content, str):
146
- effective_using = "text"
147
- model_id = self.DEFAULT_TEXT_MODEL
148
- elif isinstance(item_content, Image.Image):
149
- effective_using = "vision"
150
- model_id = self.DEFAULT_VISION_MODEL
151
- else:
152
- raise TypeError(f"Unsupported item_content type: {type(item_content)}")
176
+ # Try inferring based on content type
177
+ if isinstance(item_content, str):
178
+ effective_using = "text"
179
+ model_id = self.DEFAULT_TEXT_MODEL
180
+ elif isinstance(item_content, Image.Image):
181
+ effective_using = "vision"
182
+ model_id = self.DEFAULT_VISION_MODEL
183
+ else:
184
+ raise TypeError(f"Unsupported item_content type: {type(item_content)}")
153
185
  else:
154
- # Infer engine type if not given
155
- effective_using = self.infer_using(model_id, using)
156
- # Set default model if needed (though should usually be provided if engine known)
157
- if model_id is None:
158
- model_id = self.DEFAULT_TEXT_MODEL if effective_using == "text" else self.DEFAULT_VISION_MODEL
186
+ # Infer engine type if not given
187
+ effective_using = self.infer_using(model_id, using)
188
+ # Set default model if needed (though should usually be provided if engine known)
189
+ if model_id is None:
190
+ model_id = (
191
+ self.DEFAULT_TEXT_MODEL
192
+ if effective_using == "text"
193
+ else self.DEFAULT_VISION_MODEL
194
+ )
159
195
 
160
- if not categories:
161
- raise ValueError("Categories list cannot be empty.")
196
+ if not labels:
197
+ raise ValueError("Labels list cannot be empty.")
162
198
 
163
199
  pipeline_instance = self._get_pipeline(model_id, effective_using)
164
200
  timestamp = datetime.now()
165
- parameters = { # Store parameters used for this run
166
- 'categories': categories,
167
- 'model_id': model_id,
168
- 'using': effective_using,
169
- 'min_confidence': min_confidence,
170
- 'multi_label': multi_label,
171
- **kwargs
201
+ parameters = { # Store parameters used for this run
202
+ "labels": labels,
203
+ "model_id": model_id,
204
+ "using": effective_using,
205
+ "min_confidence": min_confidence,
206
+ "multi_label": multi_label,
207
+ **kwargs,
172
208
  }
173
209
 
174
- logger.debug(f"Classifying content (type: {type(item_content).__name__}) with model '{model_id}'")
210
+ logger.debug(
211
+ f"Classifying content (type: {type(item_content).__name__}) with model '{model_id}'"
212
+ )
175
213
  try:
176
214
  # Handle potential kwargs for specific pipelines if needed
177
215
  # The zero-shot pipelines expect `candidate_labels`
178
- result_raw = pipeline_instance(item_content, candidate_labels=categories, multi_label=multi_label, **kwargs)
216
+ result_raw = pipeline_instance(
217
+ item_content, candidate_labels=labels, multi_label=multi_label, **kwargs
218
+ )
179
219
  logger.debug(f"Raw pipeline result: {result_raw}")
180
220
 
181
- # --- Process raw result into ClassificationResult --- #
221
+ # --- Process raw result into ClassificationResult --- #
182
222
  scores_list: List[CategoryScore] = []
183
-
223
+
184
224
  # Handle text pipeline format (dict with 'labels' and 'scores')
185
- if isinstance(result_raw, dict) and 'labels' in result_raw and 'scores' in result_raw:
186
- for label, score_val in zip(result_raw['labels'], result_raw['scores']):
187
- if score_val >= min_confidence:
188
- try:
189
- scores_list.append(CategoryScore(label=label, confidence=score_val))
190
- except (ValueError, TypeError) as score_err:
191
- logger.warning(f"Skipping invalid score from text pipeline: label='{label}', score={score_val}. Error: {score_err}")
225
+ if isinstance(result_raw, dict) and "labels" in result_raw and "scores" in result_raw:
226
+ for label, score_val in zip(result_raw["labels"], result_raw["scores"]):
227
+ if score_val >= min_confidence:
228
+ try:
229
+ scores_list.append(CategoryScore(label, score_val))
230
+ except (ValueError, TypeError) as score_err:
231
+ logger.warning(
232
+ f"Skipping invalid score from text pipeline: label='{label}', score={score_val}. Error: {score_err}"
233
+ )
192
234
  # Handle vision pipeline format (list of dicts with 'label' and 'score')
193
- elif isinstance(result_raw, list) and all(isinstance(item, dict) and 'label' in item and 'score' in item for item in result_raw):
194
- for item in result_raw:
195
- score_val = item['score']
196
- label = item['label']
197
- if score_val >= min_confidence:
198
- try:
199
- scores_list.append(CategoryScore(label=label, confidence=score_val))
200
- except (ValueError, TypeError) as score_err:
201
- logger.warning(f"Skipping invalid score from vision pipeline: label='{label}', score={score_val}. Error: {score_err}")
235
+ elif isinstance(result_raw, list) and all(
236
+ isinstance(item, dict) and "label" in item and "score" in item
237
+ for item in result_raw
238
+ ):
239
+ for item in result_raw:
240
+ score_val = item["score"]
241
+ label = item["label"]
242
+ if score_val >= min_confidence:
243
+ try:
244
+ scores_list.append(CategoryScore(label, score_val))
245
+ except (ValueError, TypeError) as score_err:
246
+ logger.warning(
247
+ f"Skipping invalid score from vision pipeline: label='{label}', score={score_val}. Error: {score_err}"
248
+ )
202
249
  else:
203
- logger.warning(f"Unexpected raw result format from pipeline for model '{model_id}': {type(result_raw)}. Cannot extract scores.")
204
- # Return empty result?
205
- # scores_list = []
250
+ logger.warning(
251
+ f"Unexpected raw result format from pipeline for model '{model_id}': {type(result_raw)}. Cannot extract scores."
252
+ )
253
+ # Return empty result?
254
+ # scores_list = []
206
255
 
207
- return ClassificationResult(
256
+ # ClassificationResult now calculates top score/category internally
257
+ result_obj = ClassificationResult(
258
+ scores=scores_list, # Pass the filtered list
208
259
  model_id=model_id,
209
260
  using=effective_using,
210
- timestamp=timestamp,
211
261
  parameters=parameters,
212
- scores=scores_list
262
+ timestamp=timestamp,
213
263
  )
264
+ return result_obj
214
265
  # --- End Processing --- #
215
266
 
216
267
  except Exception as e:
217
- logger.error(f"Classification failed for model '{model_id}': {e}", exc_info=True)
218
- # Return an empty result object on failure?
219
- # return ClassificationResult(model_id=model_id, engine_type=engine_type, timestamp=timestamp, parameters=parameters, scores=[])
220
- raise ClassificationError(f"Classification failed using model '{model_id}'. Error: {e}") from e
268
+ logger.error(f"Classification failed for model '{model_id}': {e}", exc_info=True)
269
+ # Return an empty result object on failure?
270
+ # return ClassificationResult(model_id=model_id, engine_type=engine_type, timestamp=timestamp, parameters=parameters, scores=[])
271
+ raise ClassificationError(
272
+ f"Classification failed using model '{model_id}'. Error: {e}"
273
+ ) from e
221
274
 
222
275
  def classify_batch(
223
276
  self,
224
277
  item_contents: List[Union[str, Image.Image]],
225
- categories: List[str],
278
+ labels: List[str],
226
279
  model_id: Optional[str] = None,
227
280
  using: Optional[str] = None,
228
281
  min_confidence: float = 0.0,
229
282
  multi_label: bool = False,
230
283
  batch_size: int = 8,
231
284
  progress_bar: bool = True,
232
- **kwargs
233
- ) -> List[ClassificationResult]: # Return list of ClassificationResult
285
+ **kwargs,
286
+ ) -> List[ClassificationResult]: # Return list of ClassificationResult
234
287
  """Classifies a batch of items (text or image) using the pipeline's batching."""
235
288
  if not item_contents:
236
- return []
289
+ return []
237
290
 
238
291
  # Determine model and engine type (assuming uniform type in batch)
239
292
  first_item = item_contents[0]
240
293
  effective_using = using
241
294
  if model_id is None:
242
- if isinstance(first_item, str):
243
- effective_using = "text"
244
- model_id = self.DEFAULT_TEXT_MODEL
245
- elif isinstance(first_item, Image.Image):
246
- effective_using = "vision"
247
- model_id = self.DEFAULT_VISION_MODEL
248
- else:
249
- raise TypeError(f"Unsupported item_content type in batch: {type(first_item)}")
295
+ if isinstance(first_item, str):
296
+ effective_using = "text"
297
+ model_id = self.DEFAULT_TEXT_MODEL
298
+ elif isinstance(first_item, Image.Image):
299
+ effective_using = "vision"
300
+ model_id = self.DEFAULT_VISION_MODEL
301
+ else:
302
+ raise TypeError(f"Unsupported item_content type in batch: {type(first_item)}")
250
303
  else:
251
- effective_using = self.infer_using(model_id, using)
252
- if model_id is None:
253
- model_id = self.DEFAULT_TEXT_MODEL if effective_using == "text" else self.DEFAULT_VISION_MODEL
304
+ effective_using = self.infer_using(model_id, using)
305
+ if model_id is None:
306
+ model_id = (
307
+ self.DEFAULT_TEXT_MODEL
308
+ if effective_using == "text"
309
+ else self.DEFAULT_VISION_MODEL
310
+ )
254
311
 
255
- if not categories:
256
- raise ValueError("Categories list cannot be empty.")
312
+ if not labels:
313
+ raise ValueError("Labels list cannot be empty.")
257
314
 
258
315
  pipeline_instance = self._get_pipeline(model_id, effective_using)
259
- timestamp = datetime.now() # Single timestamp for the batch run
260
- parameters = { # Parameters for the whole batch
261
- 'categories': categories,
262
- 'model_id': model_id,
263
- 'using': effective_using,
264
- 'min_confidence': min_confidence,
265
- 'multi_label': multi_label,
266
- 'batch_size': batch_size,
267
- **kwargs
316
+ timestamp = datetime.now() # Single timestamp for the batch run
317
+ parameters = { # Parameters for the whole batch
318
+ "labels": labels,
319
+ "model_id": model_id,
320
+ "using": effective_using,
321
+ "min_confidence": min_confidence,
322
+ "multi_label": multi_label,
323
+ "batch_size": batch_size,
324
+ **kwargs,
268
325
  }
269
326
 
270
- logger.info(f"Classifying batch of {len(item_contents)} items with model '{model_id}' (batch size: {batch_size})")
327
+ logger.info(
328
+ f"Classifying batch of {len(item_contents)} items with model '{model_id}' (batch size: {batch_size})"
329
+ )
271
330
  batch_results_list: List[ClassificationResult] = []
272
331
 
273
332
  try:
274
333
  # Use pipeline directly for batching
275
334
  results_iterator = pipeline_instance(
276
335
  item_contents,
277
- candidate_labels=categories,
336
+ candidate_labels=labels,
278
337
  multi_label=multi_label,
279
338
  batch_size=batch_size,
280
- **kwargs
339
+ **kwargs,
281
340
  )
282
-
341
+
283
342
  # Wrap with tqdm for progress if requested
284
343
  total_items = len(item_contents)
285
344
  if progress_bar:
286
- # Get the appropriate tqdm class
287
- tqdm_class = get_tqdm()
288
- results_iterator = tqdm_class(
289
- results_iterator,
290
- total=total_items,
291
- desc=f"Classifying batch ({model_id})",
292
- leave=False # Don't leave progress bar hanging
293
- )
345
+ # Get the appropriate tqdm class
346
+ tqdm_class = get_tqdm()
347
+ results_iterator = tqdm_class(
348
+ results_iterator,
349
+ total=total_items,
350
+ desc=f"Classifying batch ({model_id})",
351
+ leave=False, # Don't leave progress bar hanging
352
+ )
294
353
 
295
354
  for raw_result in results_iterator:
296
- # --- Process each raw result (which corresponds to ONE input item) --- #
355
+ # --- Process each raw result (which corresponds to ONE input item) --- #
297
356
  scores_list: List[CategoryScore] = []
298
357
  try:
299
358
  # Check for text format (dict with 'labels' and 'scores')
300
- if isinstance(raw_result, dict) and 'labels' in raw_result and 'scores' in raw_result:
301
- for label, score_val in zip(raw_result['labels'], raw_result['scores']):
302
- if score_val >= min_confidence:
303
- try:
304
- scores_list.append(CategoryScore(label=label, confidence=score_val))
305
- except (ValueError, TypeError) as score_err:
306
- logger.warning(f"Skipping invalid score from text pipeline batch: label='{label}', score={score_val}. Error: {score_err}")
359
+ if (
360
+ isinstance(raw_result, dict)
361
+ and "labels" in raw_result
362
+ and "scores" in raw_result
363
+ ):
364
+ for label, score_val in zip(raw_result["labels"], raw_result["scores"]):
365
+ if score_val >= min_confidence:
366
+ try:
367
+ scores_list.append(CategoryScore(label, score_val))
368
+ except (ValueError, TypeError) as score_err:
369
+ logger.warning(
370
+ f"Skipping invalid score from text pipeline batch: label='{label}', score={score_val}. Error: {score_err}"
371
+ )
307
372
  # Check for vision format (list of dicts with 'label' and 'score')
308
373
  elif isinstance(raw_result, list):
309
- for item in raw_result:
310
- try:
311
- score_val = item['score']
312
- label = item['label']
313
- if score_val >= min_confidence:
314
- scores_list.append(CategoryScore(label=label, confidence=score_val))
315
- except (KeyError, ValueError, TypeError) as item_err:
316
- logger.warning(f"Skipping invalid item in vision result list from batch: {item}. Error: {item_err}")
374
+ for item in raw_result:
375
+ try:
376
+ score_val = item["score"]
377
+ label = item["label"]
378
+ if score_val >= min_confidence:
379
+ scores_list.append(CategoryScore(label, score_val))
380
+ except (KeyError, ValueError, TypeError) as item_err:
381
+ logger.warning(
382
+ f"Skipping invalid item in vision result list from batch: {item}. Error: {item_err}"
383
+ )
317
384
  else:
318
- logger.warning(f"Unexpected raw result format in batch item from model '{model_id}': {type(raw_result)}. Cannot extract scores.")
319
-
385
+ logger.warning(
386
+ f"Unexpected raw result format in batch item from model '{model_id}': {type(raw_result)}. Cannot extract scores."
387
+ )
388
+
320
389
  except Exception as proc_err:
321
- logger.error(f"Error processing result item in batch: {proc_err}", exc_info=True)
322
- # scores_list remains empty for this item
390
+ logger.error(
391
+ f"Error processing result item in batch: {proc_err}", exc_info=True
392
+ )
393
+ # scores_list remains empty for this item
394
+
395
+ # --- Determine top category and score ---
396
+ scores_list.sort(key=lambda s: s.score, reverse=True)
397
+ top_category = scores_list[0].label
398
+ top_score = scores_list[0].score
399
+ # --- End Determine top category ---
323
400
 
324
401
  # Append result object for this item
325
- batch_results_list.append(ClassificationResult(
326
- model_id=model_id,
327
- using=effective_using,
328
- timestamp=timestamp, # Use same timestamp for batch
329
- parameters=parameters, # Use same params for batch
330
- scores=scores_list
331
- ))
402
+ batch_results_list.append(
403
+ ClassificationResult(
404
+ scores=scores_list, # Pass the full list, init will sort/filter
405
+ model_id=model_id,
406
+ using=effective_using,
407
+ timestamp=timestamp, # Use same timestamp for batch
408
+ parameters=parameters, # Use same params for batch
409
+ )
410
+ )
332
411
  # --- End Processing --- #
333
412
 
334
413
  if len(batch_results_list) != total_items:
335
- logger.warning(f"Batch classification returned {len(batch_results_list)} results, but expected {total_items}. Results might be incomplete or misaligned.")
414
+ logger.warning(
415
+ f"Batch classification returned {len(batch_results_list)} results, but expected {total_items}. Results might be incomplete or misaligned."
416
+ )
336
417
 
337
418
  return batch_results_list
338
419
 
339
420
  except Exception as e:
340
- logger.error(f"Batch classification failed for model '{model_id}': {e}", exc_info=True)
341
- # Return list of empty results?
342
- # return [ClassificationResult(model_id=model_id, s=engine_type, timestamp=timestamp, parameters=parameters, scores=[]) for _ in item_contents]
343
- raise ClassificationError(f"Batch classification failed using model '{model_id}'. Error: {e}") from e
421
+ logger.error(f"Batch classification failed for model '{model_id}': {e}", exc_info=True)
422
+ # Return list of empty results?
423
+ # return [ClassificationResult(model_id=model_id, s=engine_type, timestamp=timestamp, parameters=parameters, scores=[]) for _ in item_contents]
424
+ raise ClassificationError(
425
+ f"Batch classification failed using model '{model_id}'. Error: {e}"
426
+ ) from e