natural-pdf 0.2.18__py3-none-any.whl → 0.2.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- natural_pdf/__init__.py +8 -0
- natural_pdf/analyzers/checkbox/__init__.py +6 -0
- natural_pdf/analyzers/checkbox/base.py +265 -0
- natural_pdf/analyzers/checkbox/checkbox_analyzer.py +329 -0
- natural_pdf/analyzers/checkbox/checkbox_manager.py +166 -0
- natural_pdf/analyzers/checkbox/checkbox_options.py +60 -0
- natural_pdf/analyzers/checkbox/mixin.py +95 -0
- natural_pdf/analyzers/checkbox/rtdetr.py +201 -0
- natural_pdf/collections/mixins.py +14 -5
- natural_pdf/core/element_manager.py +5 -1
- natural_pdf/core/page.py +61 -0
- natural_pdf/core/page_collection.py +41 -1
- natural_pdf/core/pdf.py +24 -1
- natural_pdf/describe/base.py +20 -0
- natural_pdf/elements/base.py +152 -10
- natural_pdf/elements/element_collection.py +41 -2
- natural_pdf/elements/region.py +115 -2
- natural_pdf/judge.py +1509 -0
- natural_pdf/selectors/parser.py +42 -1
- {natural_pdf-0.2.18.dist-info → natural_pdf-0.2.19.dist-info}/METADATA +1 -1
- {natural_pdf-0.2.18.dist-info → natural_pdf-0.2.19.dist-info}/RECORD +41 -17
- temp/check_model.py +49 -0
- temp/check_pdf_content.py +9 -0
- temp/checkbox_checks.py +590 -0
- temp/checkbox_simple.py +117 -0
- temp/checkbox_ux_ideas.py +400 -0
- temp/context_manager_prototype.py +177 -0
- temp/convert_to_hf.py +60 -0
- temp/demo_text_closest.py +66 -0
- temp/inspect_model.py +43 -0
- temp/rtdetr_dinov2_test.py +49 -0
- temp/test_closest_debug.py +26 -0
- temp/test_closest_debug2.py +22 -0
- temp/test_context_exploration.py +85 -0
- temp/test_durham.py +30 -0
- temp/test_empty_string.py +16 -0
- temp/test_similarity.py +15 -0
- {natural_pdf-0.2.18.dist-info → natural_pdf-0.2.19.dist-info}/WHEEL +0 -0
- {natural_pdf-0.2.18.dist-info → natural_pdf-0.2.19.dist-info}/entry_points.txt +0 -0
- {natural_pdf-0.2.18.dist-info → natural_pdf-0.2.19.dist-info}/licenses/LICENSE +0 -0
- {natural_pdf-0.2.18.dist-info → natural_pdf-0.2.19.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,166 @@
|
|
1
|
+
"""Manager for checkbox detection engines."""
|
2
|
+
|
3
|
+
import logging
|
4
|
+
from typing import Any, Dict, List, Optional, Type, Union
|
5
|
+
|
6
|
+
from PIL import Image
|
7
|
+
|
8
|
+
from .base import CheckboxDetector
|
9
|
+
from .checkbox_options import CheckboxOptions, RTDETRCheckboxOptions
|
10
|
+
|
11
|
+
logger = logging.getLogger(__name__)
|
12
|
+
|
13
|
+
|
14
|
+
def _lazy_import_rtdetr_detector():
|
15
|
+
"""Lazy import RT-DETR detector to avoid heavy dependencies at module load."""
|
16
|
+
from .rtdetr import RTDETRCheckboxDetector
|
17
|
+
|
18
|
+
return RTDETRCheckboxDetector
|
19
|
+
|
20
|
+
|
21
|
+
class CheckboxManager:
|
22
|
+
"""Manages checkbox detection engines and provides a unified interface."""
|
23
|
+
|
24
|
+
# Registry of available engines
|
25
|
+
ENGINE_REGISTRY = {
|
26
|
+
"rtdetr": {
|
27
|
+
"class": _lazy_import_rtdetr_detector,
|
28
|
+
"options_class": RTDETRCheckboxOptions,
|
29
|
+
},
|
30
|
+
"wendys": { # Alias for the default model
|
31
|
+
"class": _lazy_import_rtdetr_detector,
|
32
|
+
"options_class": RTDETRCheckboxOptions,
|
33
|
+
},
|
34
|
+
}
|
35
|
+
|
36
|
+
def __init__(self):
|
37
|
+
"""Initialize the checkbox manager."""
|
38
|
+
self.logger = logging.getLogger(__name__)
|
39
|
+
self._detector_cache: Dict[str, CheckboxDetector] = {}
|
40
|
+
|
41
|
+
def detect_checkboxes(
|
42
|
+
self,
|
43
|
+
image: Image.Image,
|
44
|
+
engine: Optional[str] = None,
|
45
|
+
options: Optional[Union[CheckboxOptions, Dict[str, Any]]] = None,
|
46
|
+
**kwargs,
|
47
|
+
) -> List[Dict[str, Any]]:
|
48
|
+
"""
|
49
|
+
Detect checkboxes in an image using the specified engine.
|
50
|
+
|
51
|
+
Args:
|
52
|
+
image: PIL Image to analyze
|
53
|
+
engine: Name of the detection engine (default: 'rtdetr')
|
54
|
+
options: CheckboxOptions instance or dict of options
|
55
|
+
**kwargs: Additional options to override
|
56
|
+
|
57
|
+
Returns:
|
58
|
+
List of detection dictionaries
|
59
|
+
"""
|
60
|
+
# Determine engine and options
|
61
|
+
if options is None:
|
62
|
+
if engine is None:
|
63
|
+
engine = "rtdetr" # Default engine
|
64
|
+
options = self._create_options(engine, **kwargs)
|
65
|
+
elif isinstance(options, dict):
|
66
|
+
if engine is None:
|
67
|
+
engine = "rtdetr"
|
68
|
+
options = self._create_options(engine, **options, **kwargs)
|
69
|
+
else:
|
70
|
+
# options is a CheckboxOptions instance
|
71
|
+
# Determine engine from options type if not specified
|
72
|
+
if engine is None:
|
73
|
+
engine = self._get_engine_from_options(options)
|
74
|
+
# Apply any kwargs as overrides
|
75
|
+
if kwargs:
|
76
|
+
options = self._override_options(options, **kwargs)
|
77
|
+
|
78
|
+
# Get detector
|
79
|
+
detector = self._get_detector(engine)
|
80
|
+
|
81
|
+
# Run detection
|
82
|
+
try:
|
83
|
+
return detector.detect(image, options)
|
84
|
+
except Exception as e:
|
85
|
+
self.logger.error(f"Checkbox detection failed with {engine}: {e}", exc_info=True)
|
86
|
+
raise
|
87
|
+
|
88
|
+
def _get_engine_from_options(self, options: CheckboxOptions) -> str:
|
89
|
+
"""Determine engine from options type."""
|
90
|
+
for engine_name, engine_info in self.ENGINE_REGISTRY.items():
|
91
|
+
if isinstance(options, engine_info["options_class"]):
|
92
|
+
return engine_name
|
93
|
+
# Default if can't determine
|
94
|
+
return "rtdetr"
|
95
|
+
|
96
|
+
def _create_options(self, engine: str, **kwargs) -> CheckboxOptions:
|
97
|
+
"""Create options instance for the specified engine."""
|
98
|
+
if engine not in self.ENGINE_REGISTRY:
|
99
|
+
raise ValueError(
|
100
|
+
f"Unknown checkbox detection engine: {engine}. "
|
101
|
+
f"Available: {list(self.ENGINE_REGISTRY.keys())}"
|
102
|
+
)
|
103
|
+
|
104
|
+
options_class = self.ENGINE_REGISTRY[engine]["options_class"]
|
105
|
+
return options_class(**kwargs)
|
106
|
+
|
107
|
+
def _override_options(self, options: CheckboxOptions, **kwargs) -> CheckboxOptions:
|
108
|
+
"""Create a new options instance with overrides applied."""
|
109
|
+
# Get current values as dict
|
110
|
+
import dataclasses
|
111
|
+
|
112
|
+
current_values = dataclasses.asdict(options)
|
113
|
+
|
114
|
+
# Apply overrides
|
115
|
+
current_values.update(kwargs)
|
116
|
+
|
117
|
+
# Create new instance
|
118
|
+
return type(options)(**current_values)
|
119
|
+
|
120
|
+
def _get_detector(self, engine: str) -> CheckboxDetector:
|
121
|
+
"""Get or create a detector instance for the specified engine."""
|
122
|
+
if engine not in self._detector_cache:
|
123
|
+
if engine not in self.ENGINE_REGISTRY:
|
124
|
+
raise ValueError(
|
125
|
+
f"Unknown checkbox detection engine: {engine}. "
|
126
|
+
f"Available: {list(self.ENGINE_REGISTRY.keys())}"
|
127
|
+
)
|
128
|
+
|
129
|
+
# Get detector class (lazy import)
|
130
|
+
detector_class = self.ENGINE_REGISTRY[engine]["class"]
|
131
|
+
if callable(detector_class):
|
132
|
+
detector_class = detector_class() # Call factory function
|
133
|
+
|
134
|
+
# Check availability
|
135
|
+
if not detector_class.is_available():
|
136
|
+
raise RuntimeError(
|
137
|
+
f"Checkbox detection engine '{engine}' is not available. "
|
138
|
+
f"Please install required dependencies."
|
139
|
+
)
|
140
|
+
|
141
|
+
# Create instance
|
142
|
+
self._detector_cache[engine] = detector_class()
|
143
|
+
self.logger.info(f"Initialized checkbox detector: {engine}")
|
144
|
+
|
145
|
+
return self._detector_cache[engine]
|
146
|
+
|
147
|
+
def is_engine_available(self, engine: str) -> bool:
|
148
|
+
"""Check if a specific engine is available."""
|
149
|
+
if engine not in self.ENGINE_REGISTRY:
|
150
|
+
return False
|
151
|
+
|
152
|
+
try:
|
153
|
+
detector_class = self.ENGINE_REGISTRY[engine]["class"]
|
154
|
+
if callable(detector_class):
|
155
|
+
detector_class = detector_class()
|
156
|
+
return detector_class.is_available()
|
157
|
+
except Exception:
|
158
|
+
return False
|
159
|
+
|
160
|
+
def list_available_engines(self) -> List[str]:
|
161
|
+
"""List all available checkbox detection engines."""
|
162
|
+
available = []
|
163
|
+
for engine in self.ENGINE_REGISTRY:
|
164
|
+
if self.is_engine_available(engine):
|
165
|
+
available.append(engine)
|
166
|
+
return available
|
@@ -0,0 +1,60 @@
|
|
1
|
+
"""Options classes for checkbox detection engines."""
|
2
|
+
|
3
|
+
import logging
|
4
|
+
from dataclasses import dataclass, field
|
5
|
+
from typing import Any, Dict, Optional
|
6
|
+
|
7
|
+
logger = logging.getLogger(__name__)
|
8
|
+
|
9
|
+
|
10
|
+
@dataclass
|
11
|
+
class CheckboxOptions:
|
12
|
+
"""Base options for checkbox detection engines."""
|
13
|
+
|
14
|
+
confidence: float = 0.02 # Default very low confidence for DETR models
|
15
|
+
resolution: int = 150 # DPI for rendering pages to images
|
16
|
+
device: Optional[str] = "cpu" # Preferred device ('cpu', 'cuda', 'mps', etc.)
|
17
|
+
|
18
|
+
# Label mapping from model outputs to standard states
|
19
|
+
label_mapping: Dict[str, str] = field(
|
20
|
+
default_factory=lambda: {
|
21
|
+
# Common mappings
|
22
|
+
"checkbox": "unchecked",
|
23
|
+
"checked_checkbox": "checked",
|
24
|
+
"checkbox_checked": "checked",
|
25
|
+
"unchecked_checkbox": "unchecked",
|
26
|
+
"checkbox_unchecked": "unchecked",
|
27
|
+
# Numeric mappings
|
28
|
+
"0": "unchecked",
|
29
|
+
"1": "checked",
|
30
|
+
# Descriptive mappings
|
31
|
+
"empty": "unchecked",
|
32
|
+
"tick": "checked",
|
33
|
+
"filled": "checked",
|
34
|
+
"blank": "unchecked",
|
35
|
+
}
|
36
|
+
)
|
37
|
+
|
38
|
+
# Non-max suppression parameters
|
39
|
+
nms_threshold: float = 0.1 # IoU threshold for overlapping boxes (low for checkboxes)
|
40
|
+
|
41
|
+
# Text filtering
|
42
|
+
reject_with_text: bool = (
|
43
|
+
True # Reject detections that contain text (checkboxes should be empty)
|
44
|
+
)
|
45
|
+
|
46
|
+
# Extra arguments for engine-specific parameters
|
47
|
+
extra_args: Dict[str, Any] = field(default_factory=dict)
|
48
|
+
|
49
|
+
|
50
|
+
@dataclass
|
51
|
+
class RTDETRCheckboxOptions(CheckboxOptions):
|
52
|
+
"""Options specific to RT-DETR checkbox detection models."""
|
53
|
+
|
54
|
+
model_repo: str = "wendys-llc/rtdetr-v2-r50-chkbx" # Default checkbox model
|
55
|
+
model_revision: Optional[str] = None # Specific model revision
|
56
|
+
image_processor_repo: Optional[str] = None # Override image processor if needed
|
57
|
+
|
58
|
+
# RT-DETR specific parameters
|
59
|
+
max_detections: int = 100 # Maximum number of detections per image
|
60
|
+
post_process_threshold: float = 0.0 # Threshold for post-processing (0.0 for all)
|
@@ -0,0 +1,95 @@
|
|
1
|
+
"""Checkbox detection mixin for Page and Region classes."""
|
2
|
+
|
3
|
+
import logging
|
4
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
5
|
+
|
6
|
+
if TYPE_CHECKING:
|
7
|
+
from natural_pdf.analyzers.checkbox.checkbox_options import CheckboxOptions
|
8
|
+
from natural_pdf.elements.element_collection import ElementCollection
|
9
|
+
from natural_pdf.elements.region import Region
|
10
|
+
|
11
|
+
logger = logging.getLogger(__name__)
|
12
|
+
|
13
|
+
|
14
|
+
class CheckboxDetectionMixin:
|
15
|
+
"""Mixin to add checkbox detection capabilities to Page and Region classes."""
|
16
|
+
|
17
|
+
def detect_checkboxes(
|
18
|
+
self,
|
19
|
+
engine: Optional[str] = None,
|
20
|
+
options: Optional[Union["CheckboxOptions", Dict[str, Any]]] = None,
|
21
|
+
confidence: Optional[float] = None,
|
22
|
+
resolution: Optional[int] = None,
|
23
|
+
device: Optional[str] = None,
|
24
|
+
existing: str = "replace",
|
25
|
+
limit: Optional[int] = None,
|
26
|
+
**kwargs,
|
27
|
+
) -> "ElementCollection[Region]":
|
28
|
+
"""
|
29
|
+
Detect checkboxes in the page or region.
|
30
|
+
|
31
|
+
This method identifies checkboxes and their states (checked/unchecked) using
|
32
|
+
computer vision models. Detected checkboxes are added as Region objects with
|
33
|
+
type="checkbox" and can be accessed via selectors like page.find_all('checkbox').
|
34
|
+
|
35
|
+
Args:
|
36
|
+
engine: Name of the detection engine (default: 'rtdetr' for wendys model)
|
37
|
+
options: CheckboxOptions instance or dict of options for advanced configuration
|
38
|
+
confidence: Minimum confidence threshold (default: 0.02 for DETR models)
|
39
|
+
resolution: DPI for rendering pages to images (default: 150)
|
40
|
+
device: Device for inference ('cpu', 'cuda', 'mps', etc.)
|
41
|
+
existing: How to handle existing checkbox regions: 'replace' (default) or 'append'
|
42
|
+
limit: Maximum number of checkboxes to detect (useful when you know the expected count)
|
43
|
+
**kwargs: Additional engine-specific arguments
|
44
|
+
|
45
|
+
Returns:
|
46
|
+
ElementCollection containing detected checkbox Region objects with attributes:
|
47
|
+
- region_type: "checkbox"
|
48
|
+
- is_checked: bool indicating if checkbox is checked
|
49
|
+
- checkbox_state: "checked" or "unchecked"
|
50
|
+
- confidence: detection confidence score
|
51
|
+
|
52
|
+
Examples:
|
53
|
+
# Basic detection
|
54
|
+
checkboxes = page.detect_checkboxes()
|
55
|
+
|
56
|
+
# Find checked boxes
|
57
|
+
checked = page.find_all('checkbox:checked')
|
58
|
+
unchecked = page.find_all('checkbox:unchecked')
|
59
|
+
|
60
|
+
# Limit to expected number
|
61
|
+
checkboxes = page.detect_checkboxes(limit=10)
|
62
|
+
|
63
|
+
# High confidence detection
|
64
|
+
checkboxes = page.detect_checkboxes(confidence=0.9)
|
65
|
+
|
66
|
+
# GPU acceleration
|
67
|
+
checkboxes = page.detect_checkboxes(device='cuda')
|
68
|
+
|
69
|
+
# Custom model
|
70
|
+
from natural_pdf import CheckboxOptions
|
71
|
+
options = CheckboxOptions(model_repo="your-org/your-checkbox-model")
|
72
|
+
checkboxes = page.detect_checkboxes(options=options)
|
73
|
+
"""
|
74
|
+
# Lazy import to avoid circular dependencies
|
75
|
+
from natural_pdf.analyzers.checkbox.checkbox_analyzer import CheckboxAnalyzer
|
76
|
+
|
77
|
+
# Create analyzer
|
78
|
+
analyzer = CheckboxAnalyzer(self)
|
79
|
+
|
80
|
+
# Run detection
|
81
|
+
regions = analyzer.detect_checkboxes(
|
82
|
+
engine=engine,
|
83
|
+
options=options,
|
84
|
+
confidence=confidence,
|
85
|
+
resolution=resolution,
|
86
|
+
device=device,
|
87
|
+
existing=existing,
|
88
|
+
limit=limit,
|
89
|
+
**kwargs,
|
90
|
+
)
|
91
|
+
|
92
|
+
# Return as ElementCollection
|
93
|
+
from natural_pdf.elements.element_collection import ElementCollection
|
94
|
+
|
95
|
+
return ElementCollection(regions)
|
@@ -0,0 +1,201 @@
|
|
1
|
+
"""RT-DETR based checkbox detector implementation."""
|
2
|
+
|
3
|
+
import logging
|
4
|
+
from typing import Any, Dict, List, Optional, Tuple
|
5
|
+
|
6
|
+
from PIL import Image
|
7
|
+
|
8
|
+
from .base import CheckboxDetector
|
9
|
+
from .checkbox_options import RTDETRCheckboxOptions
|
10
|
+
|
11
|
+
logger = logging.getLogger(__name__)
|
12
|
+
|
13
|
+
# Lazy imports cache
|
14
|
+
_transformers_cache = None
|
15
|
+
|
16
|
+
|
17
|
+
def _get_transformers():
|
18
|
+
"""Lazy import transformers to avoid heavy dependency at module load."""
|
19
|
+
global _transformers_cache
|
20
|
+
if _transformers_cache is None:
|
21
|
+
try:
|
22
|
+
from transformers import AutoImageProcessor, AutoModelForObjectDetection
|
23
|
+
|
24
|
+
_transformers_cache = (AutoImageProcessor, AutoModelForObjectDetection)
|
25
|
+
except ImportError:
|
26
|
+
raise ImportError(
|
27
|
+
"transformers library is required for RT-DETR checkbox detection. "
|
28
|
+
"Install it with: pip install transformers"
|
29
|
+
)
|
30
|
+
return _transformers_cache
|
31
|
+
|
32
|
+
|
33
|
+
def _get_torch():
|
34
|
+
"""Lazy import torch."""
|
35
|
+
try:
|
36
|
+
import torch
|
37
|
+
|
38
|
+
return torch
|
39
|
+
except ImportError:
|
40
|
+
raise ImportError(
|
41
|
+
"torch is required for RT-DETR checkbox detection. "
|
42
|
+
"Install it with: pip install torch"
|
43
|
+
)
|
44
|
+
|
45
|
+
|
46
|
+
class RTDETRCheckboxDetector(CheckboxDetector):
|
47
|
+
"""RT-DETR based checkbox detector using HuggingFace transformers."""
|
48
|
+
|
49
|
+
def __init__(self):
|
50
|
+
"""Initialize the RT-DETR checkbox detector."""
|
51
|
+
super().__init__()
|
52
|
+
|
53
|
+
@classmethod
|
54
|
+
def is_available(cls) -> bool:
|
55
|
+
"""Check if transformers and torch are available."""
|
56
|
+
try:
|
57
|
+
_get_transformers()
|
58
|
+
_get_torch()
|
59
|
+
return True
|
60
|
+
except ImportError:
|
61
|
+
return False
|
62
|
+
|
63
|
+
def _get_cache_key(self, options: RTDETRCheckboxOptions) -> str:
|
64
|
+
"""Generate cache key including model repo and revision."""
|
65
|
+
base_key = super()._get_cache_key(options)
|
66
|
+
model_key = options.model_repo.replace("/", "_")
|
67
|
+
revision_key = options.model_revision or "default"
|
68
|
+
return f"{base_key}_{model_key}_{revision_key}"
|
69
|
+
|
70
|
+
def _load_model_from_options(self, options: RTDETRCheckboxOptions) -> Dict[str, Any]:
|
71
|
+
"""Load RT-DETR model and processor from HuggingFace."""
|
72
|
+
AutoImageProcessor, AutoModelForObjectDetection = _get_transformers()
|
73
|
+
torch = _get_torch()
|
74
|
+
|
75
|
+
try:
|
76
|
+
# Load image processor
|
77
|
+
if options.image_processor_repo:
|
78
|
+
image_processor = AutoImageProcessor.from_pretrained(
|
79
|
+
options.image_processor_repo, revision=options.model_revision
|
80
|
+
)
|
81
|
+
else:
|
82
|
+
image_processor = AutoImageProcessor.from_pretrained(
|
83
|
+
options.model_repo, revision=options.model_revision
|
84
|
+
)
|
85
|
+
|
86
|
+
# Load model
|
87
|
+
model = AutoModelForObjectDetection.from_pretrained(
|
88
|
+
options.model_repo, revision=options.model_revision
|
89
|
+
)
|
90
|
+
|
91
|
+
# Move to device
|
92
|
+
if options.device and options.device != "cpu":
|
93
|
+
if options.device == "cuda" and torch.cuda.is_available():
|
94
|
+
model = model.to("cuda")
|
95
|
+
elif options.device == "mps" and torch.backends.mps.is_available():
|
96
|
+
model = model.to("mps")
|
97
|
+
else:
|
98
|
+
self.logger.warning(
|
99
|
+
f"Requested device '{options.device}' not available, using CPU"
|
100
|
+
)
|
101
|
+
model = model.to("cpu")
|
102
|
+
else:
|
103
|
+
model = model.to("cpu")
|
104
|
+
|
105
|
+
# Set to eval mode
|
106
|
+
model.eval()
|
107
|
+
|
108
|
+
return {
|
109
|
+
"model": model,
|
110
|
+
"processor": image_processor,
|
111
|
+
"device": next(model.parameters()).device,
|
112
|
+
}
|
113
|
+
|
114
|
+
except Exception as e:
|
115
|
+
raise RuntimeError(
|
116
|
+
f"Failed to load checkbox model '{options.model_repo}'. "
|
117
|
+
f"This may be due to network issues or missing credentials. "
|
118
|
+
f"Original error: {e}"
|
119
|
+
)
|
120
|
+
|
121
|
+
def detect(self, image: Image.Image, options: RTDETRCheckboxOptions) -> List[Dict[str, Any]]:
|
122
|
+
"""
|
123
|
+
Detect checkboxes in the given image using RT-DETR.
|
124
|
+
|
125
|
+
Args:
|
126
|
+
image: PIL Image to analyze
|
127
|
+
options: RT-DETR specific options
|
128
|
+
|
129
|
+
Returns:
|
130
|
+
List of standardized detection dictionaries
|
131
|
+
"""
|
132
|
+
torch = _get_torch()
|
133
|
+
|
134
|
+
# Get cached model
|
135
|
+
model_dict = self._get_model(options)
|
136
|
+
model = model_dict["model"]
|
137
|
+
processor = model_dict["processor"]
|
138
|
+
device = model_dict["device"]
|
139
|
+
|
140
|
+
# Prepare inputs
|
141
|
+
inputs = processor(images=[image], return_tensors="pt")
|
142
|
+
if device.type != "cpu":
|
143
|
+
inputs = {k: v.to(device) for k, v in inputs.items()}
|
144
|
+
|
145
|
+
# Run inference
|
146
|
+
with torch.no_grad():
|
147
|
+
outputs = model(**inputs)
|
148
|
+
|
149
|
+
# Post-process results
|
150
|
+
target_sizes = torch.tensor([image.size[::-1]]) # (height, width)
|
151
|
+
if device.type != "cpu":
|
152
|
+
target_sizes = target_sizes.to(device)
|
153
|
+
|
154
|
+
results = processor.post_process_object_detection(
|
155
|
+
outputs, threshold=options.post_process_threshold, target_sizes=target_sizes
|
156
|
+
)[0]
|
157
|
+
|
158
|
+
# Convert to standardized format
|
159
|
+
detections = []
|
160
|
+
for i in range(len(results["scores"])):
|
161
|
+
score = results["scores"][i].item()
|
162
|
+
|
163
|
+
# Apply confidence threshold
|
164
|
+
if score < options.confidence:
|
165
|
+
continue
|
166
|
+
|
167
|
+
label = results["labels"][i].item()
|
168
|
+
box = results["boxes"][i].tolist()
|
169
|
+
|
170
|
+
# Get label text from model config
|
171
|
+
if hasattr(model.config, "id2label") and label in model.config.id2label:
|
172
|
+
label_text = model.config.id2label[label]
|
173
|
+
else:
|
174
|
+
label_text = str(label)
|
175
|
+
|
176
|
+
# Map to checkbox state
|
177
|
+
is_checked, state = self._map_label_to_state(label_text, options)
|
178
|
+
|
179
|
+
detection = {
|
180
|
+
"bbox": tuple(box), # (x0, y0, x1, y1)
|
181
|
+
"class": label_text,
|
182
|
+
"normalized_class": "checkbox",
|
183
|
+
"is_checked": is_checked,
|
184
|
+
"checkbox_state": state,
|
185
|
+
"confidence": score,
|
186
|
+
"model": options.model_repo.split("/")[-1], # Short model name
|
187
|
+
"source": "checkbox",
|
188
|
+
}
|
189
|
+
detections.append(detection)
|
190
|
+
|
191
|
+
# Apply NMS if needed
|
192
|
+
if options.nms_threshold > 0:
|
193
|
+
detections = self._apply_nms(detections, options.nms_threshold)
|
194
|
+
|
195
|
+
# Limit detections if specified
|
196
|
+
if options.max_detections > 0 and len(detections) > options.max_detections:
|
197
|
+
# Sort by confidence and keep top N
|
198
|
+
detections = sorted(detections, key=lambda x: x["confidence"], reverse=True)
|
199
|
+
detections = detections[: options.max_detections]
|
200
|
+
|
201
|
+
return detections
|
@@ -99,10 +99,6 @@ class ApplyMixin:
|
|
99
99
|
|
100
100
|
results = [func(item, *args, **kwargs) for item in items_iterable]
|
101
101
|
|
102
|
-
# If results is empty, return an empty list
|
103
|
-
if not results:
|
104
|
-
return []
|
105
|
-
|
106
102
|
# Import here to avoid circular imports
|
107
103
|
from natural_pdf import PDF, Page
|
108
104
|
from natural_pdf.core.page_collection import PageCollection
|
@@ -111,11 +107,24 @@ class ApplyMixin:
|
|
111
107
|
from natural_pdf.elements.element_collection import ElementCollection
|
112
108
|
from natural_pdf.elements.region import Region
|
113
109
|
|
110
|
+
# Determine the return type based on the input collection type
|
111
|
+
# This handles empty results correctly
|
112
|
+
if self.__class__.__name__ == "ElementCollection":
|
113
|
+
return ElementCollection(results)
|
114
|
+
elif self.__class__.__name__ == "PageCollection":
|
115
|
+
return PageCollection(results)
|
116
|
+
elif self.__class__.__name__ == "PDFCollection":
|
117
|
+
return PDFCollection(results)
|
118
|
+
|
119
|
+
# If not a known collection type, try to infer from results
|
120
|
+
if not results:
|
121
|
+
return []
|
122
|
+
|
114
123
|
first_non_none = next((r for r in results if r is not None), None)
|
115
124
|
first_type = type(first_non_none) if first_non_none is not None else None
|
116
125
|
|
117
126
|
# Return the appropriate collection based on result type (...generally)
|
118
|
-
if issubclass(first_type, Element) or issubclass(first_type, Region):
|
127
|
+
if first_type and (issubclass(first_type, Element) or issubclass(first_type, Region)):
|
119
128
|
return ElementCollection(results)
|
120
129
|
elif first_type == PDF:
|
121
130
|
return PDFCollection(results)
|
@@ -584,13 +584,17 @@ class ElementManager:
|
|
584
584
|
|
585
585
|
# Add regions if they exist
|
586
586
|
if hasattr(self._page, "_regions") and (
|
587
|
-
"detected" in self._page._regions
|
587
|
+
"detected" in self._page._regions
|
588
|
+
or "named" in self._page._regions
|
589
|
+
or "checkbox" in self._page._regions
|
588
590
|
):
|
589
591
|
regions = []
|
590
592
|
if "detected" in self._page._regions:
|
591
593
|
regions.extend(self._page._regions["detected"])
|
592
594
|
if "named" in self._page._regions:
|
593
595
|
regions.extend(self._page._regions["named"].values())
|
596
|
+
if "checkbox" in self._page._regions:
|
597
|
+
regions.extend(self._page._regions["checkbox"])
|
594
598
|
self._elements["regions"] = regions
|
595
599
|
logger.debug(f"Page {self._page.number}: Added {len(regions)} regions.")
|
596
600
|
else:
|
natural_pdf/core/page.py
CHANGED
@@ -50,6 +50,7 @@ import numpy as np
|
|
50
50
|
from pdfplumber.utils.geometry import get_bbox_overlap, merge_bboxes, objects_to_bbox
|
51
51
|
from pdfplumber.utils.text import TEXTMAP_KWARGS, WORD_EXTRACTOR_KWARGS, chars_to_textmap
|
52
52
|
|
53
|
+
from natural_pdf.analyzers.checkbox.mixin import CheckboxDetectionMixin
|
53
54
|
from natural_pdf.analyzers.layout.layout_analyzer import LayoutAnalyzer
|
54
55
|
from natural_pdf.analyzers.layout.layout_manager import LayoutManager
|
55
56
|
from natural_pdf.analyzers.layout.layout_options import LayoutOptions
|
@@ -103,6 +104,7 @@ class Page(
|
|
103
104
|
ClassificationMixin,
|
104
105
|
ExtractionMixin,
|
105
106
|
ShapeDetectionMixin,
|
107
|
+
CheckboxDetectionMixin,
|
106
108
|
DescribeMixin,
|
107
109
|
VisualSearchMixin,
|
108
110
|
Visualizable,
|
@@ -1491,6 +1493,65 @@ class Page(
|
|
1491
1493
|
"Cannot sort elements in reading order: Missing required attributes (top, x0)."
|
1492
1494
|
)
|
1493
1495
|
|
1496
|
+
# Handle :closest pseudo-class for fuzzy text matching
|
1497
|
+
for pseudo in selector_obj.get("pseudo_classes", []):
|
1498
|
+
name = pseudo.get("name")
|
1499
|
+
if name == "closest" and pseudo.get("args") is not None:
|
1500
|
+
import difflib
|
1501
|
+
|
1502
|
+
# Parse search text and threshold
|
1503
|
+
search_text = str(pseudo["args"]).strip()
|
1504
|
+
threshold = 0.0 # Default threshold
|
1505
|
+
|
1506
|
+
# Handle empty search text
|
1507
|
+
if not search_text:
|
1508
|
+
matching_elements = []
|
1509
|
+
break
|
1510
|
+
|
1511
|
+
# Check if threshold is specified with @ separator
|
1512
|
+
if "@" in search_text and search_text.count("@") == 1:
|
1513
|
+
text_part, threshold_part = search_text.rsplit("@", 1)
|
1514
|
+
try:
|
1515
|
+
threshold = float(threshold_part)
|
1516
|
+
search_text = text_part.strip()
|
1517
|
+
except (ValueError, TypeError):
|
1518
|
+
pass # Keep original search_text and default threshold
|
1519
|
+
|
1520
|
+
# Determine case sensitivity
|
1521
|
+
ignore_case = not kwargs.get("case", False)
|
1522
|
+
|
1523
|
+
# Calculate similarity scores for all elements
|
1524
|
+
scored_elements = []
|
1525
|
+
|
1526
|
+
for el in matching_elements:
|
1527
|
+
if hasattr(el, "text") and el.text:
|
1528
|
+
el_text = el.text.strip()
|
1529
|
+
search_term = search_text
|
1530
|
+
|
1531
|
+
if ignore_case:
|
1532
|
+
el_text = el_text.lower()
|
1533
|
+
search_term = search_term.lower()
|
1534
|
+
|
1535
|
+
# Calculate similarity ratio
|
1536
|
+
ratio = difflib.SequenceMatcher(None, search_term, el_text).ratio()
|
1537
|
+
|
1538
|
+
# Check if element contains the search term as substring
|
1539
|
+
contains_match = search_term in el_text
|
1540
|
+
|
1541
|
+
# Store element with its similarity score and contains flag
|
1542
|
+
if ratio >= threshold:
|
1543
|
+
scored_elements.append((ratio, contains_match, el))
|
1544
|
+
|
1545
|
+
# Sort by:
|
1546
|
+
# 1. Contains match (True before False)
|
1547
|
+
# 2. Similarity score (highest first)
|
1548
|
+
# This ensures substring matches come first but are sorted by similarity
|
1549
|
+
scored_elements.sort(key=lambda x: (x[1], x[0]), reverse=True)
|
1550
|
+
|
1551
|
+
# Extract just the elements
|
1552
|
+
matching_elements = [el for _, _, el in scored_elements]
|
1553
|
+
break # Only process the first :closest pseudo-class
|
1554
|
+
|
1494
1555
|
# Handle collection-level pseudo-classes (:first, :last)
|
1495
1556
|
for pseudo in selector_obj.get("pseudo_classes", []):
|
1496
1557
|
name = pseudo.get("name")
|