natural-pdf 0.2.18__py3-none-any.whl → 0.2.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- natural_pdf/__init__.py +8 -0
- natural_pdf/analyzers/checkbox/__init__.py +6 -0
- natural_pdf/analyzers/checkbox/base.py +265 -0
- natural_pdf/analyzers/checkbox/checkbox_analyzer.py +329 -0
- natural_pdf/analyzers/checkbox/checkbox_manager.py +166 -0
- natural_pdf/analyzers/checkbox/checkbox_options.py +60 -0
- natural_pdf/analyzers/checkbox/mixin.py +95 -0
- natural_pdf/analyzers/checkbox/rtdetr.py +201 -0
- natural_pdf/collections/mixins.py +14 -5
- natural_pdf/core/element_manager.py +5 -1
- natural_pdf/core/page.py +61 -0
- natural_pdf/core/page_collection.py +41 -1
- natural_pdf/core/pdf.py +24 -1
- natural_pdf/describe/base.py +20 -0
- natural_pdf/elements/base.py +152 -10
- natural_pdf/elements/element_collection.py +41 -2
- natural_pdf/elements/region.py +115 -2
- natural_pdf/judge.py +1509 -0
- natural_pdf/selectors/parser.py +42 -1
- {natural_pdf-0.2.18.dist-info → natural_pdf-0.2.19.dist-info}/METADATA +1 -1
- {natural_pdf-0.2.18.dist-info → natural_pdf-0.2.19.dist-info}/RECORD +41 -17
- temp/check_model.py +49 -0
- temp/check_pdf_content.py +9 -0
- temp/checkbox_checks.py +590 -0
- temp/checkbox_simple.py +117 -0
- temp/checkbox_ux_ideas.py +400 -0
- temp/context_manager_prototype.py +177 -0
- temp/convert_to_hf.py +60 -0
- temp/demo_text_closest.py +66 -0
- temp/inspect_model.py +43 -0
- temp/rtdetr_dinov2_test.py +49 -0
- temp/test_closest_debug.py +26 -0
- temp/test_closest_debug2.py +22 -0
- temp/test_context_exploration.py +85 -0
- temp/test_durham.py +30 -0
- temp/test_empty_string.py +16 -0
- temp/test_similarity.py +15 -0
- {natural_pdf-0.2.18.dist-info → natural_pdf-0.2.19.dist-info}/WHEEL +0 -0
- {natural_pdf-0.2.18.dist-info → natural_pdf-0.2.19.dist-info}/entry_points.txt +0 -0
- {natural_pdf-0.2.18.dist-info → natural_pdf-0.2.19.dist-info}/licenses/LICENSE +0 -0
- {natural_pdf-0.2.18.dist-info → natural_pdf-0.2.19.dist-info}/top_level.txt +0 -0
natural_pdf/judge.py
ADDED
@@ -0,0 +1,1509 @@
|
|
1
|
+
"""
|
2
|
+
Visual Judge for classifying regions based on image content.
|
3
|
+
|
4
|
+
This module provides a simple visual classifier that learns from examples
|
5
|
+
to classify regions (like checkboxes) into categories. It uses basic image
|
6
|
+
metrics rather than neural networks for fast, interpretable results.
|
7
|
+
"""
|
8
|
+
|
9
|
+
import hashlib
|
10
|
+
import json
|
11
|
+
import logging
|
12
|
+
import shutil
|
13
|
+
from collections import namedtuple
|
14
|
+
from pathlib import Path
|
15
|
+
from typing import Dict, List, Optional, Tuple, Union
|
16
|
+
|
17
|
+
import numpy as np
|
18
|
+
from PIL import Image
|
19
|
+
|
20
|
+
logger = logging.getLogger(__name__)
|
21
|
+
|
22
|
+
# Return types
|
23
|
+
Decision = namedtuple("Decision", ["label", "score"])
|
24
|
+
PickResult = namedtuple("PickResult", ["region", "index", "label", "score"])
|
25
|
+
|
26
|
+
|
27
|
+
class JudgeError(Exception):
|
28
|
+
"""Raised when Judge operations fail."""
|
29
|
+
|
30
|
+
pass
|
31
|
+
|
32
|
+
|
33
|
+
class Judge:
|
34
|
+
"""
|
35
|
+
Visual classifier for regions using simple image metrics.
|
36
|
+
|
37
|
+
Requires class labels to be specified. For binary classification,
|
38
|
+
requires at least one example of each class before making decisions.
|
39
|
+
|
40
|
+
Examples:
|
41
|
+
Checkbox detection:
|
42
|
+
```python
|
43
|
+
judge = Judge("checkboxes", labels=["unchecked", "checked"])
|
44
|
+
judge.add(empty_box, "unchecked")
|
45
|
+
judge.add(marked_box, "checked")
|
46
|
+
|
47
|
+
result = judge.decide(new_box)
|
48
|
+
if result.label == "checked":
|
49
|
+
print("Box is checked!")
|
50
|
+
```
|
51
|
+
|
52
|
+
Signature detection:
|
53
|
+
```python
|
54
|
+
judge = Judge("signatures", labels=["unsigned", "signed"])
|
55
|
+
judge.add(blank_area, "unsigned")
|
56
|
+
judge.add(signature_area, "signed")
|
57
|
+
|
58
|
+
result = judge.decide(new_region)
|
59
|
+
print(f"Classification: {result.label} (confidence: {result.score})")
|
60
|
+
```
|
61
|
+
"""
|
62
|
+
|
63
|
+
def __init__(
|
64
|
+
self,
|
65
|
+
name: str,
|
66
|
+
labels: List[str],
|
67
|
+
base_dir: Optional[str] = None,
|
68
|
+
target_prior: Optional[float] = None,
|
69
|
+
):
|
70
|
+
"""
|
71
|
+
Initialize a Judge for visual classification.
|
72
|
+
|
73
|
+
Args:
|
74
|
+
name: Name for this judge (used for folder name)
|
75
|
+
labels: Class labels (required, typically 2 for binary classification)
|
76
|
+
base_dir: Base directory for storage. Defaults to current directory
|
77
|
+
target_prior: Target prior probability for the FIRST label in the labels list.
|
78
|
+
- 0.5 (default) = neutral, treats both classes equally
|
79
|
+
- >0.5 = favors labels[0]
|
80
|
+
- <0.5 = favors labels[1]
|
81
|
+
Example: Judge("cb", ["checked", "unchecked"], target_prior=0.6)
|
82
|
+
favors detecting "checked" checkboxes.
|
83
|
+
"""
|
84
|
+
if not labels or len(labels) != 2:
|
85
|
+
raise JudgeError("Judge requires exactly 2 class labels (binary classification only)")
|
86
|
+
|
87
|
+
self.name = name
|
88
|
+
self.labels = labels
|
89
|
+
self.target_prior = target_prior if target_prior is not None else 0.5
|
90
|
+
|
91
|
+
# Set up directory structure
|
92
|
+
self.base_dir = Path(base_dir) if base_dir else Path.cwd()
|
93
|
+
self.root_dir = self.base_dir / name
|
94
|
+
self.root_dir.mkdir(exist_ok=True)
|
95
|
+
|
96
|
+
# Create label directories
|
97
|
+
for label in self.labels:
|
98
|
+
(self.root_dir / label).mkdir(exist_ok=True)
|
99
|
+
(self.root_dir / "unlabeled").mkdir(exist_ok=True)
|
100
|
+
(self.root_dir / "_removed").mkdir(exist_ok=True)
|
101
|
+
|
102
|
+
# Config file
|
103
|
+
self.config_path = self.root_dir / "judge.json"
|
104
|
+
|
105
|
+
# Load existing config or initialize
|
106
|
+
self.thresholds = {}
|
107
|
+
self.metrics_info = {}
|
108
|
+
if self.config_path.exists():
|
109
|
+
self._load_config()
|
110
|
+
|
111
|
+
def add(self, region, label: Optional[str] = None) -> None:
|
112
|
+
"""
|
113
|
+
Add a region to the judge's dataset.
|
114
|
+
|
115
|
+
Args:
|
116
|
+
region: Region object to add
|
117
|
+
label: Class label. If None, added to unlabeled for later teaching
|
118
|
+
|
119
|
+
Raises:
|
120
|
+
JudgeError: If label is not in allowed labels
|
121
|
+
"""
|
122
|
+
if label is not None and label not in self.labels:
|
123
|
+
raise JudgeError(f"Label '{label}' not in allowed labels: {self.labels}")
|
124
|
+
|
125
|
+
# Render region to image
|
126
|
+
try:
|
127
|
+
img = region.render(crop=True)
|
128
|
+
if not isinstance(img, Image.Image):
|
129
|
+
img = Image.fromarray(img)
|
130
|
+
except Exception as e:
|
131
|
+
raise JudgeError(f"Failed to render region: {e}")
|
132
|
+
|
133
|
+
# Convert to RGB if needed
|
134
|
+
if img.mode != "RGB":
|
135
|
+
img = img.convert("RGB")
|
136
|
+
|
137
|
+
# Generate hash from image content
|
138
|
+
img_array = np.array(img)
|
139
|
+
img_hash = hashlib.md5(img_array.tobytes()).hexdigest()[:12]
|
140
|
+
|
141
|
+
# Determine target directory
|
142
|
+
target_dir = self.root_dir / (label if label else "unlabeled")
|
143
|
+
target_path = target_dir / f"{img_hash}.png"
|
144
|
+
|
145
|
+
# Check if hash already exists anywhere
|
146
|
+
existing_locations = []
|
147
|
+
for check_label in self.labels + ["unlabeled", "_removed"]:
|
148
|
+
check_path = self.root_dir / check_label / f"{img_hash}.png"
|
149
|
+
if check_path.exists():
|
150
|
+
existing_locations.append(check_label)
|
151
|
+
|
152
|
+
if existing_locations:
|
153
|
+
logger.warning(f"Duplicate image detected (hash: {img_hash})")
|
154
|
+
logger.warning(f"Already exists in: {', '.join(existing_locations)}")
|
155
|
+
print(f"⚠️ Duplicate image - already exists in: {', '.join(existing_locations)}")
|
156
|
+
return
|
157
|
+
|
158
|
+
# Save image
|
159
|
+
img.save(target_path)
|
160
|
+
logger.debug(f"Added image {img_hash} to {label if label else 'unlabeled'}")
|
161
|
+
|
162
|
+
def teach(self, labels: Optional[List[str]] = None, review: bool = False) -> None:
|
163
|
+
"""
|
164
|
+
Interactive teaching interface using IPython widgets.
|
165
|
+
|
166
|
+
Args:
|
167
|
+
labels: Labels to use for teaching. Defaults to self.labels
|
168
|
+
review: If True, review already labeled images for re-classification
|
169
|
+
"""
|
170
|
+
# Check for IPython environment
|
171
|
+
try:
|
172
|
+
import ipywidgets as widgets
|
173
|
+
from IPython.display import clear_output, display
|
174
|
+
except ImportError:
|
175
|
+
raise JudgeError(
|
176
|
+
"Teaching requires IPython and ipywidgets. Use 'pip install ipywidgets'"
|
177
|
+
)
|
178
|
+
|
179
|
+
labels = labels or self.labels
|
180
|
+
|
181
|
+
# Get images to review
|
182
|
+
if review:
|
183
|
+
# Get all labeled images for review
|
184
|
+
files_to_review = []
|
185
|
+
for label in self.labels:
|
186
|
+
label_dir = self.root_dir / label
|
187
|
+
for img_path in sorted(label_dir.glob("*.png")):
|
188
|
+
files_to_review.append((img_path, label))
|
189
|
+
|
190
|
+
if not files_to_review:
|
191
|
+
print("No labeled images to review")
|
192
|
+
return
|
193
|
+
|
194
|
+
# Shuffle for review
|
195
|
+
import random
|
196
|
+
|
197
|
+
random.shuffle(files_to_review)
|
198
|
+
review_files = [f[0] for f in files_to_review]
|
199
|
+
original_labels = {str(f[0]): f[1] for f in files_to_review}
|
200
|
+
else:
|
201
|
+
# Get unlabeled images
|
202
|
+
unlabeled_dir = self.root_dir / "unlabeled"
|
203
|
+
review_files = sorted(unlabeled_dir.glob("*.png"))
|
204
|
+
original_labels = {}
|
205
|
+
|
206
|
+
if not review_files:
|
207
|
+
print("No unlabeled images to teach")
|
208
|
+
return
|
209
|
+
|
210
|
+
# State for teaching
|
211
|
+
self._teaching_state = {
|
212
|
+
"current_index": 0,
|
213
|
+
"labeled_count": 0,
|
214
|
+
"removed_count": 0,
|
215
|
+
"files": review_files,
|
216
|
+
"labels": labels,
|
217
|
+
"review_mode": review,
|
218
|
+
"original_labels": original_labels,
|
219
|
+
}
|
220
|
+
|
221
|
+
# Create widgets
|
222
|
+
image_widget = widgets.Image()
|
223
|
+
status_label = widgets.Label()
|
224
|
+
|
225
|
+
# Create buttons for labeling
|
226
|
+
button_layout = widgets.Layout(width="auto", margin="5px")
|
227
|
+
|
228
|
+
btn_prev = widgets.Button(description="↑ Previous", layout=button_layout)
|
229
|
+
btn_class1 = widgets.Button(
|
230
|
+
description=f"← {labels[0]}", layout=button_layout, button_style="primary"
|
231
|
+
)
|
232
|
+
btn_class2 = widgets.Button(
|
233
|
+
description=f"→ {labels[1]}", layout=button_layout, button_style="success"
|
234
|
+
)
|
235
|
+
btn_skip = widgets.Button(description="↓ Skip", layout=button_layout)
|
236
|
+
btn_remove = widgets.Button(
|
237
|
+
description="✗ Remove", layout=button_layout, button_style="danger"
|
238
|
+
)
|
239
|
+
|
240
|
+
button_box = widgets.HBox([btn_prev, btn_class1, btn_class2, btn_skip, btn_remove])
|
241
|
+
|
242
|
+
# Keyboard shortcuts info
|
243
|
+
info_label = widgets.Label(
|
244
|
+
value="Keys: ↑ prev | ← "
|
245
|
+
+ labels[0]
|
246
|
+
+ " | → "
|
247
|
+
+ labels[1]
|
248
|
+
+ " | ↓ skip | Delete remove"
|
249
|
+
)
|
250
|
+
|
251
|
+
def update_display():
|
252
|
+
"""Update the displayed image and status."""
|
253
|
+
state = self._teaching_state
|
254
|
+
if 0 <= state["current_index"] < len(state["files"]):
|
255
|
+
img_path = state["files"][state["current_index"]]
|
256
|
+
with open(img_path, "rb") as f:
|
257
|
+
image_widget.value = f.read()
|
258
|
+
|
259
|
+
# Build status text
|
260
|
+
status_text = f"Image {state['current_index'] + 1} of {len(state['files'])}"
|
261
|
+
if state["review_mode"]:
|
262
|
+
current_label = state["original_labels"].get(str(img_path), "unknown")
|
263
|
+
status_text += f" (Currently: {current_label})"
|
264
|
+
status_text += f" | Labeled: {state['labeled_count']}"
|
265
|
+
if state["removed_count"] > 0:
|
266
|
+
status_text += f" | Removed: {state['removed_count']}"
|
267
|
+
|
268
|
+
status_label.value = status_text
|
269
|
+
|
270
|
+
# Update button states
|
271
|
+
btn_prev.disabled = state["current_index"] == 0
|
272
|
+
else:
|
273
|
+
status_label.value = "Teaching complete!"
|
274
|
+
# Hide the image widget instead of showing broken image
|
275
|
+
image_widget.layout.display = "none"
|
276
|
+
# Disable all buttons
|
277
|
+
btn_prev.disabled = True
|
278
|
+
btn_class1.disabled = True
|
279
|
+
btn_class2.disabled = True
|
280
|
+
btn_skip.disabled = True
|
281
|
+
|
282
|
+
# Auto-retrain
|
283
|
+
if state["labeled_count"] > 0 or state["removed_count"] > 0:
|
284
|
+
clear_output(wait=True)
|
285
|
+
print("Teaching complete!")
|
286
|
+
print(f"Labeled: {state['labeled_count']} images")
|
287
|
+
if state["removed_count"] > 0:
|
288
|
+
print(f"Removed: {state['removed_count']} images")
|
289
|
+
|
290
|
+
if state["labeled_count"] > 0:
|
291
|
+
print("\nRetraining with new examples...")
|
292
|
+
self._retrain()
|
293
|
+
print("✓ Training complete! Judge is ready to use.")
|
294
|
+
else:
|
295
|
+
print("No changes made.")
|
296
|
+
|
297
|
+
def move_file_to_class(class_index):
|
298
|
+
"""Move current file to specified class."""
|
299
|
+
state = self._teaching_state
|
300
|
+
if state["current_index"] >= len(state["files"]):
|
301
|
+
return
|
302
|
+
|
303
|
+
current_file = state["files"][state["current_index"]]
|
304
|
+
target_dir = self.root_dir / labels[class_index]
|
305
|
+
shutil.move(str(current_file), str(target_dir / current_file.name))
|
306
|
+
state["labeled_count"] += 1
|
307
|
+
state["current_index"] += 1
|
308
|
+
update_display()
|
309
|
+
|
310
|
+
# Button callbacks
|
311
|
+
def on_prev(b):
|
312
|
+
state = self._teaching_state
|
313
|
+
if state["current_index"] > 0:
|
314
|
+
state["current_index"] -= 1
|
315
|
+
update_display()
|
316
|
+
|
317
|
+
def on_class1(b):
|
318
|
+
move_file_to_class(0)
|
319
|
+
|
320
|
+
def on_class2(b):
|
321
|
+
move_file_to_class(1)
|
322
|
+
|
323
|
+
def on_skip(b):
|
324
|
+
state = self._teaching_state
|
325
|
+
state["current_index"] += 1
|
326
|
+
update_display()
|
327
|
+
|
328
|
+
def on_remove(b):
|
329
|
+
state = self._teaching_state
|
330
|
+
if state["current_index"] >= len(state["files"]):
|
331
|
+
return
|
332
|
+
|
333
|
+
current_file = state["files"][state["current_index"]]
|
334
|
+
target_dir = self.root_dir / "_removed"
|
335
|
+
shutil.move(str(current_file), str(target_dir / current_file.name))
|
336
|
+
state["removed_count"] += 1
|
337
|
+
state["current_index"] += 1
|
338
|
+
update_display()
|
339
|
+
|
340
|
+
# Connect buttons
|
341
|
+
btn_prev.on_click(on_prev)
|
342
|
+
btn_class1.on_click(on_class1)
|
343
|
+
btn_class2.on_click(on_class2)
|
344
|
+
btn_skip.on_click(on_skip)
|
345
|
+
btn_remove.on_click(on_remove)
|
346
|
+
|
347
|
+
# Create output widget for keyboard handling
|
348
|
+
output = widgets.Output()
|
349
|
+
|
350
|
+
# Keyboard event handler
|
351
|
+
def on_key(event):
|
352
|
+
"""Handle keyboard events."""
|
353
|
+
if event["type"] != "keydown":
|
354
|
+
return
|
355
|
+
|
356
|
+
key = event["key"]
|
357
|
+
|
358
|
+
if key == "ArrowUp":
|
359
|
+
on_prev(None)
|
360
|
+
elif key == "ArrowLeft":
|
361
|
+
on_class1(None)
|
362
|
+
elif key == "ArrowRight":
|
363
|
+
on_class2(None)
|
364
|
+
elif key == "ArrowDown":
|
365
|
+
on_skip(None)
|
366
|
+
elif key in ["Delete", "Backspace"]:
|
367
|
+
on_remove(None)
|
368
|
+
|
369
|
+
# Display everything
|
370
|
+
display(status_label)
|
371
|
+
display(image_widget)
|
372
|
+
display(button_box)
|
373
|
+
display(info_label)
|
374
|
+
display(output)
|
375
|
+
|
376
|
+
# Show first image
|
377
|
+
update_display()
|
378
|
+
|
379
|
+
# Try to set up keyboard handling (may not work in all environments)
|
380
|
+
try:
|
381
|
+
from ipyevents import Event
|
382
|
+
|
383
|
+
event_handler = Event(source=output, watched_events=["keydown"])
|
384
|
+
event_handler.on_dom_event(on_key)
|
385
|
+
except:
|
386
|
+
# If ipyevents not available, just use buttons
|
387
|
+
print("Note: Install ipyevents for keyboard shortcuts: pip install ipyevents")
|
388
|
+
|
389
|
+
def decide(self, regions: Union["Region", List["Region"]]) -> Union[Decision, List[Decision]]:
|
390
|
+
"""
|
391
|
+
Classify one or more regions.
|
392
|
+
|
393
|
+
Args:
|
394
|
+
regions: Single region or list of regions to classify
|
395
|
+
|
396
|
+
Returns:
|
397
|
+
Decision or list of Decisions with label and score
|
398
|
+
|
399
|
+
Raises:
|
400
|
+
JudgeError: If not enough training examples
|
401
|
+
"""
|
402
|
+
# Check if we have examples
|
403
|
+
for label in self.labels:
|
404
|
+
label_dir = self.root_dir / label
|
405
|
+
if not any(label_dir.glob("*.png")):
|
406
|
+
raise JudgeError(f"Need at least one example of class '{label}' before deciding")
|
407
|
+
|
408
|
+
# Ensure thresholds are current
|
409
|
+
if not self.thresholds:
|
410
|
+
self._retrain()
|
411
|
+
|
412
|
+
# Handle single region
|
413
|
+
single_input = not isinstance(regions, list)
|
414
|
+
if single_input:
|
415
|
+
regions = [regions]
|
416
|
+
|
417
|
+
results = []
|
418
|
+
for region in regions:
|
419
|
+
# Extract metrics
|
420
|
+
metrics = self._extract_metrics(region)
|
421
|
+
|
422
|
+
# Apply thresholds with soft voting
|
423
|
+
votes = {label: 0.0 for label in self.labels}
|
424
|
+
total_weight = 0.0
|
425
|
+
|
426
|
+
for metric_name, value in metrics.items():
|
427
|
+
if metric_name in self.thresholds:
|
428
|
+
metric_info = self.thresholds[metric_name]
|
429
|
+
weight = metric_info["accuracy"] # This is now Youden's J
|
430
|
+
|
431
|
+
# For binary classification
|
432
|
+
label1, label2 = self.labels
|
433
|
+
threshold1, direction1 = metric_info["thresholds"][label1]
|
434
|
+
|
435
|
+
# Get standard deviations for soft voting
|
436
|
+
stats = self.metrics_info.get(metric_name, {})
|
437
|
+
s1 = stats.get(f"std_{label1}", 0.0)
|
438
|
+
s2 = stats.get(f"std_{label2}", 0.0)
|
439
|
+
scale1 = s1 if s1 > 1e-6 else 1.0
|
440
|
+
scale2 = s2 if s2 > 1e-6 else 1.0
|
441
|
+
|
442
|
+
# Calculate signed margin (positive favors label1, negative favors label2)
|
443
|
+
if direction1 == "higher":
|
444
|
+
margin = (value - threshold1) / (scale1 if value >= threshold1 else scale2)
|
445
|
+
else:
|
446
|
+
margin = (threshold1 - value) / (scale1 if value <= threshold1 else scale2)
|
447
|
+
|
448
|
+
# Clip margin to avoid single metric dominating
|
449
|
+
margin = np.clip(margin, -6, 6)
|
450
|
+
|
451
|
+
# Soft votes using sigmoid
|
452
|
+
p1 = 1.0 / (1.0 + np.exp(-margin))
|
453
|
+
p2 = 1.0 - p1
|
454
|
+
|
455
|
+
votes[label1] += weight * p1
|
456
|
+
votes[label2] += weight * p2
|
457
|
+
total_weight += weight
|
458
|
+
|
459
|
+
# Normalize votes
|
460
|
+
if total_weight > 0:
|
461
|
+
for label in votes:
|
462
|
+
votes[label] /= total_weight
|
463
|
+
else:
|
464
|
+
# Fallback: uniform votes so prior still works
|
465
|
+
for label in votes:
|
466
|
+
votes[label] = 0.5
|
467
|
+
total_weight = 1.0
|
468
|
+
|
469
|
+
# Apply prior bias correction
|
470
|
+
def _logit(p, eps=1e-6):
|
471
|
+
p = max(eps, min(1 - eps, p))
|
472
|
+
return np.log(p / (1 - p))
|
473
|
+
|
474
|
+
def _sigmoid(x):
|
475
|
+
if x >= 0:
|
476
|
+
z = np.exp(-x)
|
477
|
+
return 1.0 / (1.0 + z)
|
478
|
+
else:
|
479
|
+
z = np.exp(x)
|
480
|
+
return z / (1.0 + z)
|
481
|
+
|
482
|
+
# Estimate priors from training counts
|
483
|
+
counts = self._get_training_counts()
|
484
|
+
label1, label2 = self.labels
|
485
|
+
n1 = counts.get(label1, 0)
|
486
|
+
n2 = counts.get(label2, 0)
|
487
|
+
total = max(1, n1 + n2)
|
488
|
+
|
489
|
+
if n1 > 0 and n2 > 0: # Only apply bias if we have examples of both classes
|
490
|
+
emp_prior1 = n1 / total
|
491
|
+
emp_prior2 = n2 / total
|
492
|
+
|
493
|
+
# Target prior (0.5/0.5 neutralizes imbalance)
|
494
|
+
target_prior1 = self.target_prior
|
495
|
+
target_prior2 = 1.0 - self.target_prior
|
496
|
+
|
497
|
+
# Calculate bias
|
498
|
+
bias1 = _logit(target_prior1) - _logit(emp_prior1)
|
499
|
+
bias2 = _logit(target_prior2) - _logit(emp_prior2)
|
500
|
+
|
501
|
+
# Apply bias in logit space
|
502
|
+
v1 = _sigmoid(_logit(votes[label1]) + bias1)
|
503
|
+
v2 = _sigmoid(_logit(votes[label2]) + bias2)
|
504
|
+
|
505
|
+
# Renormalize
|
506
|
+
s = v1 + v2
|
507
|
+
votes[label1] = v1 / s
|
508
|
+
votes[label2] = v2 / s
|
509
|
+
|
510
|
+
# Find best label
|
511
|
+
best_label = max(votes.items(), key=lambda x: x[1])
|
512
|
+
results.append(Decision(label=best_label[0], score=best_label[1]))
|
513
|
+
|
514
|
+
return results[0] if single_input else results
|
515
|
+
|
516
|
+
def pick(
|
517
|
+
self, target_label: str, regions: List["Region"], labels: Optional[List[str]] = None
|
518
|
+
) -> PickResult:
|
519
|
+
"""
|
520
|
+
Pick which region best matches the target label.
|
521
|
+
|
522
|
+
Args:
|
523
|
+
target_label: The class label to look for
|
524
|
+
regions: List of regions to choose from
|
525
|
+
labels: Optional human-friendly labels for each region
|
526
|
+
|
527
|
+
Returns:
|
528
|
+
PickResult with winning region, index, label (if provided), and score
|
529
|
+
|
530
|
+
Raises:
|
531
|
+
JudgeError: If target_label not in allowed labels
|
532
|
+
"""
|
533
|
+
if target_label not in self.labels:
|
534
|
+
raise JudgeError(f"Target label '{target_label}' not in allowed labels: {self.labels}")
|
535
|
+
|
536
|
+
# Classify all regions
|
537
|
+
decisions = self.decide(regions)
|
538
|
+
|
539
|
+
# Find best match for target label
|
540
|
+
best_index = -1
|
541
|
+
best_score = -1.0
|
542
|
+
|
543
|
+
for i, decision in enumerate(decisions):
|
544
|
+
if decision.label == target_label and decision.score > best_score:
|
545
|
+
best_score = decision.score
|
546
|
+
best_index = i
|
547
|
+
|
548
|
+
if best_index == -1:
|
549
|
+
# No region matched the target label
|
550
|
+
raise JudgeError(f"No region classified as '{target_label}'")
|
551
|
+
|
552
|
+
# Build result
|
553
|
+
region = regions[best_index]
|
554
|
+
label = labels[best_index] if labels and best_index < len(labels) else None
|
555
|
+
|
556
|
+
return PickResult(region=region, index=best_index, label=label, score=best_score)
|
557
|
+
|
558
|
+
def count(self, target_label: str, regions: List["Region"]) -> int:
|
559
|
+
"""
|
560
|
+
Count how many regions match the target label.
|
561
|
+
|
562
|
+
Args:
|
563
|
+
target_label: The class label to count
|
564
|
+
regions: List of regions to check
|
565
|
+
|
566
|
+
Returns:
|
567
|
+
Number of regions classified as target_label
|
568
|
+
"""
|
569
|
+
decisions = self.decide(regions)
|
570
|
+
return sum(1 for d in decisions if d.label == target_label)
|
571
|
+
|
572
|
+
def info(self) -> None:
|
573
|
+
"""
|
574
|
+
Show configuration and training information for this Judge.
|
575
|
+
"""
|
576
|
+
print(f"Judge: {self.name}")
|
577
|
+
print(f"Labels: {self.labels}")
|
578
|
+
if self.target_prior != 0.5:
|
579
|
+
print(
|
580
|
+
f"Target prior: {self.target_prior:.2f} (favors '{self.labels[0]}')"
|
581
|
+
if self.target_prior > 0.5
|
582
|
+
else f"Target prior: {self.target_prior:.2f} (favors '{self.labels[1]}')"
|
583
|
+
)
|
584
|
+
|
585
|
+
# Get training counts
|
586
|
+
counts = self._get_training_counts()
|
587
|
+
print(f"\nTraining examples:")
|
588
|
+
for label in self.labels:
|
589
|
+
count = counts.get(label, 0)
|
590
|
+
print(f" {label}: {count}")
|
591
|
+
|
592
|
+
if counts.get("unlabeled", 0) > 0:
|
593
|
+
print(f" unlabeled: {counts['unlabeled']}")
|
594
|
+
|
595
|
+
# Show actual imbalance
|
596
|
+
labeled_counts = [counts.get(label, 0) for label in self.labels]
|
597
|
+
if all(c > 0 for c in labeled_counts):
|
598
|
+
max_count = max(labeled_counts)
|
599
|
+
min_count = min(labeled_counts)
|
600
|
+
if max_count != min_count:
|
601
|
+
# Find which is which
|
602
|
+
for i, label in enumerate(self.labels):
|
603
|
+
if counts.get(label, 0) == max_count:
|
604
|
+
majority_label = label
|
605
|
+
if counts.get(label, 0) == min_count:
|
606
|
+
minority_label = label
|
607
|
+
|
608
|
+
ratio = max_count / min_count
|
609
|
+
print(
|
610
|
+
f"\nClass imbalance: {majority_label}:{minority_label} = {max_count}:{min_count} ({ratio:.1f}:1)"
|
611
|
+
)
|
612
|
+
|
613
|
+
print(" Using Youden's J weights with soft voting and prior correction")
|
614
|
+
|
615
|
+
def inspect(self, preview: bool = True) -> None:
|
616
|
+
"""
|
617
|
+
Inspect all stored examples, showing their true labels and predicted labels/scores.
|
618
|
+
Useful for debugging classification issues.
|
619
|
+
|
620
|
+
Args:
|
621
|
+
preview: If True (default), display images inline in HTML tables (requires IPython/Jupyter).
|
622
|
+
If False, use text-only output.
|
623
|
+
"""
|
624
|
+
if not self.thresholds:
|
625
|
+
print("No trained model yet. Add examples and the model will auto-train.")
|
626
|
+
return
|
627
|
+
|
628
|
+
if not preview:
|
629
|
+
# Show basic info first
|
630
|
+
self.info()
|
631
|
+
print("-" * 80)
|
632
|
+
|
633
|
+
print("\nThresholds learned:")
|
634
|
+
for metric, info in self.thresholds.items():
|
635
|
+
weight = info["accuracy"] # This is now Youden's J
|
636
|
+
selection_acc = info.get(
|
637
|
+
"selection_accuracy", info["accuracy"]
|
638
|
+
) # Fallback for old models
|
639
|
+
print(f" {metric}: weight={weight:.3f} (selection_accuracy={selection_acc:.3f})")
|
640
|
+
for label, (threshold, direction) in info["thresholds"].items():
|
641
|
+
print(f" {label}: {direction} than {threshold:.3f}")
|
642
|
+
|
643
|
+
# Show metric distribution info if available
|
644
|
+
if metric in self.metrics_info:
|
645
|
+
metric_stats = self.metrics_info[metric]
|
646
|
+
for label in self.labels:
|
647
|
+
mean_key = f"mean_{label}"
|
648
|
+
std_key = f"std_{label}"
|
649
|
+
if mean_key in metric_stats:
|
650
|
+
print(
|
651
|
+
f" {label} distribution: mean={metric_stats[mean_key]:.3f}, std={metric_stats[std_key]:.3f}"
|
652
|
+
)
|
653
|
+
|
654
|
+
if preview:
|
655
|
+
# HTML preview mode
|
656
|
+
try:
|
657
|
+
import base64
|
658
|
+
import io
|
659
|
+
|
660
|
+
from IPython.display import HTML, display
|
661
|
+
except ImportError:
|
662
|
+
print("Preview mode requires IPython/Jupyter. Falling back to text mode.")
|
663
|
+
preview = False
|
664
|
+
|
665
|
+
if preview:
|
666
|
+
# Build HTML tables for everything
|
667
|
+
html_parts = []
|
668
|
+
html_parts.append("<style>")
|
669
|
+
html_parts.append("table { border-collapse: collapse; margin: 20px 0; }")
|
670
|
+
html_parts.append("th, td { border: 1px solid #ddd; padding: 8px; text-align: left; }")
|
671
|
+
html_parts.append("th { background-color: #f2f2f2; font-weight: bold; }")
|
672
|
+
html_parts.append("img { max-width: 60px; max-height: 60px; }")
|
673
|
+
html_parts.append(".correct { color: green; }")
|
674
|
+
html_parts.append(".incorrect { color: red; }")
|
675
|
+
html_parts.append(".metrics { font-size: 0.9em; color: #666; }")
|
676
|
+
html_parts.append("h3 { margin-top: 30px; }")
|
677
|
+
html_parts.append(".imbalance-warning { background-color: #fff3cd; color: #856404; }")
|
678
|
+
html_parts.append("</style>")
|
679
|
+
|
680
|
+
# Configuration table
|
681
|
+
html_parts.append("<h3>Judge Configuration</h3>")
|
682
|
+
html_parts.append("<table>")
|
683
|
+
html_parts.append("<tr><th>Property</th><th>Value</th></tr>")
|
684
|
+
html_parts.append(f"<tr><td>Name</td><td>{self.name}</td></tr>")
|
685
|
+
html_parts.append(f"<tr><td>Labels</td><td>{', '.join(self.labels)}</td></tr>")
|
686
|
+
html_parts.append(f"<tr><td>Target Prior</td><td>{self.target_prior:.2f}")
|
687
|
+
if self.target_prior != 0.5:
|
688
|
+
html_parts.append(
|
689
|
+
f" (favors '{self.labels[0] if self.target_prior > 0.5 else self.labels[1]}')"
|
690
|
+
)
|
691
|
+
html_parts.append("</td></tr>")
|
692
|
+
html_parts.append("</table>")
|
693
|
+
|
694
|
+
# Training counts table
|
695
|
+
counts = self._get_training_counts()
|
696
|
+
html_parts.append("<h3>Training Examples</h3>")
|
697
|
+
html_parts.append("<table>")
|
698
|
+
html_parts.append("<tr><th>Class</th><th>Count</th></tr>")
|
699
|
+
|
700
|
+
# Check for imbalance
|
701
|
+
labeled_counts = [counts.get(label, 0) for label in self.labels]
|
702
|
+
is_imbalanced = False
|
703
|
+
if all(c > 0 for c in labeled_counts):
|
704
|
+
max_count = max(labeled_counts)
|
705
|
+
min_count = min(labeled_counts)
|
706
|
+
if max_count != min_count:
|
707
|
+
ratio = max_count / min_count
|
708
|
+
is_imbalanced = ratio > 1.5
|
709
|
+
|
710
|
+
for label in self.labels:
|
711
|
+
count = counts.get(label, 0)
|
712
|
+
row_class = ""
|
713
|
+
if is_imbalanced:
|
714
|
+
if count == max(labeled_counts):
|
715
|
+
row_class = ' class="imbalance-warning"'
|
716
|
+
html_parts.append(f"<tr{row_class}><td>{label}</td><td>{count}</td></tr>")
|
717
|
+
|
718
|
+
if counts.get("unlabeled", 0) > 0:
|
719
|
+
html_parts.append(f"<tr><td>unlabeled</td><td>{counts['unlabeled']}</td></tr>")
|
720
|
+
|
721
|
+
html_parts.append("</table>")
|
722
|
+
|
723
|
+
if is_imbalanced:
|
724
|
+
html_parts.append(
|
725
|
+
f"<p><em>Class imbalance detected ({ratio:.1f}:1). Using Youden's J weights with prior correction.</em></p>"
|
726
|
+
)
|
727
|
+
|
728
|
+
# Thresholds table
|
729
|
+
html_parts.append("<h3>Learned Thresholds</h3>")
|
730
|
+
html_parts.append("<table>")
|
731
|
+
html_parts.append(
|
732
|
+
"<tr><th>Metric</th><th>Weight (Youden's J)</th><th>Selection Accuracy</th><th>Threshold Details</th></tr>"
|
733
|
+
)
|
734
|
+
|
735
|
+
for metric, info in self.thresholds.items():
|
736
|
+
weight = info["accuracy"] # This is Youden's J
|
737
|
+
selection_acc = info.get("selection_accuracy", weight)
|
738
|
+
|
739
|
+
# Build threshold details
|
740
|
+
details = []
|
741
|
+
for label, (threshold, direction) in info["thresholds"].items():
|
742
|
+
details.append(f"<br>{label}: {direction} than {threshold:.3f}")
|
743
|
+
|
744
|
+
# Add distribution info if available
|
745
|
+
if metric in self.metrics_info:
|
746
|
+
metric_stats = self.metrics_info[metric]
|
747
|
+
details.append("<br><em>Distributions:</em>")
|
748
|
+
for label in self.labels:
|
749
|
+
mean_key = f"mean_{label}"
|
750
|
+
std_key = f"std_{label}"
|
751
|
+
if mean_key in metric_stats:
|
752
|
+
details.append(
|
753
|
+
f"<br> {label}: μ={metric_stats[mean_key]:.1f}, σ={metric_stats[std_key]:.1f}"
|
754
|
+
)
|
755
|
+
|
756
|
+
html_parts.append("<tr>")
|
757
|
+
html_parts.append(f"<td>{metric}</td>")
|
758
|
+
html_parts.append(f"<td>{weight:.3f}</td>")
|
759
|
+
html_parts.append(f"<td>{selection_acc:.3f}</td>")
|
760
|
+
html_parts.append(f"<td>{''.join(details)}</td>")
|
761
|
+
html_parts.append("</tr>")
|
762
|
+
|
763
|
+
html_parts.append("</table>")
|
764
|
+
|
765
|
+
all_correct = 0
|
766
|
+
all_total = 0
|
767
|
+
|
768
|
+
# First show labeled examples
|
769
|
+
for true_label in self.labels:
|
770
|
+
label_dir = self.root_dir / true_label
|
771
|
+
examples = list(label_dir.glob("*.png"))
|
772
|
+
|
773
|
+
if not examples:
|
774
|
+
continue
|
775
|
+
|
776
|
+
html_parts.append(
|
777
|
+
f"<h3>Predictions: {true_label.upper()} ({len(examples)} total)</h3>"
|
778
|
+
)
|
779
|
+
html_parts.append("<table>")
|
780
|
+
html_parts.append(
|
781
|
+
"<tr><th>Image</th><th>Status</th><th>Predicted</th><th>Score</th><th>Key Metrics</th></tr>"
|
782
|
+
)
|
783
|
+
|
784
|
+
correct = 0
|
785
|
+
|
786
|
+
for img_path in sorted(examples)[:20]: # Show max 20 per class in preview
|
787
|
+
# Load image
|
788
|
+
img = Image.open(img_path)
|
789
|
+
mock_region = type("MockRegion", (), {"render": lambda self, crop=True: img})()
|
790
|
+
|
791
|
+
# Get prediction
|
792
|
+
decision = self.decide(mock_region)
|
793
|
+
is_correct = decision.label == true_label
|
794
|
+
if is_correct:
|
795
|
+
correct += 1
|
796
|
+
|
797
|
+
# Extract metrics
|
798
|
+
metrics = self._extract_metrics(mock_region)
|
799
|
+
|
800
|
+
# Convert image to base64
|
801
|
+
buffered = io.BytesIO()
|
802
|
+
img.save(buffered, format="PNG")
|
803
|
+
img_str = base64.b64encode(buffered.getvalue()).decode()
|
804
|
+
|
805
|
+
# Build row
|
806
|
+
status_class = "correct" if is_correct else "incorrect"
|
807
|
+
status_symbol = "✓" if is_correct else "✗"
|
808
|
+
|
809
|
+
# Format key metrics
|
810
|
+
metric_strs = []
|
811
|
+
for metric, value in sorted(metrics.items()):
|
812
|
+
if metric in self.thresholds:
|
813
|
+
metric_strs.append(f"{metric}={value:.1f}")
|
814
|
+
metrics_html = "<br>".join(metric_strs[:3])
|
815
|
+
|
816
|
+
html_parts.append("<tr>")
|
817
|
+
html_parts.append(f'<td><img src="data:image/png;base64,{img_str}" /></td>')
|
818
|
+
html_parts.append(f'<td class="{status_class}">{status_symbol}</td>')
|
819
|
+
html_parts.append(f"<td>{decision.label}</td>")
|
820
|
+
html_parts.append(f"<td>{decision.score:.3f}</td>")
|
821
|
+
html_parts.append(f'<td class="metrics">{metrics_html}</td>')
|
822
|
+
html_parts.append("</tr>")
|
823
|
+
|
824
|
+
html_parts.append("</table>")
|
825
|
+
|
826
|
+
accuracy = correct / len(examples) if examples else 0
|
827
|
+
all_correct += correct
|
828
|
+
all_total += len(examples)
|
829
|
+
|
830
|
+
if len(examples) > 20:
|
831
|
+
html_parts.append(f"<p><em>... and {len(examples) - 20} more</em></p>")
|
832
|
+
html_parts.append(
|
833
|
+
f"<p>Accuracy for {true_label}: <strong>{accuracy:.1%}</strong> ({correct}/{len(examples)})</p>"
|
834
|
+
)
|
835
|
+
|
836
|
+
if all_total > 0:
|
837
|
+
overall_accuracy = all_correct / all_total
|
838
|
+
html_parts.append(
|
839
|
+
f"<h3>Overall accuracy: {overall_accuracy:.1%} ({all_correct}/{all_total})</h3>"
|
840
|
+
)
|
841
|
+
|
842
|
+
# Now show unlabeled examples with predictions
|
843
|
+
unlabeled_dir = self.root_dir / "unlabeled"
|
844
|
+
unlabeled_examples = list(unlabeled_dir.glob("*.png"))
|
845
|
+
|
846
|
+
if unlabeled_examples:
|
847
|
+
html_parts.append(
|
848
|
+
f"<h3>Predictions: UNLABELED ({len(unlabeled_examples)} total)</h3>"
|
849
|
+
)
|
850
|
+
html_parts.append("<table>")
|
851
|
+
html_parts.append(
|
852
|
+
"<tr><th>Image</th><th>Predicted</th><th>Score</th><th>Key Metrics</th></tr>"
|
853
|
+
)
|
854
|
+
|
855
|
+
for img_path in sorted(unlabeled_examples)[:20]: # Show max 20
|
856
|
+
# Load image
|
857
|
+
img = Image.open(img_path)
|
858
|
+
mock_region = type("MockRegion", (), {"render": lambda self, crop=True: img})()
|
859
|
+
|
860
|
+
# Get prediction
|
861
|
+
decision = self.decide(mock_region)
|
862
|
+
|
863
|
+
# Extract metrics
|
864
|
+
metrics = self._extract_metrics(mock_region)
|
865
|
+
|
866
|
+
# Convert image to base64
|
867
|
+
buffered = io.BytesIO()
|
868
|
+
img.save(buffered, format="PNG")
|
869
|
+
img_str = base64.b64encode(buffered.getvalue()).decode()
|
870
|
+
|
871
|
+
# Format key metrics
|
872
|
+
metric_strs = []
|
873
|
+
for metric, value in sorted(metrics.items()):
|
874
|
+
if metric in self.thresholds:
|
875
|
+
metric_strs.append(f"{metric}={value:.1f}")
|
876
|
+
metrics_html = "<br>".join(metric_strs[:3])
|
877
|
+
|
878
|
+
html_parts.append("<tr>")
|
879
|
+
html_parts.append(f'<td><img src="data:image/png;base64,{img_str}" /></td>')
|
880
|
+
html_parts.append(f"<td>{decision.label}</td>")
|
881
|
+
html_parts.append(f"<td>{decision.score:.3f}</td>")
|
882
|
+
html_parts.append(f'<td class="metrics">{metrics_html}</td>')
|
883
|
+
html_parts.append("</tr>")
|
884
|
+
|
885
|
+
html_parts.append("</table>")
|
886
|
+
|
887
|
+
if len(unlabeled_examples) > 20:
|
888
|
+
html_parts.append(
|
889
|
+
f"<p><em>... and {len(unlabeled_examples) - 20} more</em></p>"
|
890
|
+
)
|
891
|
+
|
892
|
+
# Display HTML
|
893
|
+
display(HTML("".join(html_parts)))
|
894
|
+
|
895
|
+
else:
|
896
|
+
# Text mode (original)
|
897
|
+
print("\nPredictions on training data:")
|
898
|
+
print("-" * 80)
|
899
|
+
|
900
|
+
# Test each labeled example
|
901
|
+
all_correct = 0
|
902
|
+
all_total = 0
|
903
|
+
|
904
|
+
for true_label in self.labels:
|
905
|
+
label_dir = self.root_dir / true_label
|
906
|
+
examples = list(label_dir.glob("*.png"))
|
907
|
+
|
908
|
+
if not examples:
|
909
|
+
continue
|
910
|
+
|
911
|
+
print(f"\n{true_label.upper()} examples ({len(examples)} total):")
|
912
|
+
correct = 0
|
913
|
+
|
914
|
+
for img_path in sorted(examples)[:10]: # Show max 10 per class
|
915
|
+
# Load image and create mock region
|
916
|
+
img = Image.open(img_path)
|
917
|
+
mock_region = type("MockRegion", (), {"render": lambda self, crop=True: img})()
|
918
|
+
|
919
|
+
# Get prediction
|
920
|
+
decision = self.decide(mock_region)
|
921
|
+
is_correct = decision.label == true_label
|
922
|
+
if is_correct:
|
923
|
+
correct += 1
|
924
|
+
|
925
|
+
# Extract metrics for this example
|
926
|
+
metrics = self._extract_metrics(mock_region)
|
927
|
+
|
928
|
+
# Show result
|
929
|
+
status = "✓" if is_correct else "✗"
|
930
|
+
print(
|
931
|
+
f" {status} {img_path.name}: predicted={decision.label} (score={decision.score:.3f})"
|
932
|
+
)
|
933
|
+
|
934
|
+
# Show key metric values
|
935
|
+
metric_strs = []
|
936
|
+
for metric, value in sorted(metrics.items()):
|
937
|
+
if metric in self.thresholds:
|
938
|
+
metric_strs.append(f"{metric}={value:.2f}")
|
939
|
+
if metric_strs:
|
940
|
+
print(f" Metrics: {', '.join(metric_strs[:3])}")
|
941
|
+
|
942
|
+
accuracy = correct / len(examples) if examples else 0
|
943
|
+
all_correct += correct
|
944
|
+
all_total += len(examples)
|
945
|
+
|
946
|
+
if len(examples) > 10:
|
947
|
+
print(f" ... and {len(examples) - 10} more")
|
948
|
+
print(f" Accuracy for {true_label}: {accuracy:.1%} ({correct}/{len(examples)})")
|
949
|
+
|
950
|
+
if all_total > 0:
|
951
|
+
overall_accuracy = all_correct / all_total
|
952
|
+
print(f"\nOverall accuracy: {overall_accuracy:.1%} ({all_correct}/{all_total})")
|
953
|
+
|
954
|
+
# Show unlabeled examples with predictions
|
955
|
+
unlabeled_dir = self.root_dir / "unlabeled"
|
956
|
+
unlabeled_examples = list(unlabeled_dir.glob("*.png"))
|
957
|
+
|
958
|
+
if unlabeled_examples:
|
959
|
+
print(f"\nUNLABELED examples ({len(unlabeled_examples)} total) - predictions:")
|
960
|
+
|
961
|
+
for img_path in sorted(unlabeled_examples)[:10]: # Show max 10
|
962
|
+
# Load image and create mock region
|
963
|
+
img = Image.open(img_path)
|
964
|
+
mock_region = type("MockRegion", (), {"render": lambda self, crop=True: img})()
|
965
|
+
|
966
|
+
# Get prediction
|
967
|
+
decision = self.decide(mock_region)
|
968
|
+
|
969
|
+
# Extract metrics
|
970
|
+
metrics = self._extract_metrics(mock_region)
|
971
|
+
|
972
|
+
print(
|
973
|
+
f" {img_path.name}: predicted={decision.label} (score={decision.score:.3f})"
|
974
|
+
)
|
975
|
+
|
976
|
+
# Show key metric values
|
977
|
+
metric_strs = []
|
978
|
+
for metric, value in sorted(metrics.items()):
|
979
|
+
if metric in self.thresholds:
|
980
|
+
metric_strs.append(f"{metric}={value:.2f}")
|
981
|
+
if metric_strs:
|
982
|
+
print(f" Metrics: {', '.join(metric_strs[:3])}")
|
983
|
+
|
984
|
+
if len(unlabeled_examples) > 10:
|
985
|
+
print(f" ... and {len(unlabeled_examples) - 10} more")
|
986
|
+
|
987
|
+
def lookup(self, region) -> Optional[Tuple[str, Image.Image]]:
|
988
|
+
"""
|
989
|
+
Look up a region and return its hash and image if found in training data.
|
990
|
+
|
991
|
+
Args:
|
992
|
+
region: Region to look up
|
993
|
+
|
994
|
+
Returns:
|
995
|
+
Tuple of (hash, image) if found, None if not found
|
996
|
+
"""
|
997
|
+
try:
|
998
|
+
# Generate hash for the region
|
999
|
+
img = region.render(crop=True)
|
1000
|
+
if not isinstance(img, Image.Image):
|
1001
|
+
img = Image.fromarray(img)
|
1002
|
+
if img.mode != "RGB":
|
1003
|
+
img = img.convert("RGB")
|
1004
|
+
img_array = np.array(img)
|
1005
|
+
img_hash = hashlib.md5(img_array.tobytes()).hexdigest()[:12]
|
1006
|
+
|
1007
|
+
# Look for the image in all directories
|
1008
|
+
for subdir in ["checked", "unchecked", "unlabeled", "_removed"]:
|
1009
|
+
if subdir == "checked" or subdir == "unchecked":
|
1010
|
+
# Only look in valid label directories
|
1011
|
+
if subdir not in self.labels:
|
1012
|
+
continue
|
1013
|
+
|
1014
|
+
img_path = self.root_dir / subdir / f"{img_hash}.png"
|
1015
|
+
if img_path.exists():
|
1016
|
+
stored_img = Image.open(img_path)
|
1017
|
+
logger.debug(f"Found region in '{subdir}' with hash {img_hash}")
|
1018
|
+
return (img_hash, stored_img)
|
1019
|
+
|
1020
|
+
logger.debug(f"Region not found in training data (hash: {img_hash})")
|
1021
|
+
return None
|
1022
|
+
|
1023
|
+
except Exception as e:
|
1024
|
+
logger.error(f"Failed to lookup region: {e}")
|
1025
|
+
return None
|
1026
|
+
|
1027
|
+
def show(self, max_per_class: int = 10, size: Tuple[int, int] = (100, 100)) -> None:
|
1028
|
+
"""
|
1029
|
+
Display a grid showing examples from each category.
|
1030
|
+
|
1031
|
+
Args:
|
1032
|
+
max_per_class: Maximum number of examples to show per class
|
1033
|
+
size: Size of each image in pixels (width, height)
|
1034
|
+
"""
|
1035
|
+
try:
|
1036
|
+
import ipywidgets as widgets
|
1037
|
+
from IPython.display import display
|
1038
|
+
from PIL import Image as PILImage
|
1039
|
+
except ImportError:
|
1040
|
+
print("Show requires IPython and ipywidgets")
|
1041
|
+
return
|
1042
|
+
|
1043
|
+
# Collect images from each category
|
1044
|
+
categories = {}
|
1045
|
+
total_counts = {}
|
1046
|
+
for label in self.labels:
|
1047
|
+
label_dir = self.root_dir / label
|
1048
|
+
all_images = list(label_dir.glob("*.png"))
|
1049
|
+
total_counts[label] = len(all_images)
|
1050
|
+
images = sorted(all_images)[:max_per_class]
|
1051
|
+
if images:
|
1052
|
+
categories[label] = images
|
1053
|
+
|
1054
|
+
# Add unlabeled if any
|
1055
|
+
unlabeled_dir = self.root_dir / "unlabeled"
|
1056
|
+
all_unlabeled = list(unlabeled_dir.glob("*.png"))
|
1057
|
+
total_counts["unlabeled"] = len(all_unlabeled)
|
1058
|
+
unlabeled = sorted(all_unlabeled)[:max_per_class]
|
1059
|
+
if unlabeled:
|
1060
|
+
categories["unlabeled"] = unlabeled
|
1061
|
+
|
1062
|
+
if not categories:
|
1063
|
+
print("No images to show")
|
1064
|
+
return
|
1065
|
+
|
1066
|
+
# Create grid layout
|
1067
|
+
rows = []
|
1068
|
+
|
1069
|
+
# Check for class imbalance
|
1070
|
+
labeled_counts = {k: v for k, v in total_counts.items() if k != "unlabeled"}
|
1071
|
+
if labeled_counts and len(labeled_counts) >= 2:
|
1072
|
+
max_count = max(labeled_counts.values())
|
1073
|
+
min_count = min(labeled_counts.values())
|
1074
|
+
if min_count > 0 and max_count / min_count > 3:
|
1075
|
+
warning = widgets.HTML(
|
1076
|
+
f'<div style="background: #fff3cd; padding: 10px; margin: 10px 0; border: 1px solid #ffeeba; border-radius: 4px;">'
|
1077
|
+
f"<strong>⚠️ Class imbalance detected:</strong> {labeled_counts}<br>"
|
1078
|
+
f"Consider adding more examples of the minority class for better accuracy."
|
1079
|
+
f"</div>"
|
1080
|
+
)
|
1081
|
+
rows.append(warning)
|
1082
|
+
|
1083
|
+
for category, image_paths in categories.items():
|
1084
|
+
# Category header showing total count
|
1085
|
+
shown = len(image_paths)
|
1086
|
+
total = total_counts[category]
|
1087
|
+
header_text = f"<h3>{category}"
|
1088
|
+
if shown < total:
|
1089
|
+
header_text += f" ({shown} of {total} shown)"
|
1090
|
+
else:
|
1091
|
+
header_text += f" ({total} total)"
|
1092
|
+
header_text += "</h3>"
|
1093
|
+
header = widgets.HTML(header_text)
|
1094
|
+
|
1095
|
+
# Image row
|
1096
|
+
image_widgets = []
|
1097
|
+
for img_path in image_paths:
|
1098
|
+
# Load and resize image
|
1099
|
+
img = PILImage.open(img_path)
|
1100
|
+
img.thumbnail(size, PILImage.Resampling.LANCZOS)
|
1101
|
+
|
1102
|
+
# Convert to bytes for display
|
1103
|
+
import io
|
1104
|
+
|
1105
|
+
img_bytes = io.BytesIO()
|
1106
|
+
img.save(img_bytes, format="PNG")
|
1107
|
+
img_bytes.seek(0)
|
1108
|
+
|
1109
|
+
# Create image widget
|
1110
|
+
img_widget = widgets.Image(value=img_bytes.read(), width=size[0], height=size[1])
|
1111
|
+
image_widgets.append(img_widget)
|
1112
|
+
|
1113
|
+
# Create horizontal box for this category
|
1114
|
+
category_box = widgets.VBox([header, widgets.HBox(image_widgets)])
|
1115
|
+
rows.append(category_box)
|
1116
|
+
|
1117
|
+
# Display all categories
|
1118
|
+
display(widgets.VBox(rows))
|
1119
|
+
|
1120
|
+
def forget(self, region: Optional["Region"] = None, delete: bool = False) -> None:
|
1121
|
+
"""
|
1122
|
+
Clear training data, delete all files, or move a specific region to unlabeled.
|
1123
|
+
|
1124
|
+
Args:
|
1125
|
+
region: If provided, move this specific region to unlabeled
|
1126
|
+
delete: If True, permanently delete all files
|
1127
|
+
"""
|
1128
|
+
# Handle specific region case
|
1129
|
+
if region is not None:
|
1130
|
+
# Get hash of the region
|
1131
|
+
try:
|
1132
|
+
img = region.render(crop=True)
|
1133
|
+
if not isinstance(img, Image.Image):
|
1134
|
+
img = Image.fromarray(img)
|
1135
|
+
if img.mode != "RGB":
|
1136
|
+
img = img.convert("RGB")
|
1137
|
+
img_array = np.array(img)
|
1138
|
+
img_hash = hashlib.md5(img_array.tobytes()).hexdigest()[:12]
|
1139
|
+
except Exception as e:
|
1140
|
+
logger.error(f"Failed to hash region: {e}")
|
1141
|
+
return
|
1142
|
+
|
1143
|
+
# Find and move the image
|
1144
|
+
moved = False
|
1145
|
+
for label in self.labels + ["_removed"]:
|
1146
|
+
source_path = self.root_dir / label / f"{img_hash}.png"
|
1147
|
+
if source_path.exists():
|
1148
|
+
target_path = self.root_dir / "unlabeled" / f"{img_hash}.png"
|
1149
|
+
shutil.move(str(source_path), str(target_path))
|
1150
|
+
print(f"Moved region from '{label}' to 'unlabeled'")
|
1151
|
+
moved = True
|
1152
|
+
break
|
1153
|
+
|
1154
|
+
if not moved:
|
1155
|
+
print(f"Region not found in training data")
|
1156
|
+
return
|
1157
|
+
|
1158
|
+
# Handle delete or clear training
|
1159
|
+
if delete:
|
1160
|
+
# Delete entire directory
|
1161
|
+
if self.root_dir.exists():
|
1162
|
+
shutil.rmtree(self.root_dir)
|
1163
|
+
print(f"Deleted all data for judge '{self.name}'")
|
1164
|
+
else:
|
1165
|
+
print(f"No data found for judge '{self.name}'")
|
1166
|
+
|
1167
|
+
# Reset internal state
|
1168
|
+
self.thresholds = {}
|
1169
|
+
self.metrics_info = {}
|
1170
|
+
|
1171
|
+
# Recreate directory structure
|
1172
|
+
self.root_dir.mkdir(exist_ok=True)
|
1173
|
+
for label in self.labels:
|
1174
|
+
(self.root_dir / label).mkdir(exist_ok=True)
|
1175
|
+
(self.root_dir / "unlabeled").mkdir(exist_ok=True)
|
1176
|
+
(self.root_dir / "_removed").mkdir(exist_ok=True)
|
1177
|
+
|
1178
|
+
else:
|
1179
|
+
# Just clear training (move everything to unlabeled)
|
1180
|
+
moved_count = 0
|
1181
|
+
|
1182
|
+
# Move all labeled images back to unlabeled
|
1183
|
+
unlabeled_dir = self.root_dir / "unlabeled"
|
1184
|
+
for label in self.labels:
|
1185
|
+
label_dir = self.root_dir / label
|
1186
|
+
if label_dir.exists():
|
1187
|
+
for img_path in label_dir.glob("*.png"):
|
1188
|
+
shutil.move(str(img_path), str(unlabeled_dir / img_path.name))
|
1189
|
+
moved_count += 1
|
1190
|
+
|
1191
|
+
# Clear thresholds
|
1192
|
+
self.thresholds = {}
|
1193
|
+
self.metrics_info = {}
|
1194
|
+
|
1195
|
+
# Remove saved config
|
1196
|
+
if self.config_path.exists():
|
1197
|
+
self.config_path.unlink()
|
1198
|
+
|
1199
|
+
print(f"Moved {moved_count} labeled images back to unlabeled.")
|
1200
|
+
print("Training data cleared. Judge is now untrained.")
|
1201
|
+
|
1202
|
+
def save(self, path: Optional[str] = None) -> None:
|
1203
|
+
"""
|
1204
|
+
Save the judge configuration (auto-retrains first).
|
1205
|
+
|
1206
|
+
Args:
|
1207
|
+
path: Optional path to save to. Defaults to judge.json in root directory
|
1208
|
+
"""
|
1209
|
+
# Retrain with current examples
|
1210
|
+
self._retrain()
|
1211
|
+
|
1212
|
+
# Save config
|
1213
|
+
save_path = Path(path) if path else self.config_path
|
1214
|
+
|
1215
|
+
config = {
|
1216
|
+
"name": self.name,
|
1217
|
+
"labels": self.labels,
|
1218
|
+
"target_prior": self.target_prior,
|
1219
|
+
"thresholds": self.thresholds,
|
1220
|
+
"metrics_info": self.metrics_info,
|
1221
|
+
"training_counts": self._get_training_counts(),
|
1222
|
+
}
|
1223
|
+
|
1224
|
+
with open(save_path, "w") as f:
|
1225
|
+
json.dump(config, f, indent=2)
|
1226
|
+
|
1227
|
+
logger.info(f"Saved judge to {save_path}")
|
1228
|
+
|
1229
|
+
@classmethod
|
1230
|
+
def load(cls, path: str) -> "Judge":
|
1231
|
+
"""
|
1232
|
+
Load a judge from a saved configuration.
|
1233
|
+
|
1234
|
+
Args:
|
1235
|
+
path: Path to the saved judge.json file or the judge directory
|
1236
|
+
|
1237
|
+
Returns:
|
1238
|
+
Loaded Judge instance
|
1239
|
+
"""
|
1240
|
+
path = Path(path)
|
1241
|
+
|
1242
|
+
# If path is a directory, look for judge.json inside
|
1243
|
+
if path.is_dir():
|
1244
|
+
config_path = path / "judge.json"
|
1245
|
+
base_dir = path.parent
|
1246
|
+
name = path.name
|
1247
|
+
else:
|
1248
|
+
config_path = path
|
1249
|
+
base_dir = path.parent.parent if path.parent.name != "." else path.parent
|
1250
|
+
# Try to infer name from path
|
1251
|
+
name = None
|
1252
|
+
|
1253
|
+
with open(config_path, "r") as f:
|
1254
|
+
config = json.load(f)
|
1255
|
+
|
1256
|
+
# Use saved name if we couldn't infer it
|
1257
|
+
if name is None:
|
1258
|
+
name = config["name"]
|
1259
|
+
|
1260
|
+
# Create judge with saved config
|
1261
|
+
judge = cls(
|
1262
|
+
name,
|
1263
|
+
labels=config["labels"],
|
1264
|
+
base_dir=base_dir,
|
1265
|
+
target_prior=config.get("target_prior", 0.5),
|
1266
|
+
) # Default to 0.5 for old configs
|
1267
|
+
judge.thresholds = config["thresholds"]
|
1268
|
+
judge.metrics_info = config.get("metrics_info", {})
|
1269
|
+
|
1270
|
+
return judge
|
1271
|
+
|
1272
|
+
# Private methods
|
1273
|
+
|
1274
|
+
def _extract_metrics(self, region) -> Dict[str, float]:
|
1275
|
+
"""Extract image metrics from a region."""
|
1276
|
+
try:
|
1277
|
+
img = region.render(crop=True)
|
1278
|
+
if not isinstance(img, Image.Image):
|
1279
|
+
img = Image.fromarray(img)
|
1280
|
+
|
1281
|
+
# Convert to grayscale for analysis
|
1282
|
+
gray = np.array(img.convert("L"))
|
1283
|
+
|
1284
|
+
metrics = {}
|
1285
|
+
|
1286
|
+
# 1. Center darkness
|
1287
|
+
h, w = gray.shape
|
1288
|
+
cy, cx = h // 2, w // 2
|
1289
|
+
center_size = min(5, h // 4, w // 4) # Adaptive center size
|
1290
|
+
center = gray[
|
1291
|
+
max(0, cy - center_size) : min(h, cy + center_size + 1),
|
1292
|
+
max(0, cx - center_size) : min(w, cx + center_size + 1),
|
1293
|
+
]
|
1294
|
+
metrics["center_darkness"] = 255 - np.mean(center)
|
1295
|
+
|
1296
|
+
# 2. Overall darkness (ink density)
|
1297
|
+
metrics["ink_density"] = 255 - np.mean(gray)
|
1298
|
+
|
1299
|
+
# 3. Dark pixel ratio
|
1300
|
+
metrics["dark_pixel_ratio"] = np.sum(gray < 200) / gray.size
|
1301
|
+
|
1302
|
+
# 4. Standard deviation (complexity)
|
1303
|
+
metrics["std_dev"] = np.std(gray)
|
1304
|
+
|
1305
|
+
# 5. Edge vs center ratio
|
1306
|
+
edge_size = max(2, min(h // 10, w // 10))
|
1307
|
+
edge_mask = np.zeros_like(gray, dtype=bool)
|
1308
|
+
edge_mask[:edge_size, :] = True
|
1309
|
+
edge_mask[-edge_size:, :] = True
|
1310
|
+
edge_mask[:, :edge_size] = True
|
1311
|
+
edge_mask[:, -edge_size:] = True
|
1312
|
+
|
1313
|
+
edge_mean = np.mean(gray[edge_mask]) if np.any(edge_mask) else 255
|
1314
|
+
center_mean = np.mean(center)
|
1315
|
+
metrics["edge_center_ratio"] = edge_mean / (center_mean + 1)
|
1316
|
+
|
1317
|
+
# 6. Diagonal density (for X patterns)
|
1318
|
+
if h > 10 and w > 10:
|
1319
|
+
diag_mask = np.zeros_like(gray, dtype=bool)
|
1320
|
+
for i in range(min(h, w)):
|
1321
|
+
if i < h and i < w:
|
1322
|
+
diag_mask[i, i] = True
|
1323
|
+
diag_mask[i, w - 1 - i] = True
|
1324
|
+
metrics["diagonal_density"] = 255 - np.mean(gray[diag_mask])
|
1325
|
+
else:
|
1326
|
+
metrics["diagonal_density"] = metrics["ink_density"]
|
1327
|
+
|
1328
|
+
return metrics
|
1329
|
+
|
1330
|
+
except Exception as e:
|
1331
|
+
raise JudgeError(f"Failed to extract metrics: {e}")
|
1332
|
+
|
1333
|
+
def _retrain(self) -> None:
|
1334
|
+
"""Retrain thresholds from current examples."""
|
1335
|
+
# Collect all examples
|
1336
|
+
examples = {label: [] for label in self.labels}
|
1337
|
+
|
1338
|
+
for label in self.labels:
|
1339
|
+
label_dir = self.root_dir / label
|
1340
|
+
for img_path in label_dir.glob("*.png"):
|
1341
|
+
img = Image.open(img_path)
|
1342
|
+
# Create a mock region that just returns the image
|
1343
|
+
mock_region = type("MockRegion", (), {"render": lambda self, crop=True: img})()
|
1344
|
+
metrics = self._extract_metrics(mock_region)
|
1345
|
+
examples[label].append(metrics)
|
1346
|
+
|
1347
|
+
# Check we have examples
|
1348
|
+
for label, exs in examples.items():
|
1349
|
+
if not exs:
|
1350
|
+
logger.warning(f"No examples for class '{label}'")
|
1351
|
+
return
|
1352
|
+
|
1353
|
+
# Check for class imbalance
|
1354
|
+
example_counts = {label: len(exs) for label, exs in examples.items()}
|
1355
|
+
max_count = max(example_counts.values())
|
1356
|
+
min_count = min(example_counts.values())
|
1357
|
+
|
1358
|
+
imbalance_ratio = max_count / min_count if min_count > 0 else float("inf")
|
1359
|
+
is_imbalanced = imbalance_ratio > 1.5 # Consider imbalanced if more than 1.5x difference
|
1360
|
+
|
1361
|
+
if is_imbalanced:
|
1362
|
+
logger.info(
|
1363
|
+
f"Class imbalance detected: {example_counts} (ratio {imbalance_ratio:.1f}:1)"
|
1364
|
+
)
|
1365
|
+
logger.info("Using balanced accuracy for threshold selection")
|
1366
|
+
|
1367
|
+
# Find best thresholds for each metric
|
1368
|
+
self.thresholds = {}
|
1369
|
+
self.metrics_info = {}
|
1370
|
+
metric_candidates = [] # Store all metrics with their scores
|
1371
|
+
|
1372
|
+
all_metrics = set()
|
1373
|
+
for exs in examples.values():
|
1374
|
+
for ex in exs:
|
1375
|
+
all_metrics.update(ex.keys())
|
1376
|
+
|
1377
|
+
for metric in all_metrics:
|
1378
|
+
# Get all values for this metric
|
1379
|
+
values_by_label = {}
|
1380
|
+
for label, exs in examples.items():
|
1381
|
+
values_by_label[label] = [ex.get(metric, 0) for ex in exs]
|
1382
|
+
|
1383
|
+
# Find threshold that best separates classes (for binary)
|
1384
|
+
if len(self.labels) == 2:
|
1385
|
+
label1, label2 = self.labels
|
1386
|
+
vals1 = values_by_label[label1]
|
1387
|
+
vals2 = values_by_label[label2]
|
1388
|
+
|
1389
|
+
# Try different thresholds
|
1390
|
+
all_vals = vals1 + vals2
|
1391
|
+
best_threshold = None
|
1392
|
+
best_accuracy = 0
|
1393
|
+
best_direction = None
|
1394
|
+
|
1395
|
+
for threshold in np.percentile(all_vals, [10, 20, 30, 40, 50, 60, 70, 80, 90]):
|
1396
|
+
# Test both directions
|
1397
|
+
for direction in ["higher", "lower"]:
|
1398
|
+
if direction == "higher":
|
1399
|
+
correct1 = sum(1 for v in vals1 if v > threshold)
|
1400
|
+
correct2 = sum(1 for v in vals2 if v <= threshold)
|
1401
|
+
else:
|
1402
|
+
correct1 = sum(1 for v in vals1 if v < threshold)
|
1403
|
+
correct2 = sum(1 for v in vals2 if v >= threshold)
|
1404
|
+
|
1405
|
+
# Always use balanced accuracy for threshold selection
|
1406
|
+
# This finds fair thresholds regardless of class imbalance
|
1407
|
+
acc1 = correct1 / len(vals1) if len(vals1) > 0 else 0
|
1408
|
+
acc2 = correct2 / len(vals2) if len(vals2) > 0 else 0
|
1409
|
+
accuracy = (acc1 + acc2) / 2
|
1410
|
+
|
1411
|
+
if accuracy > best_accuracy:
|
1412
|
+
best_accuracy = accuracy
|
1413
|
+
best_threshold = threshold
|
1414
|
+
best_direction = direction
|
1415
|
+
|
1416
|
+
# Calculate Youden's J statistic for weight (TPR - FPR)
|
1417
|
+
if best_direction == "higher":
|
1418
|
+
tp = sum(1 for v in vals1 if v > best_threshold)
|
1419
|
+
fn = len(vals1) - tp
|
1420
|
+
tn = sum(1 for v in vals2 if v <= best_threshold)
|
1421
|
+
fp = len(vals2) - tn
|
1422
|
+
else:
|
1423
|
+
tp = sum(1 for v in vals1 if v < best_threshold)
|
1424
|
+
fn = len(vals1) - tp
|
1425
|
+
tn = sum(1 for v in vals2 if v >= best_threshold)
|
1426
|
+
fp = len(vals2) - tn
|
1427
|
+
|
1428
|
+
tpr = tp / len(vals1) if len(vals1) > 0 else 0
|
1429
|
+
fpr = fp / len(vals2) if len(vals2) > 0 else 0
|
1430
|
+
youden_j = max(0.0, min(1.0, tpr - fpr))
|
1431
|
+
|
1432
|
+
# Store all candidates
|
1433
|
+
metric_candidates.append(
|
1434
|
+
{
|
1435
|
+
"metric": metric,
|
1436
|
+
"youden_j": youden_j,
|
1437
|
+
"selection_accuracy": best_accuracy,
|
1438
|
+
"threshold": best_threshold,
|
1439
|
+
"direction": best_direction,
|
1440
|
+
"label1": label1,
|
1441
|
+
"label2": label2,
|
1442
|
+
"stats": {
|
1443
|
+
"mean_" + label1: np.mean(vals1),
|
1444
|
+
"mean_" + label2: np.mean(vals2),
|
1445
|
+
"std_" + label1: np.std(vals1),
|
1446
|
+
"std_" + label2: np.std(vals2),
|
1447
|
+
},
|
1448
|
+
}
|
1449
|
+
)
|
1450
|
+
|
1451
|
+
# Sort by selection accuracy
|
1452
|
+
metric_candidates.sort(key=lambda x: x["selection_accuracy"], reverse=True)
|
1453
|
+
|
1454
|
+
# Use relaxed cutoff when imbalanced
|
1455
|
+
keep_cutoff = 0.55 if is_imbalanced else 0.60
|
1456
|
+
|
1457
|
+
# Keep metrics that pass cutoff, or top 3 if none pass
|
1458
|
+
kept_metrics = [m for m in metric_candidates if m["selection_accuracy"] > keep_cutoff]
|
1459
|
+
if not kept_metrics and metric_candidates:
|
1460
|
+
# Keep top 3 metrics even if they don't pass cutoff
|
1461
|
+
kept_metrics = metric_candidates[:3]
|
1462
|
+
logger.warning(
|
1463
|
+
f"No metrics passed cutoff {keep_cutoff}, keeping top {len(kept_metrics)} metrics"
|
1464
|
+
)
|
1465
|
+
|
1466
|
+
# Store selected metrics
|
1467
|
+
for candidate in kept_metrics:
|
1468
|
+
metric = candidate["metric"]
|
1469
|
+
label1 = candidate["label1"]
|
1470
|
+
label2 = candidate["label2"]
|
1471
|
+
self.thresholds[metric] = {
|
1472
|
+
"accuracy": candidate["youden_j"], # Use Youden's J as weight
|
1473
|
+
"selection_accuracy": candidate["selection_accuracy"],
|
1474
|
+
"thresholds": {
|
1475
|
+
label1: (candidate["threshold"], candidate["direction"]),
|
1476
|
+
label2: (
|
1477
|
+
candidate["threshold"],
|
1478
|
+
"lower" if candidate["direction"] == "higher" else "higher",
|
1479
|
+
),
|
1480
|
+
},
|
1481
|
+
}
|
1482
|
+
self.metrics_info[metric] = candidate["stats"]
|
1483
|
+
|
1484
|
+
def _load_config(self) -> None:
|
1485
|
+
"""Load configuration from file."""
|
1486
|
+
try:
|
1487
|
+
with open(self.config_path, "r") as f:
|
1488
|
+
config = json.load(f)
|
1489
|
+
|
1490
|
+
self.thresholds = config.get("thresholds", {})
|
1491
|
+
self.metrics_info = config.get("metrics_info", {})
|
1492
|
+
|
1493
|
+
# Verify labels match
|
1494
|
+
if config.get("labels") != self.labels:
|
1495
|
+
logger.warning(
|
1496
|
+
f"Saved labels {config.get('labels')} don't match current {self.labels}"
|
1497
|
+
)
|
1498
|
+
|
1499
|
+
except Exception as e:
|
1500
|
+
logger.warning(f"Failed to load config: {e}")
|
1501
|
+
|
1502
|
+
def _get_training_counts(self) -> Dict[str, int]:
|
1503
|
+
"""Get count of examples per class."""
|
1504
|
+
counts = {}
|
1505
|
+
for label in self.labels:
|
1506
|
+
label_dir = self.root_dir / label
|
1507
|
+
counts[label] = len(list(label_dir.glob("*.png")))
|
1508
|
+
counts["unlabeled"] = len(list((self.root_dir / "unlabeled").glob("*.png")))
|
1509
|
+
return counts
|