adaptshot 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
adaptshot/__init__.py ADDED
@@ -0,0 +1,19 @@
1
+ """AdaptShot: Human-Aligned Few-Shot Vision Learning."""
2
+
3
+ __version__ = "0.1.0"
4
+
5
+ from .config.settings import AdaptShotConfig
6
+ from .core.learner import FewShotLearner
7
+ from .core.calibration import CalibrationEngine
8
+ from .core.act import ACTEngine
9
+ from .training.feedback_router import FeedbackRouter
10
+ from .training.up_ugf import UPUGFPruner
11
+
12
+ __all__ = [
13
+ "AdaptShotConfig",
14
+ "FewShotLearner",
15
+ "CalibrationEngine",
16
+ "ACTEngine",
17
+ "FeedbackRouter",
18
+ "UPUGFPruner",
19
+ ]
File without changes
@@ -0,0 +1,55 @@
1
+ """Immutable configuration dataclasses for AdaptShot."""
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Literal, Optional
5
+
6
+ import torch
7
+
8
+
9
+ @dataclass(frozen=True)
10
+ class AdaptShotConfig:
11
+ """
12
+ Central, immutable configuration for the AdaptShot pipeline.
13
+
14
+ Using a frozen dataclass guarantees that pipeline hyperparameters cannot be
15
+ accidentally mutated during inference or training, which is critical for
16
+ deterministic reproducibility and CI/CD validation.
17
+ """
18
+ # Core execution
19
+ backbone: Literal["resnet18", "mobilenet_v3_small"] = "resnet18"
20
+ device: Literal["cpu", "cuda", "mps"] = "cpu" # CPU-first default
21
+ seed: int = 42
22
+
23
+ # Few-shot learning parameters
24
+ n_way: int = 5 # Number of classes per episode
25
+ k_shot: int = 10 # Support examples per class
26
+ query_size: int = 15 # Query examples per class for evaluation
27
+
28
+ # Similarity search
29
+ use_faiss: bool = False # Toggle FAISS-CPU acceleration
30
+ faiss_nprobe: int = 8 # FAISS IVF index probing depth (if used later)
31
+
32
+ # Calibration & uncertainty
33
+ calibration_method: Literal["temperature", "conformal", "none"] = "temperature"
34
+ ece_n_bins: int = 15 # Number of bins for Expected Calibration Error
35
+ temperature_init: float = 1.0
36
+
37
+ # Memory management (UP-UGF)
38
+ max_buffer_size: int = 100
39
+
40
+ # Logging & debugging
41
+ verbose: bool = True
42
+ log_dir: Optional[str] = None
43
+
44
+ def __post_init__(self) -> None:
45
+ """Validate configuration constraints immediately after creation."""
46
+ if self.k_shot <= 0 or self.n_way <= 0:
47
+ raise ValueError("n_way and k_shot must be positive integers.")
48
+ if self.max_buffer_size < 10:
49
+ raise ValueError("max_buffer_size must be >= 10 for meaningful few-shot operation.")
50
+ if self.device == "cuda" and not torch.cuda.is_available():
51
+ import warnings
52
+ warnings.warn(
53
+ "CUDA requested but not available. Runtime logic will fall back to CPU.",
54
+ RuntimeWarning
55
+ )
File without changes
adaptshot/core/act.py ADDED
@@ -0,0 +1,134 @@
1
+ """Adaptive Confidence Thresholding (ACT) engine for few-shot predictions.
2
+
3
+ Dynamically adjusts per-class decision thresholds based on real-time
4
+ correction history and model uncertainty, reducing false acceptances
5
+ by requesting human feedback when the model is genuinely unsure.
6
+ """
7
+
8
+ import logging
9
+ from typing import Dict, Tuple
10
+
11
+ import numpy as np
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ class ACTEngine:
17
+ """
18
+ Adaptive Confidence Thresholding engine.
19
+
20
+ Maintains a dynamic threshold τ_k for each class k that adapts based on:
21
+ - Historical correction rates (incorrect vs. correct)
22
+ - Model uncertainty signals (entropy/ECE proxies)
23
+ - Configurable cost of requesting human feedback (γ)
24
+
25
+ The engine implements an exponential moving average update rule to
26
+ prevent oscillation while remaining responsive to distribution shift.
27
+ """
28
+
29
+ def __init__(
30
+ self,
31
+ base_threshold: float = 0.65,
32
+ learning_rate: float = 0.01,
33
+ feedback_cost_factor: float = 0.5,
34
+ min_threshold: float = 0.50,
35
+ max_threshold: float = 0.95,
36
+ n_classes: int = 100,
37
+ ) -> None:
38
+ """
39
+ Args:
40
+ base_threshold: Initial decision threshold for all classes
41
+ learning_rate: Step size for threshold adaptation (η)
42
+ feedback_cost_factor: Weight penalizing unnecessary human queries (γ)
43
+ min_threshold: Lower bound for τ_k
44
+ max_threshold: Upper bound for τ_k
45
+ n_classes: Preallocated number of class slots
46
+ """
47
+ self.eta = learning_rate
48
+ self.gamma = feedback_cost_factor
49
+ self.min_threshold = min_threshold
50
+ self.max_threshold = max_threshold
51
+
52
+ # Per-class state: {class_idx: {"threshold": float, "correct": float, "incorrect": float, "total": float}}
53
+ self._class_state: Dict[int, Dict[str, float]] = {}
54
+ for k in range(n_classes):
55
+ self._class_state[k] = {
56
+ "threshold": base_threshold,
57
+ "correct": 0.0,
58
+ "incorrect": 0.0,
59
+ "total": 0.0,
60
+ }
61
+
62
+ def should_accept(
63
+ self,
64
+ confidence: float,
65
+ class_idx: int,
66
+ recent_incorrect_rate: float = 0.0,
67
+ recent_correct_rate: float = 1.0,
68
+ ) -> Tuple[bool, str]:
69
+ """
70
+ Evaluate whether to accept a prediction or request human feedback.
71
+
72
+ Args:
73
+ confidence: Calibrated confidence score [0, 1]
74
+ class_idx: Predicted class index
75
+ recent_incorrect_rate: Fraction of recent corrections that were wrong [0, 1]
76
+ recent_correct_rate: Fraction of recent confirmations that were right [0, 1]
77
+
78
+ Returns:
79
+ (accept: bool, action: str) where action is "ACCEPT" or "REQUEST_FEEDBACK"
80
+ """
81
+ # Ensure class state exists (handles dynamic class expansion)
82
+ if class_idx not in self._class_state:
83
+ existing_thresholds = [s["threshold"] for s in self._class_state.values()]
84
+ default_thresh = float(np.mean(existing_thresholds)) if existing_thresholds else 0.65
85
+ self._class_state[class_idx] = {
86
+ "threshold": default_thresh,
87
+ "correct": 0.0,
88
+ "incorrect": 0.0,
89
+ "total": 0.0,
90
+ }
91
+
92
+ state = self._class_state[class_idx]
93
+ threshold = float(np.clip(state["threshold"], self.min_threshold, self.max_threshold))
94
+
95
+ # Update threshold: Δτ = η * (incorrect_rate - γ * correct_rate)
96
+ delta = self.eta * (recent_incorrect_rate - self.gamma * recent_correct_rate)
97
+ state["threshold"] = threshold + delta
98
+
99
+ # Update counters (EMA-style tracking)
100
+ state["total"] += 1.0
101
+ if recent_incorrect_rate > 0.5:
102
+ state["incorrect"] += 1.0
103
+ else:
104
+ state["correct"] += 1.0
105
+
106
+ accept = confidence >= threshold
107
+ action = "ACCEPT" if accept else "REQUEST_FEEDBACK"
108
+
109
+ logger.debug(
110
+ f"ACT | Class {class_idx} | Conf: {confidence:.3f} | τ: {threshold:.3f} | Action: {action}"
111
+ )
112
+
113
+ return accept, action
114
+
115
+ def get_threshold(self, class_idx: int) -> float:
116
+ """Return the current adaptive threshold for a given class."""
117
+ if class_idx in self._class_state:
118
+ return float(np.clip(self._class_state[class_idx]["threshold"], self.min_threshold, self.max_threshold))
119
+ existing = [s["threshold"] for s in self._class_state.values()]
120
+ return float(np.clip(np.mean(existing), self.min_threshold, self.max_threshold)) if existing else 0.65
121
+
122
+ def get_all_thresholds(self) -> Dict[int, float]:
123
+ """Return a snapshot of all current class thresholds."""
124
+ return {k: self.get_threshold(k) for k in self._class_state}
125
+
126
+ def reset_class(self, class_idx: int, base_threshold: float = 0.65) -> None:
127
+ """Reset adaptation state for a specific class (e.g., after dataset reset)."""
128
+ if class_idx in self._class_state:
129
+ self._class_state[class_idx] = {
130
+ "threshold": base_threshold,
131
+ "correct": 0.0,
132
+ "incorrect": 0.0,
133
+ "total": 0.0,
134
+ }
@@ -0,0 +1,184 @@
1
+ """Online calibration module for few-shot vision predictions.
2
+
3
+ Implements temperature scaling, Expected Calibration Error (ECE) tracking,
4
+ and a conformal prediction stub for high-stakes deployment contexts.
5
+ """
6
+
7
+ from typing import Dict, List, Optional, Tuple
8
+
9
+ import numpy as np
10
+ import torch
11
+
12
+
13
+ class CalibrationEngine:
14
+ """
15
+ Tracks prediction calibration and applies post-hoc scaling to raw confidence scores.
16
+
17
+ Designed for streaming few-shot evaluation where a held-out validation set is
18
+ unavailable or too small. Maintains a sliding window of recent predictions to
19
+ fit temperature parameters online.
20
+ """
21
+
22
+ def __init__(
23
+ self,
24
+ n_bins: int = 15,
25
+ window_size: int = 100,
26
+ temperature_init: float = 1.0,
27
+ method: str = "temperature",
28
+ ) -> None:
29
+ """
30
+ Args:
31
+ n_bins: Number of bins for ECE computation (default: 15)
32
+ window_size: Size of sliding window for online temperature fitting
33
+ temperature_init: Initial temperature value (1.0 = no scaling)
34
+ method: Calibration method ("temperature" or "conformal")
35
+ """
36
+ self.n_bins = n_bins
37
+ self.window_size = window_size
38
+ self.temperature = torch.nn.Parameter(torch.tensor(temperature_init))
39
+ self.method = method
40
+
41
+ # Sliding window buffers
42
+ self._window_confidences: List[float] = []
43
+ self._window_correct: List[bool] = []
44
+
45
+ # ECE tracking state
46
+ self._ece_history: List[float] = []
47
+
48
+ def update(
49
+ self,
50
+ raw_confidence: float,
51
+ predicted_label: int,
52
+ true_label: int,
53
+ ) -> None:
54
+ """
55
+ Update calibration state with a new prediction and ground truth.
56
+
57
+ Maintains a fixed-size sliding window to enable online temperature fitting
58
+ without requiring a separate validation dataset.
59
+
60
+ Args:
61
+ raw_confidence: Cosine similarity score (unnormalized, typically [-1, 1])
62
+ predicted_label: Predicted class index
63
+ true_label: Ground truth class index
64
+ """
65
+ # Normalize raw confidence to [0, 1] for temperature scaling
66
+ norm_conf = (raw_confidence + 1.0) / 2.0
67
+ self._window_confidences.append(norm_conf)
68
+ self._window_correct.append(predicted_label == true_label)
69
+
70
+ # Maintain fixed window size
71
+ if len(self._window_confidences) > self.window_size:
72
+ self._window_confidences.pop(0)
73
+ self._window_correct.pop(0)
74
+
75
+ # Refit temperature if window is sufficiently populated
76
+ if len(self._window_confidences) >= max(10, self.window_size // 2):
77
+ self._refit_temperature()
78
+
79
+ # Update running ECE
80
+ current_ece = self.compute_ece(
81
+ np.array(self._window_confidences),
82
+ np.array(self._window_correct, dtype=int)
83
+ )
84
+ self._ece_history.append(current_ece)
85
+
86
+ def calibrate(self, raw_confidence: float) -> float:
87
+ """
88
+ Apply calibration to a raw confidence score.
89
+
90
+ Args:
91
+ raw_confidence: Unnormalized cosine similarity score
92
+
93
+ Returns:
94
+ Calibrated confidence in [0, 1]
95
+ """
96
+ if self.method == "conformal":
97
+ # Conformal stub: return conservative lower bound
98
+ return max(0.0, raw_confidence - 0.1)
99
+
100
+ # Temperature scaling
101
+ norm_conf = (raw_confidence + 1.0) / 2.0
102
+ # Clamp to prevent extreme scaling
103
+ norm_conf = np.clip(norm_conf, 1e-6, 1.0 - 1e-6)
104
+ # Apply temperature
105
+ logit = np.log(norm_conf / (1.0 - norm_conf))
106
+ scaled = 1.0 / (1.0 + np.exp(-logit / float(self.temperature)))
107
+ return float(np.clip(scaled, 0.0, 1.0))
108
+
109
+ def compute_ece(
110
+ self,
111
+ confidences: np.ndarray,
112
+ labels_correct: np.ndarray,
113
+ ) -> float:
114
+ """
115
+ Compute Expected Calibration Error (ECE) on a set of predictions.
116
+
117
+ ECE measures the gap between average confidence and average accuracy
118
+ across confidence bins. Lower is better; <0.05 is the target.
119
+
120
+ Args:
121
+ confidences: Array of predicted confidence scores in [0, 1]
122
+ labels_correct: Binary array (1 if correct, 0 if incorrect)
123
+
124
+ Returns:
125
+ ECE value in [0, 1]
126
+ """
127
+ if len(confidences) == 0:
128
+ return 0.0
129
+
130
+ bin_boundaries = np.linspace(0.0, 1.0, self.n_bins + 1)
131
+ ece = 0.0
132
+ total_samples = len(confidences)
133
+
134
+ for i in range(self.n_bins):
135
+ in_bin = (confidences > bin_boundaries[i]) & (confidences <= bin_boundaries[i + 1])
136
+ prop_in_bin = in_bin.mean()
137
+
138
+ if prop_in_bin > 0:
139
+ avg_confidence = confidences[in_bin].mean()
140
+ avg_accuracy = labels_correct[in_bin].mean()
141
+ ece += np.abs(avg_accuracy - avg_confidence) * prop_in_bin
142
+
143
+ return float(ece)
144
+
145
+ def _refit_temperature(self) -> None:
146
+ """
147
+ Refit temperature parameter using NLL loss on the sliding window.
148
+ Uses a simple gradient-free line search for stability on CPU.
149
+ """
150
+ if len(self._window_confidences) < 10:
151
+ return
152
+
153
+ confs = np.array(self._window_confidences, dtype=np.float32)
154
+ correct = np.array(self._window_correct, dtype=np.float32)
155
+
156
+ # Clamp to prevent log(0)
157
+ confs = np.clip(confs, 1e-6, 1.0 - 1e-6)
158
+ logits = np.log(confs / (1.0 - confs))
159
+
160
+ # Grid search over reasonable temperature range [0.5, 3.0]
161
+ candidates = np.linspace(0.5, 3.0, 25)
162
+ best_loss = np.inf
163
+ best_T = float(self.temperature)
164
+
165
+ for T in candidates:
166
+ scaled_logits = logits / T
167
+ scaled_confs = 1.0 / (1.0 + np.exp(-scaled_logits))
168
+ # Binary cross-entropy loss
169
+ loss = -np.mean(correct * np.log(scaled_confs + 1e-6) + (1 - correct) * np.log(1.0 - scaled_confs + 1e-6))
170
+ if loss < best_loss:
171
+ best_loss = loss
172
+ best_T = T
173
+
174
+ self.temperature = torch.nn.Parameter(torch.tensor(best_T))
175
+
176
+ @property
177
+ def current_ece(self) -> float:
178
+ """Return the most recently computed ECE."""
179
+ return self._ece_history[-1] if self._ece_history else 0.0
180
+
181
+ @property
182
+ def current_temperature(self) -> float:
183
+ """Return the current temperature scaling parameter."""
184
+ return float(self.temperature)
@@ -0,0 +1,70 @@
1
+ """Frozen backbone feature extraction with TorchScript compatibility."""
2
+
3
+ from typing import Union
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn as nn
8
+ import torchvision.models as models
9
+ from PIL import Image
10
+ from torchvision import transforms
11
+
12
+ from ..config.settings import AdaptShotConfig
13
+
14
+ # Type alias for flexible image input
15
+ ImageInput = Union[str, np.ndarray, Image.Image, torch.Tensor]
16
+
17
+ # Registry for backbone factories (extensible without modifying core logic)
18
+ BackboneRegistry = {
19
+ "resnet18": lambda: models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1),
20
+ "mobilenet_v3_small": lambda: models.mobilenet_v3_small(
21
+ weights=models.MobileNet_V3_Small_Weights.IMAGENET1K_V1
22
+ ),
23
+ }
24
+
25
+
26
+ def _get_preprocess_transform(img_size: int = 224) -> transforms.Compose:
27
+ """Return standard preprocessing transforms for ImageNet-pretrained backbones."""
28
+ return transforms.Compose([
29
+ transforms.Resize((img_size, img_size), interpolation=transforms.InterpolationMode.BILINEAR),
30
+ transforms.ToTensor(),
31
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
32
+ ])
33
+
34
+
35
+ def extract_embedding(
36
+ image: ImageInput,
37
+ config: AdaptShotConfig,
38
+ return_numpy: bool = True,
39
+ ) -> Union[torch.Tensor, np.ndarray]:
40
+ """Extract feature embedding from input image using a frozen backbone."""
41
+ # Load backbone from registry
42
+ if config.backbone not in BackboneRegistry:
43
+ raise ValueError(f"Unknown backbone: {config.backbone}. Available: {list(BackboneRegistry.keys())}")
44
+
45
+ backbone = BackboneRegistry[config.backbone]()
46
+ backbone.fc = nn.Identity()
47
+ backbone.to(config.device)
48
+ backbone.eval()
49
+
50
+ # Preprocess image
51
+ preprocess = _get_preprocess_transform()
52
+
53
+ # ✅ ADD THIS: Handle file paths
54
+ if isinstance(image, str):
55
+ image = Image.open(image).convert("RGB")
56
+
57
+ if isinstance(image, np.ndarray):
58
+ image = Image.fromarray(image)
59
+ elif isinstance(image, torch.Tensor):
60
+ if image.dim() == 3 and image.shape[0] not in (1, 3):
61
+ image = image.permute(2, 0, 1)
62
+ image = transforms.ToPILImage()(image.cpu())
63
+
64
+ # Apply transforms and add batch dimension
65
+ image_tensor = preprocess(image).unsqueeze(0).to(config.device)
66
+
67
+ with torch.no_grad():
68
+ embedding = backbone(image_tensor).squeeze(0)
69
+
70
+ return embedding.detach().cpu().numpy() if return_numpy else embedding