natural-pdf 0.2.17__py3-none-any.whl → 0.2.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. natural_pdf/__init__.py +8 -0
  2. natural_pdf/analyzers/checkbox/__init__.py +6 -0
  3. natural_pdf/analyzers/checkbox/base.py +265 -0
  4. natural_pdf/analyzers/checkbox/checkbox_analyzer.py +329 -0
  5. natural_pdf/analyzers/checkbox/checkbox_manager.py +166 -0
  6. natural_pdf/analyzers/checkbox/checkbox_options.py +60 -0
  7. natural_pdf/analyzers/checkbox/mixin.py +95 -0
  8. natural_pdf/analyzers/checkbox/rtdetr.py +201 -0
  9. natural_pdf/collections/mixins.py +14 -5
  10. natural_pdf/core/element_manager.py +5 -1
  11. natural_pdf/core/page.py +103 -9
  12. natural_pdf/core/page_collection.py +41 -1
  13. natural_pdf/core/pdf.py +24 -1
  14. natural_pdf/describe/base.py +20 -0
  15. natural_pdf/elements/base.py +152 -10
  16. natural_pdf/elements/element_collection.py +41 -2
  17. natural_pdf/elements/region.py +115 -2
  18. natural_pdf/judge.py +1509 -0
  19. natural_pdf/selectors/parser.py +42 -1
  20. natural_pdf/utils/spatial.py +42 -39
  21. {natural_pdf-0.2.17.dist-info → natural_pdf-0.2.19.dist-info}/METADATA +1 -1
  22. {natural_pdf-0.2.17.dist-info → natural_pdf-0.2.19.dist-info}/RECORD +42 -18
  23. temp/check_model.py +49 -0
  24. temp/check_pdf_content.py +9 -0
  25. temp/checkbox_checks.py +590 -0
  26. temp/checkbox_simple.py +117 -0
  27. temp/checkbox_ux_ideas.py +400 -0
  28. temp/context_manager_prototype.py +177 -0
  29. temp/convert_to_hf.py +60 -0
  30. temp/demo_text_closest.py +66 -0
  31. temp/inspect_model.py +43 -0
  32. temp/rtdetr_dinov2_test.py +49 -0
  33. temp/test_closest_debug.py +26 -0
  34. temp/test_closest_debug2.py +22 -0
  35. temp/test_context_exploration.py +85 -0
  36. temp/test_durham.py +30 -0
  37. temp/test_empty_string.py +16 -0
  38. temp/test_similarity.py +15 -0
  39. {natural_pdf-0.2.17.dist-info → natural_pdf-0.2.19.dist-info}/WHEEL +0 -0
  40. {natural_pdf-0.2.17.dist-info → natural_pdf-0.2.19.dist-info}/entry_points.txt +0 -0
  41. {natural_pdf-0.2.17.dist-info → natural_pdf-0.2.19.dist-info}/licenses/LICENSE +0 -0
  42. {natural_pdf-0.2.17.dist-info → natural_pdf-0.2.19.dist-info}/top_level.txt +0 -0
natural_pdf/judge.py ADDED
@@ -0,0 +1,1509 @@
1
+ """
2
+ Visual Judge for classifying regions based on image content.
3
+
4
+ This module provides a simple visual classifier that learns from examples
5
+ to classify regions (like checkboxes) into categories. It uses basic image
6
+ metrics rather than neural networks for fast, interpretable results.
7
+ """
8
+
9
+ import hashlib
10
+ import json
11
+ import logging
12
+ import shutil
13
+ from collections import namedtuple
14
+ from pathlib import Path
15
+ from typing import Dict, List, Optional, Tuple, Union
16
+
17
+ import numpy as np
18
+ from PIL import Image
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+ # Return types
23
+ Decision = namedtuple("Decision", ["label", "score"])
24
+ PickResult = namedtuple("PickResult", ["region", "index", "label", "score"])
25
+
26
+
27
+ class JudgeError(Exception):
28
+ """Raised when Judge operations fail."""
29
+
30
+ pass
31
+
32
+
33
+ class Judge:
34
+ """
35
+ Visual classifier for regions using simple image metrics.
36
+
37
+ Requires class labels to be specified. For binary classification,
38
+ requires at least one example of each class before making decisions.
39
+
40
+ Examples:
41
+ Checkbox detection:
42
+ ```python
43
+ judge = Judge("checkboxes", labels=["unchecked", "checked"])
44
+ judge.add(empty_box, "unchecked")
45
+ judge.add(marked_box, "checked")
46
+
47
+ result = judge.decide(new_box)
48
+ if result.label == "checked":
49
+ print("Box is checked!")
50
+ ```
51
+
52
+ Signature detection:
53
+ ```python
54
+ judge = Judge("signatures", labels=["unsigned", "signed"])
55
+ judge.add(blank_area, "unsigned")
56
+ judge.add(signature_area, "signed")
57
+
58
+ result = judge.decide(new_region)
59
+ print(f"Classification: {result.label} (confidence: {result.score})")
60
+ ```
61
+ """
62
+
63
+ def __init__(
64
+ self,
65
+ name: str,
66
+ labels: List[str],
67
+ base_dir: Optional[str] = None,
68
+ target_prior: Optional[float] = None,
69
+ ):
70
+ """
71
+ Initialize a Judge for visual classification.
72
+
73
+ Args:
74
+ name: Name for this judge (used for folder name)
75
+ labels: Class labels (required, typically 2 for binary classification)
76
+ base_dir: Base directory for storage. Defaults to current directory
77
+ target_prior: Target prior probability for the FIRST label in the labels list.
78
+ - 0.5 (default) = neutral, treats both classes equally
79
+ - >0.5 = favors labels[0]
80
+ - <0.5 = favors labels[1]
81
+ Example: Judge("cb", ["checked", "unchecked"], target_prior=0.6)
82
+ favors detecting "checked" checkboxes.
83
+ """
84
+ if not labels or len(labels) != 2:
85
+ raise JudgeError("Judge requires exactly 2 class labels (binary classification only)")
86
+
87
+ self.name = name
88
+ self.labels = labels
89
+ self.target_prior = target_prior if target_prior is not None else 0.5
90
+
91
+ # Set up directory structure
92
+ self.base_dir = Path(base_dir) if base_dir else Path.cwd()
93
+ self.root_dir = self.base_dir / name
94
+ self.root_dir.mkdir(exist_ok=True)
95
+
96
+ # Create label directories
97
+ for label in self.labels:
98
+ (self.root_dir / label).mkdir(exist_ok=True)
99
+ (self.root_dir / "unlabeled").mkdir(exist_ok=True)
100
+ (self.root_dir / "_removed").mkdir(exist_ok=True)
101
+
102
+ # Config file
103
+ self.config_path = self.root_dir / "judge.json"
104
+
105
+ # Load existing config or initialize
106
+ self.thresholds = {}
107
+ self.metrics_info = {}
108
+ if self.config_path.exists():
109
+ self._load_config()
110
+
111
+ def add(self, region, label: Optional[str] = None) -> None:
112
+ """
113
+ Add a region to the judge's dataset.
114
+
115
+ Args:
116
+ region: Region object to add
117
+ label: Class label. If None, added to unlabeled for later teaching
118
+
119
+ Raises:
120
+ JudgeError: If label is not in allowed labels
121
+ """
122
+ if label is not None and label not in self.labels:
123
+ raise JudgeError(f"Label '{label}' not in allowed labels: {self.labels}")
124
+
125
+ # Render region to image
126
+ try:
127
+ img = region.render(crop=True)
128
+ if not isinstance(img, Image.Image):
129
+ img = Image.fromarray(img)
130
+ except Exception as e:
131
+ raise JudgeError(f"Failed to render region: {e}")
132
+
133
+ # Convert to RGB if needed
134
+ if img.mode != "RGB":
135
+ img = img.convert("RGB")
136
+
137
+ # Generate hash from image content
138
+ img_array = np.array(img)
139
+ img_hash = hashlib.md5(img_array.tobytes()).hexdigest()[:12]
140
+
141
+ # Determine target directory
142
+ target_dir = self.root_dir / (label if label else "unlabeled")
143
+ target_path = target_dir / f"{img_hash}.png"
144
+
145
+ # Check if hash already exists anywhere
146
+ existing_locations = []
147
+ for check_label in self.labels + ["unlabeled", "_removed"]:
148
+ check_path = self.root_dir / check_label / f"{img_hash}.png"
149
+ if check_path.exists():
150
+ existing_locations.append(check_label)
151
+
152
+ if existing_locations:
153
+ logger.warning(f"Duplicate image detected (hash: {img_hash})")
154
+ logger.warning(f"Already exists in: {', '.join(existing_locations)}")
155
+ print(f"⚠️ Duplicate image - already exists in: {', '.join(existing_locations)}")
156
+ return
157
+
158
+ # Save image
159
+ img.save(target_path)
160
+ logger.debug(f"Added image {img_hash} to {label if label else 'unlabeled'}")
161
+
162
+ def teach(self, labels: Optional[List[str]] = None, review: bool = False) -> None:
163
+ """
164
+ Interactive teaching interface using IPython widgets.
165
+
166
+ Args:
167
+ labels: Labels to use for teaching. Defaults to self.labels
168
+ review: If True, review already labeled images for re-classification
169
+ """
170
+ # Check for IPython environment
171
+ try:
172
+ import ipywidgets as widgets
173
+ from IPython.display import clear_output, display
174
+ except ImportError:
175
+ raise JudgeError(
176
+ "Teaching requires IPython and ipywidgets. Use 'pip install ipywidgets'"
177
+ )
178
+
179
+ labels = labels or self.labels
180
+
181
+ # Get images to review
182
+ if review:
183
+ # Get all labeled images for review
184
+ files_to_review = []
185
+ for label in self.labels:
186
+ label_dir = self.root_dir / label
187
+ for img_path in sorted(label_dir.glob("*.png")):
188
+ files_to_review.append((img_path, label))
189
+
190
+ if not files_to_review:
191
+ print("No labeled images to review")
192
+ return
193
+
194
+ # Shuffle for review
195
+ import random
196
+
197
+ random.shuffle(files_to_review)
198
+ review_files = [f[0] for f in files_to_review]
199
+ original_labels = {str(f[0]): f[1] for f in files_to_review}
200
+ else:
201
+ # Get unlabeled images
202
+ unlabeled_dir = self.root_dir / "unlabeled"
203
+ review_files = sorted(unlabeled_dir.glob("*.png"))
204
+ original_labels = {}
205
+
206
+ if not review_files:
207
+ print("No unlabeled images to teach")
208
+ return
209
+
210
+ # State for teaching
211
+ self._teaching_state = {
212
+ "current_index": 0,
213
+ "labeled_count": 0,
214
+ "removed_count": 0,
215
+ "files": review_files,
216
+ "labels": labels,
217
+ "review_mode": review,
218
+ "original_labels": original_labels,
219
+ }
220
+
221
+ # Create widgets
222
+ image_widget = widgets.Image()
223
+ status_label = widgets.Label()
224
+
225
+ # Create buttons for labeling
226
+ button_layout = widgets.Layout(width="auto", margin="5px")
227
+
228
+ btn_prev = widgets.Button(description="↑ Previous", layout=button_layout)
229
+ btn_class1 = widgets.Button(
230
+ description=f"← {labels[0]}", layout=button_layout, button_style="primary"
231
+ )
232
+ btn_class2 = widgets.Button(
233
+ description=f"→ {labels[1]}", layout=button_layout, button_style="success"
234
+ )
235
+ btn_skip = widgets.Button(description="↓ Skip", layout=button_layout)
236
+ btn_remove = widgets.Button(
237
+ description="✗ Remove", layout=button_layout, button_style="danger"
238
+ )
239
+
240
+ button_box = widgets.HBox([btn_prev, btn_class1, btn_class2, btn_skip, btn_remove])
241
+
242
+ # Keyboard shortcuts info
243
+ info_label = widgets.Label(
244
+ value="Keys: ↑ prev | ← "
245
+ + labels[0]
246
+ + " | → "
247
+ + labels[1]
248
+ + " | ↓ skip | Delete remove"
249
+ )
250
+
251
+ def update_display():
252
+ """Update the displayed image and status."""
253
+ state = self._teaching_state
254
+ if 0 <= state["current_index"] < len(state["files"]):
255
+ img_path = state["files"][state["current_index"]]
256
+ with open(img_path, "rb") as f:
257
+ image_widget.value = f.read()
258
+
259
+ # Build status text
260
+ status_text = f"Image {state['current_index'] + 1} of {len(state['files'])}"
261
+ if state["review_mode"]:
262
+ current_label = state["original_labels"].get(str(img_path), "unknown")
263
+ status_text += f" (Currently: {current_label})"
264
+ status_text += f" | Labeled: {state['labeled_count']}"
265
+ if state["removed_count"] > 0:
266
+ status_text += f" | Removed: {state['removed_count']}"
267
+
268
+ status_label.value = status_text
269
+
270
+ # Update button states
271
+ btn_prev.disabled = state["current_index"] == 0
272
+ else:
273
+ status_label.value = "Teaching complete!"
274
+ # Hide the image widget instead of showing broken image
275
+ image_widget.layout.display = "none"
276
+ # Disable all buttons
277
+ btn_prev.disabled = True
278
+ btn_class1.disabled = True
279
+ btn_class2.disabled = True
280
+ btn_skip.disabled = True
281
+
282
+ # Auto-retrain
283
+ if state["labeled_count"] > 0 or state["removed_count"] > 0:
284
+ clear_output(wait=True)
285
+ print("Teaching complete!")
286
+ print(f"Labeled: {state['labeled_count']} images")
287
+ if state["removed_count"] > 0:
288
+ print(f"Removed: {state['removed_count']} images")
289
+
290
+ if state["labeled_count"] > 0:
291
+ print("\nRetraining with new examples...")
292
+ self._retrain()
293
+ print("✓ Training complete! Judge is ready to use.")
294
+ else:
295
+ print("No changes made.")
296
+
297
+ def move_file_to_class(class_index):
298
+ """Move current file to specified class."""
299
+ state = self._teaching_state
300
+ if state["current_index"] >= len(state["files"]):
301
+ return
302
+
303
+ current_file = state["files"][state["current_index"]]
304
+ target_dir = self.root_dir / labels[class_index]
305
+ shutil.move(str(current_file), str(target_dir / current_file.name))
306
+ state["labeled_count"] += 1
307
+ state["current_index"] += 1
308
+ update_display()
309
+
310
+ # Button callbacks
311
+ def on_prev(b):
312
+ state = self._teaching_state
313
+ if state["current_index"] > 0:
314
+ state["current_index"] -= 1
315
+ update_display()
316
+
317
+ def on_class1(b):
318
+ move_file_to_class(0)
319
+
320
+ def on_class2(b):
321
+ move_file_to_class(1)
322
+
323
+ def on_skip(b):
324
+ state = self._teaching_state
325
+ state["current_index"] += 1
326
+ update_display()
327
+
328
+ def on_remove(b):
329
+ state = self._teaching_state
330
+ if state["current_index"] >= len(state["files"]):
331
+ return
332
+
333
+ current_file = state["files"][state["current_index"]]
334
+ target_dir = self.root_dir / "_removed"
335
+ shutil.move(str(current_file), str(target_dir / current_file.name))
336
+ state["removed_count"] += 1
337
+ state["current_index"] += 1
338
+ update_display()
339
+
340
+ # Connect buttons
341
+ btn_prev.on_click(on_prev)
342
+ btn_class1.on_click(on_class1)
343
+ btn_class2.on_click(on_class2)
344
+ btn_skip.on_click(on_skip)
345
+ btn_remove.on_click(on_remove)
346
+
347
+ # Create output widget for keyboard handling
348
+ output = widgets.Output()
349
+
350
+ # Keyboard event handler
351
+ def on_key(event):
352
+ """Handle keyboard events."""
353
+ if event["type"] != "keydown":
354
+ return
355
+
356
+ key = event["key"]
357
+
358
+ if key == "ArrowUp":
359
+ on_prev(None)
360
+ elif key == "ArrowLeft":
361
+ on_class1(None)
362
+ elif key == "ArrowRight":
363
+ on_class2(None)
364
+ elif key == "ArrowDown":
365
+ on_skip(None)
366
+ elif key in ["Delete", "Backspace"]:
367
+ on_remove(None)
368
+
369
+ # Display everything
370
+ display(status_label)
371
+ display(image_widget)
372
+ display(button_box)
373
+ display(info_label)
374
+ display(output)
375
+
376
+ # Show first image
377
+ update_display()
378
+
379
+ # Try to set up keyboard handling (may not work in all environments)
380
+ try:
381
+ from ipyevents import Event
382
+
383
+ event_handler = Event(source=output, watched_events=["keydown"])
384
+ event_handler.on_dom_event(on_key)
385
+ except:
386
+ # If ipyevents not available, just use buttons
387
+ print("Note: Install ipyevents for keyboard shortcuts: pip install ipyevents")
388
+
389
+ def decide(self, regions: Union["Region", List["Region"]]) -> Union[Decision, List[Decision]]:
390
+ """
391
+ Classify one or more regions.
392
+
393
+ Args:
394
+ regions: Single region or list of regions to classify
395
+
396
+ Returns:
397
+ Decision or list of Decisions with label and score
398
+
399
+ Raises:
400
+ JudgeError: If not enough training examples
401
+ """
402
+ # Check if we have examples
403
+ for label in self.labels:
404
+ label_dir = self.root_dir / label
405
+ if not any(label_dir.glob("*.png")):
406
+ raise JudgeError(f"Need at least one example of class '{label}' before deciding")
407
+
408
+ # Ensure thresholds are current
409
+ if not self.thresholds:
410
+ self._retrain()
411
+
412
+ # Handle single region
413
+ single_input = not isinstance(regions, list)
414
+ if single_input:
415
+ regions = [regions]
416
+
417
+ results = []
418
+ for region in regions:
419
+ # Extract metrics
420
+ metrics = self._extract_metrics(region)
421
+
422
+ # Apply thresholds with soft voting
423
+ votes = {label: 0.0 for label in self.labels}
424
+ total_weight = 0.0
425
+
426
+ for metric_name, value in metrics.items():
427
+ if metric_name in self.thresholds:
428
+ metric_info = self.thresholds[metric_name]
429
+ weight = metric_info["accuracy"] # This is now Youden's J
430
+
431
+ # For binary classification
432
+ label1, label2 = self.labels
433
+ threshold1, direction1 = metric_info["thresholds"][label1]
434
+
435
+ # Get standard deviations for soft voting
436
+ stats = self.metrics_info.get(metric_name, {})
437
+ s1 = stats.get(f"std_{label1}", 0.0)
438
+ s2 = stats.get(f"std_{label2}", 0.0)
439
+ scale1 = s1 if s1 > 1e-6 else 1.0
440
+ scale2 = s2 if s2 > 1e-6 else 1.0
441
+
442
+ # Calculate signed margin (positive favors label1, negative favors label2)
443
+ if direction1 == "higher":
444
+ margin = (value - threshold1) / (scale1 if value >= threshold1 else scale2)
445
+ else:
446
+ margin = (threshold1 - value) / (scale1 if value <= threshold1 else scale2)
447
+
448
+ # Clip margin to avoid single metric dominating
449
+ margin = np.clip(margin, -6, 6)
450
+
451
+ # Soft votes using sigmoid
452
+ p1 = 1.0 / (1.0 + np.exp(-margin))
453
+ p2 = 1.0 - p1
454
+
455
+ votes[label1] += weight * p1
456
+ votes[label2] += weight * p2
457
+ total_weight += weight
458
+
459
+ # Normalize votes
460
+ if total_weight > 0:
461
+ for label in votes:
462
+ votes[label] /= total_weight
463
+ else:
464
+ # Fallback: uniform votes so prior still works
465
+ for label in votes:
466
+ votes[label] = 0.5
467
+ total_weight = 1.0
468
+
469
+ # Apply prior bias correction
470
+ def _logit(p, eps=1e-6):
471
+ p = max(eps, min(1 - eps, p))
472
+ return np.log(p / (1 - p))
473
+
474
+ def _sigmoid(x):
475
+ if x >= 0:
476
+ z = np.exp(-x)
477
+ return 1.0 / (1.0 + z)
478
+ else:
479
+ z = np.exp(x)
480
+ return z / (1.0 + z)
481
+
482
+ # Estimate priors from training counts
483
+ counts = self._get_training_counts()
484
+ label1, label2 = self.labels
485
+ n1 = counts.get(label1, 0)
486
+ n2 = counts.get(label2, 0)
487
+ total = max(1, n1 + n2)
488
+
489
+ if n1 > 0 and n2 > 0: # Only apply bias if we have examples of both classes
490
+ emp_prior1 = n1 / total
491
+ emp_prior2 = n2 / total
492
+
493
+ # Target prior (0.5/0.5 neutralizes imbalance)
494
+ target_prior1 = self.target_prior
495
+ target_prior2 = 1.0 - self.target_prior
496
+
497
+ # Calculate bias
498
+ bias1 = _logit(target_prior1) - _logit(emp_prior1)
499
+ bias2 = _logit(target_prior2) - _logit(emp_prior2)
500
+
501
+ # Apply bias in logit space
502
+ v1 = _sigmoid(_logit(votes[label1]) + bias1)
503
+ v2 = _sigmoid(_logit(votes[label2]) + bias2)
504
+
505
+ # Renormalize
506
+ s = v1 + v2
507
+ votes[label1] = v1 / s
508
+ votes[label2] = v2 / s
509
+
510
+ # Find best label
511
+ best_label = max(votes.items(), key=lambda x: x[1])
512
+ results.append(Decision(label=best_label[0], score=best_label[1]))
513
+
514
+ return results[0] if single_input else results
515
+
516
+ def pick(
517
+ self, target_label: str, regions: List["Region"], labels: Optional[List[str]] = None
518
+ ) -> PickResult:
519
+ """
520
+ Pick which region best matches the target label.
521
+
522
+ Args:
523
+ target_label: The class label to look for
524
+ regions: List of regions to choose from
525
+ labels: Optional human-friendly labels for each region
526
+
527
+ Returns:
528
+ PickResult with winning region, index, label (if provided), and score
529
+
530
+ Raises:
531
+ JudgeError: If target_label not in allowed labels
532
+ """
533
+ if target_label not in self.labels:
534
+ raise JudgeError(f"Target label '{target_label}' not in allowed labels: {self.labels}")
535
+
536
+ # Classify all regions
537
+ decisions = self.decide(regions)
538
+
539
+ # Find best match for target label
540
+ best_index = -1
541
+ best_score = -1.0
542
+
543
+ for i, decision in enumerate(decisions):
544
+ if decision.label == target_label and decision.score > best_score:
545
+ best_score = decision.score
546
+ best_index = i
547
+
548
+ if best_index == -1:
549
+ # No region matched the target label
550
+ raise JudgeError(f"No region classified as '{target_label}'")
551
+
552
+ # Build result
553
+ region = regions[best_index]
554
+ label = labels[best_index] if labels and best_index < len(labels) else None
555
+
556
+ return PickResult(region=region, index=best_index, label=label, score=best_score)
557
+
558
+ def count(self, target_label: str, regions: List["Region"]) -> int:
559
+ """
560
+ Count how many regions match the target label.
561
+
562
+ Args:
563
+ target_label: The class label to count
564
+ regions: List of regions to check
565
+
566
+ Returns:
567
+ Number of regions classified as target_label
568
+ """
569
+ decisions = self.decide(regions)
570
+ return sum(1 for d in decisions if d.label == target_label)
571
+
572
+ def info(self) -> None:
573
+ """
574
+ Show configuration and training information for this Judge.
575
+ """
576
+ print(f"Judge: {self.name}")
577
+ print(f"Labels: {self.labels}")
578
+ if self.target_prior != 0.5:
579
+ print(
580
+ f"Target prior: {self.target_prior:.2f} (favors '{self.labels[0]}')"
581
+ if self.target_prior > 0.5
582
+ else f"Target prior: {self.target_prior:.2f} (favors '{self.labels[1]}')"
583
+ )
584
+
585
+ # Get training counts
586
+ counts = self._get_training_counts()
587
+ print(f"\nTraining examples:")
588
+ for label in self.labels:
589
+ count = counts.get(label, 0)
590
+ print(f" {label}: {count}")
591
+
592
+ if counts.get("unlabeled", 0) > 0:
593
+ print(f" unlabeled: {counts['unlabeled']}")
594
+
595
+ # Show actual imbalance
596
+ labeled_counts = [counts.get(label, 0) for label in self.labels]
597
+ if all(c > 0 for c in labeled_counts):
598
+ max_count = max(labeled_counts)
599
+ min_count = min(labeled_counts)
600
+ if max_count != min_count:
601
+ # Find which is which
602
+ for i, label in enumerate(self.labels):
603
+ if counts.get(label, 0) == max_count:
604
+ majority_label = label
605
+ if counts.get(label, 0) == min_count:
606
+ minority_label = label
607
+
608
+ ratio = max_count / min_count
609
+ print(
610
+ f"\nClass imbalance: {majority_label}:{minority_label} = {max_count}:{min_count} ({ratio:.1f}:1)"
611
+ )
612
+
613
+ print(" Using Youden's J weights with soft voting and prior correction")
614
+
615
+ def inspect(self, preview: bool = True) -> None:
616
+ """
617
+ Inspect all stored examples, showing their true labels and predicted labels/scores.
618
+ Useful for debugging classification issues.
619
+
620
+ Args:
621
+ preview: If True (default), display images inline in HTML tables (requires IPython/Jupyter).
622
+ If False, use text-only output.
623
+ """
624
+ if not self.thresholds:
625
+ print("No trained model yet. Add examples and the model will auto-train.")
626
+ return
627
+
628
+ if not preview:
629
+ # Show basic info first
630
+ self.info()
631
+ print("-" * 80)
632
+
633
+ print("\nThresholds learned:")
634
+ for metric, info in self.thresholds.items():
635
+ weight = info["accuracy"] # This is now Youden's J
636
+ selection_acc = info.get(
637
+ "selection_accuracy", info["accuracy"]
638
+ ) # Fallback for old models
639
+ print(f" {metric}: weight={weight:.3f} (selection_accuracy={selection_acc:.3f})")
640
+ for label, (threshold, direction) in info["thresholds"].items():
641
+ print(f" {label}: {direction} than {threshold:.3f}")
642
+
643
+ # Show metric distribution info if available
644
+ if metric in self.metrics_info:
645
+ metric_stats = self.metrics_info[metric]
646
+ for label in self.labels:
647
+ mean_key = f"mean_{label}"
648
+ std_key = f"std_{label}"
649
+ if mean_key in metric_stats:
650
+ print(
651
+ f" {label} distribution: mean={metric_stats[mean_key]:.3f}, std={metric_stats[std_key]:.3f}"
652
+ )
653
+
654
+ if preview:
655
+ # HTML preview mode
656
+ try:
657
+ import base64
658
+ import io
659
+
660
+ from IPython.display import HTML, display
661
+ except ImportError:
662
+ print("Preview mode requires IPython/Jupyter. Falling back to text mode.")
663
+ preview = False
664
+
665
+ if preview:
666
+ # Build HTML tables for everything
667
+ html_parts = []
668
+ html_parts.append("<style>")
669
+ html_parts.append("table { border-collapse: collapse; margin: 20px 0; }")
670
+ html_parts.append("th, td { border: 1px solid #ddd; padding: 8px; text-align: left; }")
671
+ html_parts.append("th { background-color: #f2f2f2; font-weight: bold; }")
672
+ html_parts.append("img { max-width: 60px; max-height: 60px; }")
673
+ html_parts.append(".correct { color: green; }")
674
+ html_parts.append(".incorrect { color: red; }")
675
+ html_parts.append(".metrics { font-size: 0.9em; color: #666; }")
676
+ html_parts.append("h3 { margin-top: 30px; }")
677
+ html_parts.append(".imbalance-warning { background-color: #fff3cd; color: #856404; }")
678
+ html_parts.append("</style>")
679
+
680
+ # Configuration table
681
+ html_parts.append("<h3>Judge Configuration</h3>")
682
+ html_parts.append("<table>")
683
+ html_parts.append("<tr><th>Property</th><th>Value</th></tr>")
684
+ html_parts.append(f"<tr><td>Name</td><td>{self.name}</td></tr>")
685
+ html_parts.append(f"<tr><td>Labels</td><td>{', '.join(self.labels)}</td></tr>")
686
+ html_parts.append(f"<tr><td>Target Prior</td><td>{self.target_prior:.2f}")
687
+ if self.target_prior != 0.5:
688
+ html_parts.append(
689
+ f" (favors '{self.labels[0] if self.target_prior > 0.5 else self.labels[1]}')"
690
+ )
691
+ html_parts.append("</td></tr>")
692
+ html_parts.append("</table>")
693
+
694
+ # Training counts table
695
+ counts = self._get_training_counts()
696
+ html_parts.append("<h3>Training Examples</h3>")
697
+ html_parts.append("<table>")
698
+ html_parts.append("<tr><th>Class</th><th>Count</th></tr>")
699
+
700
+ # Check for imbalance
701
+ labeled_counts = [counts.get(label, 0) for label in self.labels]
702
+ is_imbalanced = False
703
+ if all(c > 0 for c in labeled_counts):
704
+ max_count = max(labeled_counts)
705
+ min_count = min(labeled_counts)
706
+ if max_count != min_count:
707
+ ratio = max_count / min_count
708
+ is_imbalanced = ratio > 1.5
709
+
710
+ for label in self.labels:
711
+ count = counts.get(label, 0)
712
+ row_class = ""
713
+ if is_imbalanced:
714
+ if count == max(labeled_counts):
715
+ row_class = ' class="imbalance-warning"'
716
+ html_parts.append(f"<tr{row_class}><td>{label}</td><td>{count}</td></tr>")
717
+
718
+ if counts.get("unlabeled", 0) > 0:
719
+ html_parts.append(f"<tr><td>unlabeled</td><td>{counts['unlabeled']}</td></tr>")
720
+
721
+ html_parts.append("</table>")
722
+
723
+ if is_imbalanced:
724
+ html_parts.append(
725
+ f"<p><em>Class imbalance detected ({ratio:.1f}:1). Using Youden's J weights with prior correction.</em></p>"
726
+ )
727
+
728
+ # Thresholds table
729
+ html_parts.append("<h3>Learned Thresholds</h3>")
730
+ html_parts.append("<table>")
731
+ html_parts.append(
732
+ "<tr><th>Metric</th><th>Weight (Youden's J)</th><th>Selection Accuracy</th><th>Threshold Details</th></tr>"
733
+ )
734
+
735
+ for metric, info in self.thresholds.items():
736
+ weight = info["accuracy"] # This is Youden's J
737
+ selection_acc = info.get("selection_accuracy", weight)
738
+
739
+ # Build threshold details
740
+ details = []
741
+ for label, (threshold, direction) in info["thresholds"].items():
742
+ details.append(f"<br>{label}: {direction} than {threshold:.3f}")
743
+
744
+ # Add distribution info if available
745
+ if metric in self.metrics_info:
746
+ metric_stats = self.metrics_info[metric]
747
+ details.append("<br><em>Distributions:</em>")
748
+ for label in self.labels:
749
+ mean_key = f"mean_{label}"
750
+ std_key = f"std_{label}"
751
+ if mean_key in metric_stats:
752
+ details.append(
753
+ f"<br>&nbsp;&nbsp;{label}: μ={metric_stats[mean_key]:.1f}, σ={metric_stats[std_key]:.1f}"
754
+ )
755
+
756
+ html_parts.append("<tr>")
757
+ html_parts.append(f"<td>{metric}</td>")
758
+ html_parts.append(f"<td>{weight:.3f}</td>")
759
+ html_parts.append(f"<td>{selection_acc:.3f}</td>")
760
+ html_parts.append(f"<td>{''.join(details)}</td>")
761
+ html_parts.append("</tr>")
762
+
763
+ html_parts.append("</table>")
764
+
765
+ all_correct = 0
766
+ all_total = 0
767
+
768
+ # First show labeled examples
769
+ for true_label in self.labels:
770
+ label_dir = self.root_dir / true_label
771
+ examples = list(label_dir.glob("*.png"))
772
+
773
+ if not examples:
774
+ continue
775
+
776
+ html_parts.append(
777
+ f"<h3>Predictions: {true_label.upper()} ({len(examples)} total)</h3>"
778
+ )
779
+ html_parts.append("<table>")
780
+ html_parts.append(
781
+ "<tr><th>Image</th><th>Status</th><th>Predicted</th><th>Score</th><th>Key Metrics</th></tr>"
782
+ )
783
+
784
+ correct = 0
785
+
786
+ for img_path in sorted(examples)[:20]: # Show max 20 per class in preview
787
+ # Load image
788
+ img = Image.open(img_path)
789
+ mock_region = type("MockRegion", (), {"render": lambda self, crop=True: img})()
790
+
791
+ # Get prediction
792
+ decision = self.decide(mock_region)
793
+ is_correct = decision.label == true_label
794
+ if is_correct:
795
+ correct += 1
796
+
797
+ # Extract metrics
798
+ metrics = self._extract_metrics(mock_region)
799
+
800
+ # Convert image to base64
801
+ buffered = io.BytesIO()
802
+ img.save(buffered, format="PNG")
803
+ img_str = base64.b64encode(buffered.getvalue()).decode()
804
+
805
+ # Build row
806
+ status_class = "correct" if is_correct else "incorrect"
807
+ status_symbol = "✓" if is_correct else "✗"
808
+
809
+ # Format key metrics
810
+ metric_strs = []
811
+ for metric, value in sorted(metrics.items()):
812
+ if metric in self.thresholds:
813
+ metric_strs.append(f"{metric}={value:.1f}")
814
+ metrics_html = "<br>".join(metric_strs[:3])
815
+
816
+ html_parts.append("<tr>")
817
+ html_parts.append(f'<td><img src="data:image/png;base64,{img_str}" /></td>')
818
+ html_parts.append(f'<td class="{status_class}">{status_symbol}</td>')
819
+ html_parts.append(f"<td>{decision.label}</td>")
820
+ html_parts.append(f"<td>{decision.score:.3f}</td>")
821
+ html_parts.append(f'<td class="metrics">{metrics_html}</td>')
822
+ html_parts.append("</tr>")
823
+
824
+ html_parts.append("</table>")
825
+
826
+ accuracy = correct / len(examples) if examples else 0
827
+ all_correct += correct
828
+ all_total += len(examples)
829
+
830
+ if len(examples) > 20:
831
+ html_parts.append(f"<p><em>... and {len(examples) - 20} more</em></p>")
832
+ html_parts.append(
833
+ f"<p>Accuracy for {true_label}: <strong>{accuracy:.1%}</strong> ({correct}/{len(examples)})</p>"
834
+ )
835
+
836
+ if all_total > 0:
837
+ overall_accuracy = all_correct / all_total
838
+ html_parts.append(
839
+ f"<h3>Overall accuracy: {overall_accuracy:.1%} ({all_correct}/{all_total})</h3>"
840
+ )
841
+
842
+ # Now show unlabeled examples with predictions
843
+ unlabeled_dir = self.root_dir / "unlabeled"
844
+ unlabeled_examples = list(unlabeled_dir.glob("*.png"))
845
+
846
+ if unlabeled_examples:
847
+ html_parts.append(
848
+ f"<h3>Predictions: UNLABELED ({len(unlabeled_examples)} total)</h3>"
849
+ )
850
+ html_parts.append("<table>")
851
+ html_parts.append(
852
+ "<tr><th>Image</th><th>Predicted</th><th>Score</th><th>Key Metrics</th></tr>"
853
+ )
854
+
855
+ for img_path in sorted(unlabeled_examples)[:20]: # Show max 20
856
+ # Load image
857
+ img = Image.open(img_path)
858
+ mock_region = type("MockRegion", (), {"render": lambda self, crop=True: img})()
859
+
860
+ # Get prediction
861
+ decision = self.decide(mock_region)
862
+
863
+ # Extract metrics
864
+ metrics = self._extract_metrics(mock_region)
865
+
866
+ # Convert image to base64
867
+ buffered = io.BytesIO()
868
+ img.save(buffered, format="PNG")
869
+ img_str = base64.b64encode(buffered.getvalue()).decode()
870
+
871
+ # Format key metrics
872
+ metric_strs = []
873
+ for metric, value in sorted(metrics.items()):
874
+ if metric in self.thresholds:
875
+ metric_strs.append(f"{metric}={value:.1f}")
876
+ metrics_html = "<br>".join(metric_strs[:3])
877
+
878
+ html_parts.append("<tr>")
879
+ html_parts.append(f'<td><img src="data:image/png;base64,{img_str}" /></td>')
880
+ html_parts.append(f"<td>{decision.label}</td>")
881
+ html_parts.append(f"<td>{decision.score:.3f}</td>")
882
+ html_parts.append(f'<td class="metrics">{metrics_html}</td>')
883
+ html_parts.append("</tr>")
884
+
885
+ html_parts.append("</table>")
886
+
887
+ if len(unlabeled_examples) > 20:
888
+ html_parts.append(
889
+ f"<p><em>... and {len(unlabeled_examples) - 20} more</em></p>"
890
+ )
891
+
892
+ # Display HTML
893
+ display(HTML("".join(html_parts)))
894
+
895
+ else:
896
+ # Text mode (original)
897
+ print("\nPredictions on training data:")
898
+ print("-" * 80)
899
+
900
+ # Test each labeled example
901
+ all_correct = 0
902
+ all_total = 0
903
+
904
+ for true_label in self.labels:
905
+ label_dir = self.root_dir / true_label
906
+ examples = list(label_dir.glob("*.png"))
907
+
908
+ if not examples:
909
+ continue
910
+
911
+ print(f"\n{true_label.upper()} examples ({len(examples)} total):")
912
+ correct = 0
913
+
914
+ for img_path in sorted(examples)[:10]: # Show max 10 per class
915
+ # Load image and create mock region
916
+ img = Image.open(img_path)
917
+ mock_region = type("MockRegion", (), {"render": lambda self, crop=True: img})()
918
+
919
+ # Get prediction
920
+ decision = self.decide(mock_region)
921
+ is_correct = decision.label == true_label
922
+ if is_correct:
923
+ correct += 1
924
+
925
+ # Extract metrics for this example
926
+ metrics = self._extract_metrics(mock_region)
927
+
928
+ # Show result
929
+ status = "✓" if is_correct else "✗"
930
+ print(
931
+ f" {status} {img_path.name}: predicted={decision.label} (score={decision.score:.3f})"
932
+ )
933
+
934
+ # Show key metric values
935
+ metric_strs = []
936
+ for metric, value in sorted(metrics.items()):
937
+ if metric in self.thresholds:
938
+ metric_strs.append(f"{metric}={value:.2f}")
939
+ if metric_strs:
940
+ print(f" Metrics: {', '.join(metric_strs[:3])}")
941
+
942
+ accuracy = correct / len(examples) if examples else 0
943
+ all_correct += correct
944
+ all_total += len(examples)
945
+
946
+ if len(examples) > 10:
947
+ print(f" ... and {len(examples) - 10} more")
948
+ print(f" Accuracy for {true_label}: {accuracy:.1%} ({correct}/{len(examples)})")
949
+
950
+ if all_total > 0:
951
+ overall_accuracy = all_correct / all_total
952
+ print(f"\nOverall accuracy: {overall_accuracy:.1%} ({all_correct}/{all_total})")
953
+
954
+ # Show unlabeled examples with predictions
955
+ unlabeled_dir = self.root_dir / "unlabeled"
956
+ unlabeled_examples = list(unlabeled_dir.glob("*.png"))
957
+
958
+ if unlabeled_examples:
959
+ print(f"\nUNLABELED examples ({len(unlabeled_examples)} total) - predictions:")
960
+
961
+ for img_path in sorted(unlabeled_examples)[:10]: # Show max 10
962
+ # Load image and create mock region
963
+ img = Image.open(img_path)
964
+ mock_region = type("MockRegion", (), {"render": lambda self, crop=True: img})()
965
+
966
+ # Get prediction
967
+ decision = self.decide(mock_region)
968
+
969
+ # Extract metrics
970
+ metrics = self._extract_metrics(mock_region)
971
+
972
+ print(
973
+ f" {img_path.name}: predicted={decision.label} (score={decision.score:.3f})"
974
+ )
975
+
976
+ # Show key metric values
977
+ metric_strs = []
978
+ for metric, value in sorted(metrics.items()):
979
+ if metric in self.thresholds:
980
+ metric_strs.append(f"{metric}={value:.2f}")
981
+ if metric_strs:
982
+ print(f" Metrics: {', '.join(metric_strs[:3])}")
983
+
984
+ if len(unlabeled_examples) > 10:
985
+ print(f" ... and {len(unlabeled_examples) - 10} more")
986
+
987
+ def lookup(self, region) -> Optional[Tuple[str, Image.Image]]:
988
+ """
989
+ Look up a region and return its hash and image if found in training data.
990
+
991
+ Args:
992
+ region: Region to look up
993
+
994
+ Returns:
995
+ Tuple of (hash, image) if found, None if not found
996
+ """
997
+ try:
998
+ # Generate hash for the region
999
+ img = region.render(crop=True)
1000
+ if not isinstance(img, Image.Image):
1001
+ img = Image.fromarray(img)
1002
+ if img.mode != "RGB":
1003
+ img = img.convert("RGB")
1004
+ img_array = np.array(img)
1005
+ img_hash = hashlib.md5(img_array.tobytes()).hexdigest()[:12]
1006
+
1007
+ # Look for the image in all directories
1008
+ for subdir in ["checked", "unchecked", "unlabeled", "_removed"]:
1009
+ if subdir == "checked" or subdir == "unchecked":
1010
+ # Only look in valid label directories
1011
+ if subdir not in self.labels:
1012
+ continue
1013
+
1014
+ img_path = self.root_dir / subdir / f"{img_hash}.png"
1015
+ if img_path.exists():
1016
+ stored_img = Image.open(img_path)
1017
+ logger.debug(f"Found region in '{subdir}' with hash {img_hash}")
1018
+ return (img_hash, stored_img)
1019
+
1020
+ logger.debug(f"Region not found in training data (hash: {img_hash})")
1021
+ return None
1022
+
1023
+ except Exception as e:
1024
+ logger.error(f"Failed to lookup region: {e}")
1025
+ return None
1026
+
1027
+ def show(self, max_per_class: int = 10, size: Tuple[int, int] = (100, 100)) -> None:
1028
+ """
1029
+ Display a grid showing examples from each category.
1030
+
1031
+ Args:
1032
+ max_per_class: Maximum number of examples to show per class
1033
+ size: Size of each image in pixels (width, height)
1034
+ """
1035
+ try:
1036
+ import ipywidgets as widgets
1037
+ from IPython.display import display
1038
+ from PIL import Image as PILImage
1039
+ except ImportError:
1040
+ print("Show requires IPython and ipywidgets")
1041
+ return
1042
+
1043
+ # Collect images from each category
1044
+ categories = {}
1045
+ total_counts = {}
1046
+ for label in self.labels:
1047
+ label_dir = self.root_dir / label
1048
+ all_images = list(label_dir.glob("*.png"))
1049
+ total_counts[label] = len(all_images)
1050
+ images = sorted(all_images)[:max_per_class]
1051
+ if images:
1052
+ categories[label] = images
1053
+
1054
+ # Add unlabeled if any
1055
+ unlabeled_dir = self.root_dir / "unlabeled"
1056
+ all_unlabeled = list(unlabeled_dir.glob("*.png"))
1057
+ total_counts["unlabeled"] = len(all_unlabeled)
1058
+ unlabeled = sorted(all_unlabeled)[:max_per_class]
1059
+ if unlabeled:
1060
+ categories["unlabeled"] = unlabeled
1061
+
1062
+ if not categories:
1063
+ print("No images to show")
1064
+ return
1065
+
1066
+ # Create grid layout
1067
+ rows = []
1068
+
1069
+ # Check for class imbalance
1070
+ labeled_counts = {k: v for k, v in total_counts.items() if k != "unlabeled"}
1071
+ if labeled_counts and len(labeled_counts) >= 2:
1072
+ max_count = max(labeled_counts.values())
1073
+ min_count = min(labeled_counts.values())
1074
+ if min_count > 0 and max_count / min_count > 3:
1075
+ warning = widgets.HTML(
1076
+ f'<div style="background: #fff3cd; padding: 10px; margin: 10px 0; border: 1px solid #ffeeba; border-radius: 4px;">'
1077
+ f"<strong>⚠️ Class imbalance detected:</strong> {labeled_counts}<br>"
1078
+ f"Consider adding more examples of the minority class for better accuracy."
1079
+ f"</div>"
1080
+ )
1081
+ rows.append(warning)
1082
+
1083
+ for category, image_paths in categories.items():
1084
+ # Category header showing total count
1085
+ shown = len(image_paths)
1086
+ total = total_counts[category]
1087
+ header_text = f"<h3>{category}"
1088
+ if shown < total:
1089
+ header_text += f" ({shown} of {total} shown)"
1090
+ else:
1091
+ header_text += f" ({total} total)"
1092
+ header_text += "</h3>"
1093
+ header = widgets.HTML(header_text)
1094
+
1095
+ # Image row
1096
+ image_widgets = []
1097
+ for img_path in image_paths:
1098
+ # Load and resize image
1099
+ img = PILImage.open(img_path)
1100
+ img.thumbnail(size, PILImage.Resampling.LANCZOS)
1101
+
1102
+ # Convert to bytes for display
1103
+ import io
1104
+
1105
+ img_bytes = io.BytesIO()
1106
+ img.save(img_bytes, format="PNG")
1107
+ img_bytes.seek(0)
1108
+
1109
+ # Create image widget
1110
+ img_widget = widgets.Image(value=img_bytes.read(), width=size[0], height=size[1])
1111
+ image_widgets.append(img_widget)
1112
+
1113
+ # Create horizontal box for this category
1114
+ category_box = widgets.VBox([header, widgets.HBox(image_widgets)])
1115
+ rows.append(category_box)
1116
+
1117
+ # Display all categories
1118
+ display(widgets.VBox(rows))
1119
+
1120
+ def forget(self, region: Optional["Region"] = None, delete: bool = False) -> None:
1121
+ """
1122
+ Clear training data, delete all files, or move a specific region to unlabeled.
1123
+
1124
+ Args:
1125
+ region: If provided, move this specific region to unlabeled
1126
+ delete: If True, permanently delete all files
1127
+ """
1128
+ # Handle specific region case
1129
+ if region is not None:
1130
+ # Get hash of the region
1131
+ try:
1132
+ img = region.render(crop=True)
1133
+ if not isinstance(img, Image.Image):
1134
+ img = Image.fromarray(img)
1135
+ if img.mode != "RGB":
1136
+ img = img.convert("RGB")
1137
+ img_array = np.array(img)
1138
+ img_hash = hashlib.md5(img_array.tobytes()).hexdigest()[:12]
1139
+ except Exception as e:
1140
+ logger.error(f"Failed to hash region: {e}")
1141
+ return
1142
+
1143
+ # Find and move the image
1144
+ moved = False
1145
+ for label in self.labels + ["_removed"]:
1146
+ source_path = self.root_dir / label / f"{img_hash}.png"
1147
+ if source_path.exists():
1148
+ target_path = self.root_dir / "unlabeled" / f"{img_hash}.png"
1149
+ shutil.move(str(source_path), str(target_path))
1150
+ print(f"Moved region from '{label}' to 'unlabeled'")
1151
+ moved = True
1152
+ break
1153
+
1154
+ if not moved:
1155
+ print(f"Region not found in training data")
1156
+ return
1157
+
1158
+ # Handle delete or clear training
1159
+ if delete:
1160
+ # Delete entire directory
1161
+ if self.root_dir.exists():
1162
+ shutil.rmtree(self.root_dir)
1163
+ print(f"Deleted all data for judge '{self.name}'")
1164
+ else:
1165
+ print(f"No data found for judge '{self.name}'")
1166
+
1167
+ # Reset internal state
1168
+ self.thresholds = {}
1169
+ self.metrics_info = {}
1170
+
1171
+ # Recreate directory structure
1172
+ self.root_dir.mkdir(exist_ok=True)
1173
+ for label in self.labels:
1174
+ (self.root_dir / label).mkdir(exist_ok=True)
1175
+ (self.root_dir / "unlabeled").mkdir(exist_ok=True)
1176
+ (self.root_dir / "_removed").mkdir(exist_ok=True)
1177
+
1178
+ else:
1179
+ # Just clear training (move everything to unlabeled)
1180
+ moved_count = 0
1181
+
1182
+ # Move all labeled images back to unlabeled
1183
+ unlabeled_dir = self.root_dir / "unlabeled"
1184
+ for label in self.labels:
1185
+ label_dir = self.root_dir / label
1186
+ if label_dir.exists():
1187
+ for img_path in label_dir.glob("*.png"):
1188
+ shutil.move(str(img_path), str(unlabeled_dir / img_path.name))
1189
+ moved_count += 1
1190
+
1191
+ # Clear thresholds
1192
+ self.thresholds = {}
1193
+ self.metrics_info = {}
1194
+
1195
+ # Remove saved config
1196
+ if self.config_path.exists():
1197
+ self.config_path.unlink()
1198
+
1199
+ print(f"Moved {moved_count} labeled images back to unlabeled.")
1200
+ print("Training data cleared. Judge is now untrained.")
1201
+
1202
+ def save(self, path: Optional[str] = None) -> None:
1203
+ """
1204
+ Save the judge configuration (auto-retrains first).
1205
+
1206
+ Args:
1207
+ path: Optional path to save to. Defaults to judge.json in root directory
1208
+ """
1209
+ # Retrain with current examples
1210
+ self._retrain()
1211
+
1212
+ # Save config
1213
+ save_path = Path(path) if path else self.config_path
1214
+
1215
+ config = {
1216
+ "name": self.name,
1217
+ "labels": self.labels,
1218
+ "target_prior": self.target_prior,
1219
+ "thresholds": self.thresholds,
1220
+ "metrics_info": self.metrics_info,
1221
+ "training_counts": self._get_training_counts(),
1222
+ }
1223
+
1224
+ with open(save_path, "w") as f:
1225
+ json.dump(config, f, indent=2)
1226
+
1227
+ logger.info(f"Saved judge to {save_path}")
1228
+
1229
+ @classmethod
1230
+ def load(cls, path: str) -> "Judge":
1231
+ """
1232
+ Load a judge from a saved configuration.
1233
+
1234
+ Args:
1235
+ path: Path to the saved judge.json file or the judge directory
1236
+
1237
+ Returns:
1238
+ Loaded Judge instance
1239
+ """
1240
+ path = Path(path)
1241
+
1242
+ # If path is a directory, look for judge.json inside
1243
+ if path.is_dir():
1244
+ config_path = path / "judge.json"
1245
+ base_dir = path.parent
1246
+ name = path.name
1247
+ else:
1248
+ config_path = path
1249
+ base_dir = path.parent.parent if path.parent.name != "." else path.parent
1250
+ # Try to infer name from path
1251
+ name = None
1252
+
1253
+ with open(config_path, "r") as f:
1254
+ config = json.load(f)
1255
+
1256
+ # Use saved name if we couldn't infer it
1257
+ if name is None:
1258
+ name = config["name"]
1259
+
1260
+ # Create judge with saved config
1261
+ judge = cls(
1262
+ name,
1263
+ labels=config["labels"],
1264
+ base_dir=base_dir,
1265
+ target_prior=config.get("target_prior", 0.5),
1266
+ ) # Default to 0.5 for old configs
1267
+ judge.thresholds = config["thresholds"]
1268
+ judge.metrics_info = config.get("metrics_info", {})
1269
+
1270
+ return judge
1271
+
1272
+ # Private methods
1273
+
1274
+ def _extract_metrics(self, region) -> Dict[str, float]:
1275
+ """Extract image metrics from a region."""
1276
+ try:
1277
+ img = region.render(crop=True)
1278
+ if not isinstance(img, Image.Image):
1279
+ img = Image.fromarray(img)
1280
+
1281
+ # Convert to grayscale for analysis
1282
+ gray = np.array(img.convert("L"))
1283
+
1284
+ metrics = {}
1285
+
1286
+ # 1. Center darkness
1287
+ h, w = gray.shape
1288
+ cy, cx = h // 2, w // 2
1289
+ center_size = min(5, h // 4, w // 4) # Adaptive center size
1290
+ center = gray[
1291
+ max(0, cy - center_size) : min(h, cy + center_size + 1),
1292
+ max(0, cx - center_size) : min(w, cx + center_size + 1),
1293
+ ]
1294
+ metrics["center_darkness"] = 255 - np.mean(center)
1295
+
1296
+ # 2. Overall darkness (ink density)
1297
+ metrics["ink_density"] = 255 - np.mean(gray)
1298
+
1299
+ # 3. Dark pixel ratio
1300
+ metrics["dark_pixel_ratio"] = np.sum(gray < 200) / gray.size
1301
+
1302
+ # 4. Standard deviation (complexity)
1303
+ metrics["std_dev"] = np.std(gray)
1304
+
1305
+ # 5. Edge vs center ratio
1306
+ edge_size = max(2, min(h // 10, w // 10))
1307
+ edge_mask = np.zeros_like(gray, dtype=bool)
1308
+ edge_mask[:edge_size, :] = True
1309
+ edge_mask[-edge_size:, :] = True
1310
+ edge_mask[:, :edge_size] = True
1311
+ edge_mask[:, -edge_size:] = True
1312
+
1313
+ edge_mean = np.mean(gray[edge_mask]) if np.any(edge_mask) else 255
1314
+ center_mean = np.mean(center)
1315
+ metrics["edge_center_ratio"] = edge_mean / (center_mean + 1)
1316
+
1317
+ # 6. Diagonal density (for X patterns)
1318
+ if h > 10 and w > 10:
1319
+ diag_mask = np.zeros_like(gray, dtype=bool)
1320
+ for i in range(min(h, w)):
1321
+ if i < h and i < w:
1322
+ diag_mask[i, i] = True
1323
+ diag_mask[i, w - 1 - i] = True
1324
+ metrics["diagonal_density"] = 255 - np.mean(gray[diag_mask])
1325
+ else:
1326
+ metrics["diagonal_density"] = metrics["ink_density"]
1327
+
1328
+ return metrics
1329
+
1330
+ except Exception as e:
1331
+ raise JudgeError(f"Failed to extract metrics: {e}")
1332
+
1333
+ def _retrain(self) -> None:
1334
+ """Retrain thresholds from current examples."""
1335
+ # Collect all examples
1336
+ examples = {label: [] for label in self.labels}
1337
+
1338
+ for label in self.labels:
1339
+ label_dir = self.root_dir / label
1340
+ for img_path in label_dir.glob("*.png"):
1341
+ img = Image.open(img_path)
1342
+ # Create a mock region that just returns the image
1343
+ mock_region = type("MockRegion", (), {"render": lambda self, crop=True: img})()
1344
+ metrics = self._extract_metrics(mock_region)
1345
+ examples[label].append(metrics)
1346
+
1347
+ # Check we have examples
1348
+ for label, exs in examples.items():
1349
+ if not exs:
1350
+ logger.warning(f"No examples for class '{label}'")
1351
+ return
1352
+
1353
+ # Check for class imbalance
1354
+ example_counts = {label: len(exs) for label, exs in examples.items()}
1355
+ max_count = max(example_counts.values())
1356
+ min_count = min(example_counts.values())
1357
+
1358
+ imbalance_ratio = max_count / min_count if min_count > 0 else float("inf")
1359
+ is_imbalanced = imbalance_ratio > 1.5 # Consider imbalanced if more than 1.5x difference
1360
+
1361
+ if is_imbalanced:
1362
+ logger.info(
1363
+ f"Class imbalance detected: {example_counts} (ratio {imbalance_ratio:.1f}:1)"
1364
+ )
1365
+ logger.info("Using balanced accuracy for threshold selection")
1366
+
1367
+ # Find best thresholds for each metric
1368
+ self.thresholds = {}
1369
+ self.metrics_info = {}
1370
+ metric_candidates = [] # Store all metrics with their scores
1371
+
1372
+ all_metrics = set()
1373
+ for exs in examples.values():
1374
+ for ex in exs:
1375
+ all_metrics.update(ex.keys())
1376
+
1377
+ for metric in all_metrics:
1378
+ # Get all values for this metric
1379
+ values_by_label = {}
1380
+ for label, exs in examples.items():
1381
+ values_by_label[label] = [ex.get(metric, 0) for ex in exs]
1382
+
1383
+ # Find threshold that best separates classes (for binary)
1384
+ if len(self.labels) == 2:
1385
+ label1, label2 = self.labels
1386
+ vals1 = values_by_label[label1]
1387
+ vals2 = values_by_label[label2]
1388
+
1389
+ # Try different thresholds
1390
+ all_vals = vals1 + vals2
1391
+ best_threshold = None
1392
+ best_accuracy = 0
1393
+ best_direction = None
1394
+
1395
+ for threshold in np.percentile(all_vals, [10, 20, 30, 40, 50, 60, 70, 80, 90]):
1396
+ # Test both directions
1397
+ for direction in ["higher", "lower"]:
1398
+ if direction == "higher":
1399
+ correct1 = sum(1 for v in vals1 if v > threshold)
1400
+ correct2 = sum(1 for v in vals2 if v <= threshold)
1401
+ else:
1402
+ correct1 = sum(1 for v in vals1 if v < threshold)
1403
+ correct2 = sum(1 for v in vals2 if v >= threshold)
1404
+
1405
+ # Always use balanced accuracy for threshold selection
1406
+ # This finds fair thresholds regardless of class imbalance
1407
+ acc1 = correct1 / len(vals1) if len(vals1) > 0 else 0
1408
+ acc2 = correct2 / len(vals2) if len(vals2) > 0 else 0
1409
+ accuracy = (acc1 + acc2) / 2
1410
+
1411
+ if accuracy > best_accuracy:
1412
+ best_accuracy = accuracy
1413
+ best_threshold = threshold
1414
+ best_direction = direction
1415
+
1416
+ # Calculate Youden's J statistic for weight (TPR - FPR)
1417
+ if best_direction == "higher":
1418
+ tp = sum(1 for v in vals1 if v > best_threshold)
1419
+ fn = len(vals1) - tp
1420
+ tn = sum(1 for v in vals2 if v <= best_threshold)
1421
+ fp = len(vals2) - tn
1422
+ else:
1423
+ tp = sum(1 for v in vals1 if v < best_threshold)
1424
+ fn = len(vals1) - tp
1425
+ tn = sum(1 for v in vals2 if v >= best_threshold)
1426
+ fp = len(vals2) - tn
1427
+
1428
+ tpr = tp / len(vals1) if len(vals1) > 0 else 0
1429
+ fpr = fp / len(vals2) if len(vals2) > 0 else 0
1430
+ youden_j = max(0.0, min(1.0, tpr - fpr))
1431
+
1432
+ # Store all candidates
1433
+ metric_candidates.append(
1434
+ {
1435
+ "metric": metric,
1436
+ "youden_j": youden_j,
1437
+ "selection_accuracy": best_accuracy,
1438
+ "threshold": best_threshold,
1439
+ "direction": best_direction,
1440
+ "label1": label1,
1441
+ "label2": label2,
1442
+ "stats": {
1443
+ "mean_" + label1: np.mean(vals1),
1444
+ "mean_" + label2: np.mean(vals2),
1445
+ "std_" + label1: np.std(vals1),
1446
+ "std_" + label2: np.std(vals2),
1447
+ },
1448
+ }
1449
+ )
1450
+
1451
+ # Sort by selection accuracy
1452
+ metric_candidates.sort(key=lambda x: x["selection_accuracy"], reverse=True)
1453
+
1454
+ # Use relaxed cutoff when imbalanced
1455
+ keep_cutoff = 0.55 if is_imbalanced else 0.60
1456
+
1457
+ # Keep metrics that pass cutoff, or top 3 if none pass
1458
+ kept_metrics = [m for m in metric_candidates if m["selection_accuracy"] > keep_cutoff]
1459
+ if not kept_metrics and metric_candidates:
1460
+ # Keep top 3 metrics even if they don't pass cutoff
1461
+ kept_metrics = metric_candidates[:3]
1462
+ logger.warning(
1463
+ f"No metrics passed cutoff {keep_cutoff}, keeping top {len(kept_metrics)} metrics"
1464
+ )
1465
+
1466
+ # Store selected metrics
1467
+ for candidate in kept_metrics:
1468
+ metric = candidate["metric"]
1469
+ label1 = candidate["label1"]
1470
+ label2 = candidate["label2"]
1471
+ self.thresholds[metric] = {
1472
+ "accuracy": candidate["youden_j"], # Use Youden's J as weight
1473
+ "selection_accuracy": candidate["selection_accuracy"],
1474
+ "thresholds": {
1475
+ label1: (candidate["threshold"], candidate["direction"]),
1476
+ label2: (
1477
+ candidate["threshold"],
1478
+ "lower" if candidate["direction"] == "higher" else "higher",
1479
+ ),
1480
+ },
1481
+ }
1482
+ self.metrics_info[metric] = candidate["stats"]
1483
+
1484
+ def _load_config(self) -> None:
1485
+ """Load configuration from file."""
1486
+ try:
1487
+ with open(self.config_path, "r") as f:
1488
+ config = json.load(f)
1489
+
1490
+ self.thresholds = config.get("thresholds", {})
1491
+ self.metrics_info = config.get("metrics_info", {})
1492
+
1493
+ # Verify labels match
1494
+ if config.get("labels") != self.labels:
1495
+ logger.warning(
1496
+ f"Saved labels {config.get('labels')} don't match current {self.labels}"
1497
+ )
1498
+
1499
+ except Exception as e:
1500
+ logger.warning(f"Failed to load config: {e}")
1501
+
1502
+ def _get_training_counts(self) -> Dict[str, int]:
1503
+ """Get count of examples per class."""
1504
+ counts = {}
1505
+ for label in self.labels:
1506
+ label_dir = self.root_dir / label
1507
+ counts[label] = len(list(label_dir.glob("*.png")))
1508
+ counts["unlabeled"] = len(list((self.root_dir / "unlabeled").glob("*.png")))
1509
+ return counts