adaptshot 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- adaptshot/__init__.py +19 -0
- adaptshot/config/__init__.py +0 -0
- adaptshot/config/settings.py +55 -0
- adaptshot/core/__init__.py +0 -0
- adaptshot/core/act.py +134 -0
- adaptshot/core/calibration.py +184 -0
- adaptshot/core/extractor.py +70 -0
- adaptshot/core/learner.py +328 -0
- adaptshot/core/similarity.py +118 -0
- adaptshot/data/__init__.py +0 -0
- adaptshot/evaluation/__init__.py +0 -0
- adaptshot/training/__init__.py +0 -0
- adaptshot/training/feedback_router.py +136 -0
- adaptshot/training/finetune.py +172 -0
- adaptshot/training/up_ugf.py +134 -0
- adaptshot/ui/__init__.py +0 -0
- adaptshot/ui/app.py +149 -0
- adaptshot/utils/__init__.py +0 -0
- adaptshot/utils/determinism.py +89 -0
- adaptshot/utils/io.py +99 -0
- adaptshot-0.1.0.dist-info/METADATA +289 -0
- adaptshot-0.1.0.dist-info/RECORD +25 -0
- adaptshot-0.1.0.dist-info/WHEEL +5 -0
- adaptshot-0.1.0.dist-info/licenses/LICENSE +1 -0
- adaptshot-0.1.0.dist-info/top_level.txt +1 -0
adaptshot/__init__.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
"""AdaptShot: Human-Aligned Few-Shot Vision Learning."""
|
|
2
|
+
|
|
3
|
+
__version__ = "0.1.0"
|
|
4
|
+
|
|
5
|
+
from .config.settings import AdaptShotConfig
|
|
6
|
+
from .core.learner import FewShotLearner
|
|
7
|
+
from .core.calibration import CalibrationEngine
|
|
8
|
+
from .core.act import ACTEngine
|
|
9
|
+
from .training.feedback_router import FeedbackRouter
|
|
10
|
+
from .training.up_ugf import UPUGFPruner
|
|
11
|
+
|
|
12
|
+
__all__ = [
|
|
13
|
+
"AdaptShotConfig",
|
|
14
|
+
"FewShotLearner",
|
|
15
|
+
"CalibrationEngine",
|
|
16
|
+
"ACTEngine",
|
|
17
|
+
"FeedbackRouter",
|
|
18
|
+
"UPUGFPruner",
|
|
19
|
+
]
|
|
File without changes
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
"""Immutable configuration dataclasses for AdaptShot."""
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Literal, Optional
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@dataclass(frozen=True)
|
|
10
|
+
class AdaptShotConfig:
|
|
11
|
+
"""
|
|
12
|
+
Central, immutable configuration for the AdaptShot pipeline.
|
|
13
|
+
|
|
14
|
+
Using a frozen dataclass guarantees that pipeline hyperparameters cannot be
|
|
15
|
+
accidentally mutated during inference or training, which is critical for
|
|
16
|
+
deterministic reproducibility and CI/CD validation.
|
|
17
|
+
"""
|
|
18
|
+
# Core execution
|
|
19
|
+
backbone: Literal["resnet18", "mobilenet_v3_small"] = "resnet18"
|
|
20
|
+
device: Literal["cpu", "cuda", "mps"] = "cpu" # CPU-first default
|
|
21
|
+
seed: int = 42
|
|
22
|
+
|
|
23
|
+
# Few-shot learning parameters
|
|
24
|
+
n_way: int = 5 # Number of classes per episode
|
|
25
|
+
k_shot: int = 10 # Support examples per class
|
|
26
|
+
query_size: int = 15 # Query examples per class for evaluation
|
|
27
|
+
|
|
28
|
+
# Similarity search
|
|
29
|
+
use_faiss: bool = False # Toggle FAISS-CPU acceleration
|
|
30
|
+
faiss_nprobe: int = 8 # FAISS IVF index probing depth (if used later)
|
|
31
|
+
|
|
32
|
+
# Calibration & uncertainty
|
|
33
|
+
calibration_method: Literal["temperature", "conformal", "none"] = "temperature"
|
|
34
|
+
ece_n_bins: int = 15 # Number of bins for Expected Calibration Error
|
|
35
|
+
temperature_init: float = 1.0
|
|
36
|
+
|
|
37
|
+
# Memory management (UP-UGF)
|
|
38
|
+
max_buffer_size: int = 100
|
|
39
|
+
|
|
40
|
+
# Logging & debugging
|
|
41
|
+
verbose: bool = True
|
|
42
|
+
log_dir: Optional[str] = None
|
|
43
|
+
|
|
44
|
+
def __post_init__(self) -> None:
|
|
45
|
+
"""Validate configuration constraints immediately after creation."""
|
|
46
|
+
if self.k_shot <= 0 or self.n_way <= 0:
|
|
47
|
+
raise ValueError("n_way and k_shot must be positive integers.")
|
|
48
|
+
if self.max_buffer_size < 10:
|
|
49
|
+
raise ValueError("max_buffer_size must be >= 10 for meaningful few-shot operation.")
|
|
50
|
+
if self.device == "cuda" and not torch.cuda.is_available():
|
|
51
|
+
import warnings
|
|
52
|
+
warnings.warn(
|
|
53
|
+
"CUDA requested but not available. Runtime logic will fall back to CPU.",
|
|
54
|
+
RuntimeWarning
|
|
55
|
+
)
|
|
File without changes
|
adaptshot/core/act.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
1
|
+
"""Adaptive Confidence Thresholding (ACT) engine for few-shot predictions.
|
|
2
|
+
|
|
3
|
+
Dynamically adjusts per-class decision thresholds based on real-time
|
|
4
|
+
correction history and model uncertainty, reducing false acceptances
|
|
5
|
+
by requesting human feedback when the model is genuinely unsure.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import logging
|
|
9
|
+
from typing import Dict, Tuple
|
|
10
|
+
|
|
11
|
+
import numpy as np
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class ACTEngine:
|
|
17
|
+
"""
|
|
18
|
+
Adaptive Confidence Thresholding engine.
|
|
19
|
+
|
|
20
|
+
Maintains a dynamic threshold τ_k for each class k that adapts based on:
|
|
21
|
+
- Historical correction rates (incorrect vs. correct)
|
|
22
|
+
- Model uncertainty signals (entropy/ECE proxies)
|
|
23
|
+
- Configurable cost of requesting human feedback (γ)
|
|
24
|
+
|
|
25
|
+
The engine implements an exponential moving average update rule to
|
|
26
|
+
prevent oscillation while remaining responsive to distribution shift.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
def __init__(
|
|
30
|
+
self,
|
|
31
|
+
base_threshold: float = 0.65,
|
|
32
|
+
learning_rate: float = 0.01,
|
|
33
|
+
feedback_cost_factor: float = 0.5,
|
|
34
|
+
min_threshold: float = 0.50,
|
|
35
|
+
max_threshold: float = 0.95,
|
|
36
|
+
n_classes: int = 100,
|
|
37
|
+
) -> None:
|
|
38
|
+
"""
|
|
39
|
+
Args:
|
|
40
|
+
base_threshold: Initial decision threshold for all classes
|
|
41
|
+
learning_rate: Step size for threshold adaptation (η)
|
|
42
|
+
feedback_cost_factor: Weight penalizing unnecessary human queries (γ)
|
|
43
|
+
min_threshold: Lower bound for τ_k
|
|
44
|
+
max_threshold: Upper bound for τ_k
|
|
45
|
+
n_classes: Preallocated number of class slots
|
|
46
|
+
"""
|
|
47
|
+
self.eta = learning_rate
|
|
48
|
+
self.gamma = feedback_cost_factor
|
|
49
|
+
self.min_threshold = min_threshold
|
|
50
|
+
self.max_threshold = max_threshold
|
|
51
|
+
|
|
52
|
+
# Per-class state: {class_idx: {"threshold": float, "correct": float, "incorrect": float, "total": float}}
|
|
53
|
+
self._class_state: Dict[int, Dict[str, float]] = {}
|
|
54
|
+
for k in range(n_classes):
|
|
55
|
+
self._class_state[k] = {
|
|
56
|
+
"threshold": base_threshold,
|
|
57
|
+
"correct": 0.0,
|
|
58
|
+
"incorrect": 0.0,
|
|
59
|
+
"total": 0.0,
|
|
60
|
+
}
|
|
61
|
+
|
|
62
|
+
def should_accept(
|
|
63
|
+
self,
|
|
64
|
+
confidence: float,
|
|
65
|
+
class_idx: int,
|
|
66
|
+
recent_incorrect_rate: float = 0.0,
|
|
67
|
+
recent_correct_rate: float = 1.0,
|
|
68
|
+
) -> Tuple[bool, str]:
|
|
69
|
+
"""
|
|
70
|
+
Evaluate whether to accept a prediction or request human feedback.
|
|
71
|
+
|
|
72
|
+
Args:
|
|
73
|
+
confidence: Calibrated confidence score [0, 1]
|
|
74
|
+
class_idx: Predicted class index
|
|
75
|
+
recent_incorrect_rate: Fraction of recent corrections that were wrong [0, 1]
|
|
76
|
+
recent_correct_rate: Fraction of recent confirmations that were right [0, 1]
|
|
77
|
+
|
|
78
|
+
Returns:
|
|
79
|
+
(accept: bool, action: str) where action is "ACCEPT" or "REQUEST_FEEDBACK"
|
|
80
|
+
"""
|
|
81
|
+
# Ensure class state exists (handles dynamic class expansion)
|
|
82
|
+
if class_idx not in self._class_state:
|
|
83
|
+
existing_thresholds = [s["threshold"] for s in self._class_state.values()]
|
|
84
|
+
default_thresh = float(np.mean(existing_thresholds)) if existing_thresholds else 0.65
|
|
85
|
+
self._class_state[class_idx] = {
|
|
86
|
+
"threshold": default_thresh,
|
|
87
|
+
"correct": 0.0,
|
|
88
|
+
"incorrect": 0.0,
|
|
89
|
+
"total": 0.0,
|
|
90
|
+
}
|
|
91
|
+
|
|
92
|
+
state = self._class_state[class_idx]
|
|
93
|
+
threshold = float(np.clip(state["threshold"], self.min_threshold, self.max_threshold))
|
|
94
|
+
|
|
95
|
+
# Update threshold: Δτ = η * (incorrect_rate - γ * correct_rate)
|
|
96
|
+
delta = self.eta * (recent_incorrect_rate - self.gamma * recent_correct_rate)
|
|
97
|
+
state["threshold"] = threshold + delta
|
|
98
|
+
|
|
99
|
+
# Update counters (EMA-style tracking)
|
|
100
|
+
state["total"] += 1.0
|
|
101
|
+
if recent_incorrect_rate > 0.5:
|
|
102
|
+
state["incorrect"] += 1.0
|
|
103
|
+
else:
|
|
104
|
+
state["correct"] += 1.0
|
|
105
|
+
|
|
106
|
+
accept = confidence >= threshold
|
|
107
|
+
action = "ACCEPT" if accept else "REQUEST_FEEDBACK"
|
|
108
|
+
|
|
109
|
+
logger.debug(
|
|
110
|
+
f"ACT | Class {class_idx} | Conf: {confidence:.3f} | τ: {threshold:.3f} | Action: {action}"
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
return accept, action
|
|
114
|
+
|
|
115
|
+
def get_threshold(self, class_idx: int) -> float:
|
|
116
|
+
"""Return the current adaptive threshold for a given class."""
|
|
117
|
+
if class_idx in self._class_state:
|
|
118
|
+
return float(np.clip(self._class_state[class_idx]["threshold"], self.min_threshold, self.max_threshold))
|
|
119
|
+
existing = [s["threshold"] for s in self._class_state.values()]
|
|
120
|
+
return float(np.clip(np.mean(existing), self.min_threshold, self.max_threshold)) if existing else 0.65
|
|
121
|
+
|
|
122
|
+
def get_all_thresholds(self) -> Dict[int, float]:
|
|
123
|
+
"""Return a snapshot of all current class thresholds."""
|
|
124
|
+
return {k: self.get_threshold(k) for k in self._class_state}
|
|
125
|
+
|
|
126
|
+
def reset_class(self, class_idx: int, base_threshold: float = 0.65) -> None:
|
|
127
|
+
"""Reset adaptation state for a specific class (e.g., after dataset reset)."""
|
|
128
|
+
if class_idx in self._class_state:
|
|
129
|
+
self._class_state[class_idx] = {
|
|
130
|
+
"threshold": base_threshold,
|
|
131
|
+
"correct": 0.0,
|
|
132
|
+
"incorrect": 0.0,
|
|
133
|
+
"total": 0.0,
|
|
134
|
+
}
|
|
@@ -0,0 +1,184 @@
|
|
|
1
|
+
"""Online calibration module for few-shot vision predictions.
|
|
2
|
+
|
|
3
|
+
Implements temperature scaling, Expected Calibration Error (ECE) tracking,
|
|
4
|
+
and a conformal prediction stub for high-stakes deployment contexts.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from typing import Dict, List, Optional, Tuple
|
|
8
|
+
|
|
9
|
+
import numpy as np
|
|
10
|
+
import torch
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class CalibrationEngine:
|
|
14
|
+
"""
|
|
15
|
+
Tracks prediction calibration and applies post-hoc scaling to raw confidence scores.
|
|
16
|
+
|
|
17
|
+
Designed for streaming few-shot evaluation where a held-out validation set is
|
|
18
|
+
unavailable or too small. Maintains a sliding window of recent predictions to
|
|
19
|
+
fit temperature parameters online.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
n_bins: int = 15,
|
|
25
|
+
window_size: int = 100,
|
|
26
|
+
temperature_init: float = 1.0,
|
|
27
|
+
method: str = "temperature",
|
|
28
|
+
) -> None:
|
|
29
|
+
"""
|
|
30
|
+
Args:
|
|
31
|
+
n_bins: Number of bins for ECE computation (default: 15)
|
|
32
|
+
window_size: Size of sliding window for online temperature fitting
|
|
33
|
+
temperature_init: Initial temperature value (1.0 = no scaling)
|
|
34
|
+
method: Calibration method ("temperature" or "conformal")
|
|
35
|
+
"""
|
|
36
|
+
self.n_bins = n_bins
|
|
37
|
+
self.window_size = window_size
|
|
38
|
+
self.temperature = torch.nn.Parameter(torch.tensor(temperature_init))
|
|
39
|
+
self.method = method
|
|
40
|
+
|
|
41
|
+
# Sliding window buffers
|
|
42
|
+
self._window_confidences: List[float] = []
|
|
43
|
+
self._window_correct: List[bool] = []
|
|
44
|
+
|
|
45
|
+
# ECE tracking state
|
|
46
|
+
self._ece_history: List[float] = []
|
|
47
|
+
|
|
48
|
+
def update(
|
|
49
|
+
self,
|
|
50
|
+
raw_confidence: float,
|
|
51
|
+
predicted_label: int,
|
|
52
|
+
true_label: int,
|
|
53
|
+
) -> None:
|
|
54
|
+
"""
|
|
55
|
+
Update calibration state with a new prediction and ground truth.
|
|
56
|
+
|
|
57
|
+
Maintains a fixed-size sliding window to enable online temperature fitting
|
|
58
|
+
without requiring a separate validation dataset.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
raw_confidence: Cosine similarity score (unnormalized, typically [-1, 1])
|
|
62
|
+
predicted_label: Predicted class index
|
|
63
|
+
true_label: Ground truth class index
|
|
64
|
+
"""
|
|
65
|
+
# Normalize raw confidence to [0, 1] for temperature scaling
|
|
66
|
+
norm_conf = (raw_confidence + 1.0) / 2.0
|
|
67
|
+
self._window_confidences.append(norm_conf)
|
|
68
|
+
self._window_correct.append(predicted_label == true_label)
|
|
69
|
+
|
|
70
|
+
# Maintain fixed window size
|
|
71
|
+
if len(self._window_confidences) > self.window_size:
|
|
72
|
+
self._window_confidences.pop(0)
|
|
73
|
+
self._window_correct.pop(0)
|
|
74
|
+
|
|
75
|
+
# Refit temperature if window is sufficiently populated
|
|
76
|
+
if len(self._window_confidences) >= max(10, self.window_size // 2):
|
|
77
|
+
self._refit_temperature()
|
|
78
|
+
|
|
79
|
+
# Update running ECE
|
|
80
|
+
current_ece = self.compute_ece(
|
|
81
|
+
np.array(self._window_confidences),
|
|
82
|
+
np.array(self._window_correct, dtype=int)
|
|
83
|
+
)
|
|
84
|
+
self._ece_history.append(current_ece)
|
|
85
|
+
|
|
86
|
+
def calibrate(self, raw_confidence: float) -> float:
|
|
87
|
+
"""
|
|
88
|
+
Apply calibration to a raw confidence score.
|
|
89
|
+
|
|
90
|
+
Args:
|
|
91
|
+
raw_confidence: Unnormalized cosine similarity score
|
|
92
|
+
|
|
93
|
+
Returns:
|
|
94
|
+
Calibrated confidence in [0, 1]
|
|
95
|
+
"""
|
|
96
|
+
if self.method == "conformal":
|
|
97
|
+
# Conformal stub: return conservative lower bound
|
|
98
|
+
return max(0.0, raw_confidence - 0.1)
|
|
99
|
+
|
|
100
|
+
# Temperature scaling
|
|
101
|
+
norm_conf = (raw_confidence + 1.0) / 2.0
|
|
102
|
+
# Clamp to prevent extreme scaling
|
|
103
|
+
norm_conf = np.clip(norm_conf, 1e-6, 1.0 - 1e-6)
|
|
104
|
+
# Apply temperature
|
|
105
|
+
logit = np.log(norm_conf / (1.0 - norm_conf))
|
|
106
|
+
scaled = 1.0 / (1.0 + np.exp(-logit / float(self.temperature)))
|
|
107
|
+
return float(np.clip(scaled, 0.0, 1.0))
|
|
108
|
+
|
|
109
|
+
def compute_ece(
|
|
110
|
+
self,
|
|
111
|
+
confidences: np.ndarray,
|
|
112
|
+
labels_correct: np.ndarray,
|
|
113
|
+
) -> float:
|
|
114
|
+
"""
|
|
115
|
+
Compute Expected Calibration Error (ECE) on a set of predictions.
|
|
116
|
+
|
|
117
|
+
ECE measures the gap between average confidence and average accuracy
|
|
118
|
+
across confidence bins. Lower is better; <0.05 is the target.
|
|
119
|
+
|
|
120
|
+
Args:
|
|
121
|
+
confidences: Array of predicted confidence scores in [0, 1]
|
|
122
|
+
labels_correct: Binary array (1 if correct, 0 if incorrect)
|
|
123
|
+
|
|
124
|
+
Returns:
|
|
125
|
+
ECE value in [0, 1]
|
|
126
|
+
"""
|
|
127
|
+
if len(confidences) == 0:
|
|
128
|
+
return 0.0
|
|
129
|
+
|
|
130
|
+
bin_boundaries = np.linspace(0.0, 1.0, self.n_bins + 1)
|
|
131
|
+
ece = 0.0
|
|
132
|
+
total_samples = len(confidences)
|
|
133
|
+
|
|
134
|
+
for i in range(self.n_bins):
|
|
135
|
+
in_bin = (confidences > bin_boundaries[i]) & (confidences <= bin_boundaries[i + 1])
|
|
136
|
+
prop_in_bin = in_bin.mean()
|
|
137
|
+
|
|
138
|
+
if prop_in_bin > 0:
|
|
139
|
+
avg_confidence = confidences[in_bin].mean()
|
|
140
|
+
avg_accuracy = labels_correct[in_bin].mean()
|
|
141
|
+
ece += np.abs(avg_accuracy - avg_confidence) * prop_in_bin
|
|
142
|
+
|
|
143
|
+
return float(ece)
|
|
144
|
+
|
|
145
|
+
def _refit_temperature(self) -> None:
|
|
146
|
+
"""
|
|
147
|
+
Refit temperature parameter using NLL loss on the sliding window.
|
|
148
|
+
Uses a simple gradient-free line search for stability on CPU.
|
|
149
|
+
"""
|
|
150
|
+
if len(self._window_confidences) < 10:
|
|
151
|
+
return
|
|
152
|
+
|
|
153
|
+
confs = np.array(self._window_confidences, dtype=np.float32)
|
|
154
|
+
correct = np.array(self._window_correct, dtype=np.float32)
|
|
155
|
+
|
|
156
|
+
# Clamp to prevent log(0)
|
|
157
|
+
confs = np.clip(confs, 1e-6, 1.0 - 1e-6)
|
|
158
|
+
logits = np.log(confs / (1.0 - confs))
|
|
159
|
+
|
|
160
|
+
# Grid search over reasonable temperature range [0.5, 3.0]
|
|
161
|
+
candidates = np.linspace(0.5, 3.0, 25)
|
|
162
|
+
best_loss = np.inf
|
|
163
|
+
best_T = float(self.temperature)
|
|
164
|
+
|
|
165
|
+
for T in candidates:
|
|
166
|
+
scaled_logits = logits / T
|
|
167
|
+
scaled_confs = 1.0 / (1.0 + np.exp(-scaled_logits))
|
|
168
|
+
# Binary cross-entropy loss
|
|
169
|
+
loss = -np.mean(correct * np.log(scaled_confs + 1e-6) + (1 - correct) * np.log(1.0 - scaled_confs + 1e-6))
|
|
170
|
+
if loss < best_loss:
|
|
171
|
+
best_loss = loss
|
|
172
|
+
best_T = T
|
|
173
|
+
|
|
174
|
+
self.temperature = torch.nn.Parameter(torch.tensor(best_T))
|
|
175
|
+
|
|
176
|
+
@property
|
|
177
|
+
def current_ece(self) -> float:
|
|
178
|
+
"""Return the most recently computed ECE."""
|
|
179
|
+
return self._ece_history[-1] if self._ece_history else 0.0
|
|
180
|
+
|
|
181
|
+
@property
|
|
182
|
+
def current_temperature(self) -> float:
|
|
183
|
+
"""Return the current temperature scaling parameter."""
|
|
184
|
+
return float(self.temperature)
|
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
"""Frozen backbone feature extraction with TorchScript compatibility."""
|
|
2
|
+
|
|
3
|
+
from typing import Union
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import torch
|
|
7
|
+
import torch.nn as nn
|
|
8
|
+
import torchvision.models as models
|
|
9
|
+
from PIL import Image
|
|
10
|
+
from torchvision import transforms
|
|
11
|
+
|
|
12
|
+
from ..config.settings import AdaptShotConfig
|
|
13
|
+
|
|
14
|
+
# Type alias for flexible image input
|
|
15
|
+
ImageInput = Union[str, np.ndarray, Image.Image, torch.Tensor]
|
|
16
|
+
|
|
17
|
+
# Registry for backbone factories (extensible without modifying core logic)
|
|
18
|
+
BackboneRegistry = {
|
|
19
|
+
"resnet18": lambda: models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1),
|
|
20
|
+
"mobilenet_v3_small": lambda: models.mobilenet_v3_small(
|
|
21
|
+
weights=models.MobileNet_V3_Small_Weights.IMAGENET1K_V1
|
|
22
|
+
),
|
|
23
|
+
}
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def _get_preprocess_transform(img_size: int = 224) -> transforms.Compose:
|
|
27
|
+
"""Return standard preprocessing transforms for ImageNet-pretrained backbones."""
|
|
28
|
+
return transforms.Compose([
|
|
29
|
+
transforms.Resize((img_size, img_size), interpolation=transforms.InterpolationMode.BILINEAR),
|
|
30
|
+
transforms.ToTensor(),
|
|
31
|
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
|
32
|
+
])
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def extract_embedding(
|
|
36
|
+
image: ImageInput,
|
|
37
|
+
config: AdaptShotConfig,
|
|
38
|
+
return_numpy: bool = True,
|
|
39
|
+
) -> Union[torch.Tensor, np.ndarray]:
|
|
40
|
+
"""Extract feature embedding from input image using a frozen backbone."""
|
|
41
|
+
# Load backbone from registry
|
|
42
|
+
if config.backbone not in BackboneRegistry:
|
|
43
|
+
raise ValueError(f"Unknown backbone: {config.backbone}. Available: {list(BackboneRegistry.keys())}")
|
|
44
|
+
|
|
45
|
+
backbone = BackboneRegistry[config.backbone]()
|
|
46
|
+
backbone.fc = nn.Identity()
|
|
47
|
+
backbone.to(config.device)
|
|
48
|
+
backbone.eval()
|
|
49
|
+
|
|
50
|
+
# Preprocess image
|
|
51
|
+
preprocess = _get_preprocess_transform()
|
|
52
|
+
|
|
53
|
+
# ✅ ADD THIS: Handle file paths
|
|
54
|
+
if isinstance(image, str):
|
|
55
|
+
image = Image.open(image).convert("RGB")
|
|
56
|
+
|
|
57
|
+
if isinstance(image, np.ndarray):
|
|
58
|
+
image = Image.fromarray(image)
|
|
59
|
+
elif isinstance(image, torch.Tensor):
|
|
60
|
+
if image.dim() == 3 and image.shape[0] not in (1, 3):
|
|
61
|
+
image = image.permute(2, 0, 1)
|
|
62
|
+
image = transforms.ToPILImage()(image.cpu())
|
|
63
|
+
|
|
64
|
+
# Apply transforms and add batch dimension
|
|
65
|
+
image_tensor = preprocess(image).unsqueeze(0).to(config.device)
|
|
66
|
+
|
|
67
|
+
with torch.no_grad():
|
|
68
|
+
embedding = backbone(image_tensor).squeeze(0)
|
|
69
|
+
|
|
70
|
+
return embedding.detach().cpu().numpy() if return_numpy else embedding
|