adaptshot 0.1.1__tar.gz → 0.2.0.dev0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {adaptshot-0.1.1/src/adaptshot.egg-info → adaptshot-0.2.0.dev0}/PKG-INFO +48 -22
- {adaptshot-0.1.1 → adaptshot-0.2.0.dev0}/README.md +41 -16
- {adaptshot-0.1.1 → adaptshot-0.2.0.dev0}/pyproject.toml +16 -7
- {adaptshot-0.1.1 → adaptshot-0.2.0.dev0}/src/adaptshot/__init__.py +14 -1
- {adaptshot-0.1.1 → adaptshot-0.2.0.dev0}/src/adaptshot/config/settings.py +32 -9
- {adaptshot-0.1.1 → adaptshot-0.2.0.dev0}/src/adaptshot/core/act.py +24 -5
- adaptshot-0.2.0.dev0/src/adaptshot/core/backends/__init__.py +11 -0
- adaptshot-0.2.0.dev0/src/adaptshot/core/backends/onnx_backend.py +95 -0
- {adaptshot-0.1.1 → adaptshot-0.2.0.dev0}/src/adaptshot/core/calibration.py +2 -3
- adaptshot-0.2.0.dev0/src/adaptshot/core/conformal.py +436 -0
- adaptshot-0.2.0.dev0/src/adaptshot/core/contrastive.py +512 -0
- adaptshot-0.2.0.dev0/src/adaptshot/core/explain.py +586 -0
- {adaptshot-0.1.1 → adaptshot-0.2.0.dev0}/src/adaptshot/core/extractor.py +103 -28
- {adaptshot-0.1.1 → adaptshot-0.2.0.dev0}/src/adaptshot/core/learner.py +427 -37
- adaptshot-0.2.0.dev0/src/adaptshot/core/uncertainty.py +569 -0
- adaptshot-0.2.0.dev0/src/adaptshot/data/__init__.py +5 -0
- adaptshot-0.2.0.dev0/src/adaptshot/data/mobilenet_v3_small.onnx +0 -0
- adaptshot-0.2.0.dev0/src/adaptshot/data/mobilenet_v3_small.onnx.data +0 -0
- adaptshot-0.2.0.dev0/src/adaptshot/data/resnet18.onnx +0 -0
- adaptshot-0.2.0.dev0/src/adaptshot/data/resnet18.onnx.data +0 -0
- {adaptshot-0.1.1 → adaptshot-0.2.0.dev0}/src/adaptshot/studio/app.py +1 -1
- {adaptshot-0.1.1 → adaptshot-0.2.0.dev0}/src/adaptshot/training/finetune.py +36 -11
- {adaptshot-0.1.1 → adaptshot-0.2.0.dev0}/src/adaptshot/training/up_ugf.py +34 -6
- {adaptshot-0.1.1 → adaptshot-0.2.0.dev0}/src/adaptshot/ui/app.py +1 -1
- adaptshot-0.2.0.dev0/src/adaptshot/utils/profiling.py +166 -0
- {adaptshot-0.1.1 → adaptshot-0.2.0.dev0/src/adaptshot.egg-info}/PKG-INFO +48 -22
- {adaptshot-0.1.1 → adaptshot-0.2.0.dev0}/src/adaptshot.egg-info/SOURCES.txt +16 -1
- {adaptshot-0.1.1 → adaptshot-0.2.0.dev0}/src/adaptshot.egg-info/requires.txt +6 -3
- adaptshot-0.2.0.dev0/tests/test_conformal.py +132 -0
- adaptshot-0.2.0.dev0/tests/test_contrastive.py +100 -0
- {adaptshot-0.1.1 → adaptshot-0.2.0.dev0}/tests/test_exceptions.py +15 -6
- adaptshot-0.2.0.dev0/tests/test_explain.py +173 -0
- {adaptshot-0.1.1 → adaptshot-0.2.0.dev0}/tests/test_extractor.py +13 -6
- {adaptshot-0.1.1 → adaptshot-0.2.0.dev0}/tests/test_persistence.py +1 -1
- {adaptshot-0.1.1 → adaptshot-0.2.0.dev0}/tests/test_release_metadata.py +3 -3
- adaptshot-0.2.0.dev0/tests/test_uncertainty.py +156 -0
- adaptshot-0.1.1/src/adaptshot/utils/__init__.py +0 -0
- {adaptshot-0.1.1 → adaptshot-0.2.0.dev0}/LICENSE +0 -0
- {adaptshot-0.1.1 → adaptshot-0.2.0.dev0}/setup.cfg +0 -0
- {adaptshot-0.1.1 → adaptshot-0.2.0.dev0}/src/adaptshot/config/__init__.py +0 -0
- {adaptshot-0.1.1 → adaptshot-0.2.0.dev0}/src/adaptshot/core/__init__.py +0 -0
- {adaptshot-0.1.1 → adaptshot-0.2.0.dev0}/src/adaptshot/core/similarity.py +0 -0
- {adaptshot-0.1.1/src/adaptshot/data → adaptshot-0.2.0.dev0/src/adaptshot/evaluation}/__init__.py +0 -0
- {adaptshot-0.1.1 → adaptshot-0.2.0.dev0}/src/adaptshot/studio/__init__.py +0 -0
- {adaptshot-0.1.1 → adaptshot-0.2.0.dev0}/src/adaptshot/studio/utils.py +0 -0
- {adaptshot-0.1.1/src/adaptshot/evaluation → adaptshot-0.2.0.dev0/src/adaptshot/training}/__init__.py +0 -0
- {adaptshot-0.1.1 → adaptshot-0.2.0.dev0}/src/adaptshot/training/feedback_router.py +0 -0
- {adaptshot-0.1.1/src/adaptshot/training → adaptshot-0.2.0.dev0/src/adaptshot/ui}/__init__.py +0 -0
- {adaptshot-0.1.1/src/adaptshot/ui → adaptshot-0.2.0.dev0/src/adaptshot/utils}/__init__.py +0 -0
- {adaptshot-0.1.1 → adaptshot-0.2.0.dev0}/src/adaptshot/utils/determinism.py +0 -0
- {adaptshot-0.1.1 → adaptshot-0.2.0.dev0}/src/adaptshot/utils/exceptions.py +0 -0
- {adaptshot-0.1.1 → adaptshot-0.2.0.dev0}/src/adaptshot/utils/io.py +0 -0
- {adaptshot-0.1.1 → adaptshot-0.2.0.dev0}/src/adaptshot/utils/migrations.py +0 -0
- {adaptshot-0.1.1 → adaptshot-0.2.0.dev0}/src/adaptshot.egg-info/dependency_links.txt +0 -0
- {adaptshot-0.1.1 → adaptshot-0.2.0.dev0}/src/adaptshot.egg-info/entry_points.txt +0 -0
- {adaptshot-0.1.1 → adaptshot-0.2.0.dev0}/src/adaptshot.egg-info/top_level.txt +0 -0
- {adaptshot-0.1.1 → adaptshot-0.2.0.dev0}/tests/test_calibration.py +0 -0
- {adaptshot-0.1.1 → adaptshot-0.2.0.dev0}/tests/test_feedback_router.py +0 -0
- {adaptshot-0.1.1 → adaptshot-0.2.0.dev0}/tests/test_learner_integration.py +0 -0
- {adaptshot-0.1.1 → adaptshot-0.2.0.dev0}/tests/test_similarity.py +0 -0
- {adaptshot-0.1.1 → adaptshot-0.2.0.dev0}/tests/test_studio_utils.py +0 -0
|
@@ -1,16 +1,15 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: adaptshot
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.2.0.dev0
|
|
4
4
|
Summary: Human-aligned few-shot vision learning for resource-constrained environments
|
|
5
5
|
Author-email: Johnson Christopher Hassan <johnson2006christopher@gmail.com>
|
|
6
|
-
License: MIT
|
|
6
|
+
License-Expression: MIT
|
|
7
7
|
Project-URL: Homepage, https://github.com/johnson2006christopher/adaptshot
|
|
8
8
|
Project-URL: Documentation, https://github.com/johnson2006christopher/adaptshot/docs
|
|
9
9
|
Project-URL: Repository, https://github.com/johnson2006christopher/adaptshot.git
|
|
10
10
|
Project-URL: Bug Tracker, https://github.com/johnson2006christopher/adaptshot/issues
|
|
11
11
|
Classifier: Development Status :: 3 - Alpha
|
|
12
12
|
Classifier: Intended Audience :: Science/Research
|
|
13
|
-
Classifier: License :: OSI Approved :: MIT License
|
|
14
13
|
Classifier: Programming Language :: Python :: 3.9
|
|
15
14
|
Classifier: Programming Language :: Python :: 3.10
|
|
16
15
|
Classifier: Programming Language :: Python :: 3.11
|
|
@@ -19,10 +18,11 @@ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
|
19
18
|
Requires-Python: >=3.9
|
|
20
19
|
Description-Content-Type: text/markdown
|
|
21
20
|
License-File: LICENSE
|
|
22
|
-
Requires-Dist: torch>=2.0.0
|
|
23
|
-
Requires-Dist: torchvision>=0.15.0
|
|
24
21
|
Requires-Dist: numpy>=1.24.0
|
|
25
22
|
Requires-Dist: Pillow>=9.0.0
|
|
23
|
+
Provides-Extra: torch
|
|
24
|
+
Requires-Dist: torch>=2.0.0; extra == "torch"
|
|
25
|
+
Requires-Dist: torchvision>=0.15.0; extra == "torch"
|
|
26
26
|
Provides-Extra: faiss
|
|
27
27
|
Requires-Dist: faiss-cpu>=1.7.4; extra == "faiss"
|
|
28
28
|
Provides-Extra: ui
|
|
@@ -32,6 +32,7 @@ Requires-Dist: gradio>=3.50.0; extra == "gui"
|
|
|
32
32
|
Requires-Dist: pandas>=2.0.0; extra == "gui"
|
|
33
33
|
Requires-Dist: onnx>=1.16.0; extra == "gui"
|
|
34
34
|
Requires-Dist: onnxscript>=0.3.0; extra == "gui"
|
|
35
|
+
Requires-Dist: onnxruntime>=1.14.0; extra == "gui"
|
|
35
36
|
Provides-Extra: dev
|
|
36
37
|
Requires-Dist: pytest>=7.4.0; extra == "dev"
|
|
37
38
|
Requires-Dist: pytest-cov>=4.1.0; extra == "dev"
|
|
@@ -39,7 +40,7 @@ Requires-Dist: mypy>=1.5.0; extra == "dev"
|
|
|
39
40
|
Requires-Dist: ruff>=0.1.0; extra == "dev"
|
|
40
41
|
Requires-Dist: pre-commit>=3.4.0; extra == "dev"
|
|
41
42
|
Provides-Extra: all
|
|
42
|
-
Requires-Dist: adaptshot[dev,faiss,gui]; extra == "all"
|
|
43
|
+
Requires-Dist: adaptshot[dev,faiss,gui,torch]; extra == "all"
|
|
43
44
|
Dynamic: license-file
|
|
44
45
|
|
|
45
46
|
|
|
@@ -71,16 +72,19 @@ Dynamic: license-file
|
|
|
71
72
|
|
|
72
73
|
AdaptShot is a high-performance, CPU-optimized, human-in-the-loop few-shot vision library. It is designed to learn from every human correction, guarantee calibrated uncertainty, and run deterministically on edge hardware with minimal resources.
|
|
73
74
|
|
|
74
|
-
v0.
|
|
75
|
+
v0.2.0-dev is the current release, hardened with 92 regression tests, strict type-checking, and a comprehensive benchmark harness. Built in Tanzania by a self-taught engineer with nothing but a laptop and determination.
|
|
75
76
|
|
|
76
77
|
</div>
|
|
77
78
|
|
|
78
79
|
## 🚀 Key Features
|
|
79
80
|
|
|
80
81
|
* **CPU-First by Design**: Optimized for low-latency inference on standard CPUs, requiring less than 250MB of RAM.
|
|
81
|
-
* **Trustworthy & Calibrated**: Built-in **Expected Calibration Error (ECE)** minimization
|
|
82
|
+
* **Trustworthy & Calibrated**: Built-in **Expected Calibration Error (ECE)** minimization and **conformal prediction** with finite-sample coverage guarantees.
|
|
82
83
|
* **Human-in-the-Loop**: Integrated **FeedbackRouter** for real-time model adaptation through human expert corrections.
|
|
83
|
-
* **Continual Learning**: Implements **CA-EWC** (
|
|
84
|
+
* **Continual Learning**: Implements **head-only CA-EWC** (Fisher-regularized classification head fine-tuning, ~2K parameters) and **UP-UGF** (Uncertainty-Guided Forgetting with LSH-accelerated redundancy scoring) for stable, long-term learning without catastrophic forgetting.
|
|
85
|
+
* **Multi-Signal Uncertainty**: Epistemic (stochastic embedding perturbation sensitivity), aleatoric (k-NN entropy), and distributional (shrinkage-regularized Mahalanobis distance) uncertainty quantification with OOD detection. *(Full MC Dropout planned for future torch-dependent release.)*
|
|
86
|
+
* **Explainable AI**: Embedding-space feature attribution (which support examples influenced the prediction), confidence decomposition with historical penalty tracking, counterfactual explanations, and per-dimension saliency analysis.
|
|
87
|
+
* **Contrastive Prototypes**: Gradient-trained class representations via InfoNCE contrastive loss with 2-layer MLP projection head and EMA momentum prototype refinement.
|
|
84
88
|
* **Release Hardened**: Zero-config API, strict type safety, and a comprehensive benchmark harness for review and deployment readiness.
|
|
85
89
|
* **Deterministic**: Guaranteed reproducible results across different runs and hardware through strict seed management.
|
|
86
90
|
|
|
@@ -106,10 +110,19 @@ $ pip install adaptshot
|
|
|
106
110
|
|
|
107
111
|
</div>
|
|
108
112
|
|
|
113
|
+
**Core dependencies**: numpy, Pillow (~15 MB total). PyTorch is **optional** — install it only if you need training/fine-tuning:
|
|
114
|
+
|
|
115
|
+
```bash
|
|
116
|
+
pip install "adaptshot[torch]" # adds PyTorch + torchvision for training
|
|
117
|
+
```
|
|
118
|
+
|
|
119
|
+
> **Fast install**: The base library installs in under 60 seconds on standard connections — no GPU drivers, no CUDA, no 2 GB downloads.
|
|
120
|
+
|
|
109
121
|
### Optional Dependencies
|
|
110
122
|
|
|
111
123
|
AdaptShot provides optional extras for specialized workflows. The native Python API remains the source of truth; the GUI is an optional wrapper around it:
|
|
112
124
|
|
|
125
|
+
* **PyTorch (Training)**: `pip install "adaptshot[torch]"` (Required for CA-EWC fine-tuning and custom backbones)
|
|
113
126
|
* **FAISS Acceleration**: `pip install "adaptshot[faiss]"` (Recommended for support sets >100 images)
|
|
114
127
|
* **Gradio UI**: `pip install "adaptshot[ui]"` (For interactive pilots and human-in-the-loop dashboards)
|
|
115
128
|
* **Studio GUI**: `pip install "adaptshot[gui]"` (For the offline, folder-aware AdaptShot Studio workspace)
|
|
@@ -155,15 +168,21 @@ if result.uncertainty_flag:
|
|
|
155
168
|
|
|
156
169
|
---
|
|
157
170
|
|
|
158
|
-
## 🆕 What's New in v0.
|
|
159
|
-
|
|
160
|
-
- **
|
|
161
|
-
- **
|
|
162
|
-
- **
|
|
163
|
-
- **
|
|
164
|
-
- **
|
|
165
|
-
- **
|
|
166
|
-
- **
|
|
171
|
+
## 🆕 What's New in v0.2.0
|
|
172
|
+
|
|
173
|
+
- **Conformal Prediction**: Distribution-free prediction sets with true leave-one-out calibration guaranteeing finite-sample coverage at configurable significance levels
|
|
174
|
+
- **Contrastive Prototype Learning**: Gradient-trained InfoNCE class prototypes with 2-layer MLP projection head (full backpropagation through W1/b1/W2/b2)
|
|
175
|
+
- **Advanced Uncertainty**: Three complementary signals — epistemic (stochastic perturbation sensitivity), aleatoric (k-NN entropy), and distributional (shrinkage-regularized Mahalanobis OOD) — fused with mode-gated computation
|
|
176
|
+
- **XAI Explainability**: Embedding-space feature attribution, confidence decomposition with historical penalty tracking, counterfactual analysis, and per-dimension saliency
|
|
177
|
+
- **Cross-conformal prediction mode**: K-fold cross-conformal quantile averaging for more stable prediction sets
|
|
178
|
+
- **Bootstrap Temperature Calibration**: Autonomous LOO grid-search temperature optimization for cold-start scenarios
|
|
179
|
+
- **UP-UGF LSH Acceleration**: Approximate O(N log N) redundancy scoring via random projection locality-sensitive hashing for large buffers (>100 examples)
|
|
180
|
+
- **Memory Profiling**: `MemoryTracker` with tracemalloc + psutil instrumentation for verifying <250MB RAM operation
|
|
181
|
+
- **miniImageNet Benchmarks**: Standard few-shot benchmarks with baseline references (Prototypical Networks, Matching Networks, MAML)
|
|
182
|
+
- **ONNX Export**: Bundled backbone export script with SHA-256 verification for torch-free inference
|
|
183
|
+
- **ACT Symmetric Updates**: Mean-reverting threshold adaptation prevents monotonic drift in autonomous operation
|
|
184
|
+
- **37 new tests** (92 total) across conformal, contrastive, uncertainty, and explainability modules
|
|
185
|
+
- **12 new documentation pages**: Architecture deep-dive, algorithm theory, API reference, 5 tutorials, 2 GUI guides
|
|
167
186
|
|
|
168
187
|
## 🛠️ Configuration
|
|
169
188
|
|
|
@@ -191,7 +210,7 @@ AdaptShot uses a strictly typed, immutable `AdaptShotConfig` to ensure reproduci
|
|
|
191
210
|
| Parameter | Type | Default | Description |
|
|
192
211
|
| :--- | :--- | :--- | :--- |
|
|
193
212
|
| `similarity_metric` | `str` | `"euclidean"` | Distance metric (`cosine` or `euclidean`) |
|
|
194
|
-
| `inference_mode` | `str` | `"prototypical"` | Classification mode (`nearest_neighbor` or `
|
|
213
|
+
| `inference_mode` | `str` | `"prototypical"` | Classification mode (`nearest_neighbor`, `prototypical`, or `contrastive`) |
|
|
195
214
|
| `use_faiss` | `bool` | `False` | Enable FAISS-CPU acceleration for large support sets |
|
|
196
215
|
| `faiss_nprobe` | `int` | `8` | FAISS IVF index probing depth |
|
|
197
216
|
|
|
@@ -227,6 +246,15 @@ AdaptShot uses a strictly typed, immutable `AdaptShotConfig` to ensure reproduci
|
|
|
227
246
|
| `max_buffer_size` | `int` | `100` | Maximum replay buffer capacity (enforced by UP-UGF) |
|
|
228
247
|
| `log_dir` | `Optional[str]` | `None` | Optional log output directory |
|
|
229
248
|
|
|
249
|
+
### Advanced Algorithms (v0.2.0)
|
|
250
|
+
|
|
251
|
+
| Parameter | Type | Default | Description |
|
|
252
|
+
| :--- | :--- | :--- | :--- |
|
|
253
|
+
| `conformal_alpha` | `float` | `0.05` | Significance level for conformal prediction sets (0.01-0.50) |
|
|
254
|
+
| `conformal_mode` | `str` | `"split"` | Conformal prediction mode (`split` or `cross`) |
|
|
255
|
+
| `uncertainty_mode` | `str` | `"ensemble"` | Uncertainty mode (`mcdropout`, `entropy`, `mahalanobis`, or `ensemble`) |
|
|
256
|
+
| `explainability_enabled` | `bool` | `True` | Enable XAI explainability for predictions |
|
|
257
|
+
|
|
230
258
|
---
|
|
231
259
|
|
|
232
260
|
## ☁️ Deployment
|
|
@@ -280,8 +308,6 @@ Built in Mbeya, Tanzania 🇹🇿
|
|
|
280
308
|
<p><i>"The best AI doesn't guess confidently. It learns humbly, admits uncertainty, and improves through every human correction."</i></p>
|
|
281
309
|
</div>
|
|
282
310
|
|
|
283
|
-
```
|
|
284
|
-
|
|
285
311
|
---
|
|
286
312
|
|
|
287
313
|
## 🔍 Summary of Key Updates
|
|
@@ -292,7 +318,7 @@ Built in Mbeya, Tanzania 🇹🇿
|
|
|
292
318
|
| ✅ Updated **Docs badge** to live MkDocs URL | Users can access accurate, searchable documentation immediately |
|
|
293
319
|
| ✅ Fixed **installation instructions** to match `pyproject.toml` extras | Prevents user confusion; ensures `pip install adaptshot[faiss]` works |
|
|
294
320
|
| ✅ Corrected **API signatures** to match actual code (`FewShotLearner`, `PredictionResult`) | Developers can copy-paste examples with confidence |
|
|
295
|
-
| ✅ Marked v0.1.
|
|
321
|
+
| ✅ Marked v0.1.2 content as **stable / released** | Confirms publication status is accurate |
|
|
296
322
|
| ✅ Removed placeholder links (`arXiv:2604.XXXXX`, `adaptshot.dev/docs`) | No broken links; only verified, working resources |
|
|
297
323
|
| ✅ Kept the native API as the primary workflow | Reinforces code-first usage even with the optional GUI |
|
|
298
324
|
| ✅ Standardized **citation format** to GitHub + version | Academically sound; reproducible referencing |
|
|
@@ -27,16 +27,19 @@
|
|
|
27
27
|
|
|
28
28
|
AdaptShot is a high-performance, CPU-optimized, human-in-the-loop few-shot vision library. It is designed to learn from every human correction, guarantee calibrated uncertainty, and run deterministically on edge hardware with minimal resources.
|
|
29
29
|
|
|
30
|
-
v0.
|
|
30
|
+
v0.2.0-dev is the current release, hardened with 92 regression tests, strict type-checking, and a comprehensive benchmark harness. Built in Tanzania by a self-taught engineer with nothing but a laptop and determination.
|
|
31
31
|
|
|
32
32
|
</div>
|
|
33
33
|
|
|
34
34
|
## 🚀 Key Features
|
|
35
35
|
|
|
36
36
|
* **CPU-First by Design**: Optimized for low-latency inference on standard CPUs, requiring less than 250MB of RAM.
|
|
37
|
-
* **Trustworthy & Calibrated**: Built-in **Expected Calibration Error (ECE)** minimization
|
|
37
|
+
* **Trustworthy & Calibrated**: Built-in **Expected Calibration Error (ECE)** minimization and **conformal prediction** with finite-sample coverage guarantees.
|
|
38
38
|
* **Human-in-the-Loop**: Integrated **FeedbackRouter** for real-time model adaptation through human expert corrections.
|
|
39
|
-
* **Continual Learning**: Implements **CA-EWC** (
|
|
39
|
+
* **Continual Learning**: Implements **head-only CA-EWC** (Fisher-regularized classification head fine-tuning, ~2K parameters) and **UP-UGF** (Uncertainty-Guided Forgetting with LSH-accelerated redundancy scoring) for stable, long-term learning without catastrophic forgetting.
|
|
40
|
+
* **Multi-Signal Uncertainty**: Epistemic (stochastic embedding perturbation sensitivity), aleatoric (k-NN entropy), and distributional (shrinkage-regularized Mahalanobis distance) uncertainty quantification with OOD detection. *(Full MC Dropout planned for future torch-dependent release.)*
|
|
41
|
+
* **Explainable AI**: Embedding-space feature attribution (which support examples influenced the prediction), confidence decomposition with historical penalty tracking, counterfactual explanations, and per-dimension saliency analysis.
|
|
42
|
+
* **Contrastive Prototypes**: Gradient-trained class representations via InfoNCE contrastive loss with 2-layer MLP projection head and EMA momentum prototype refinement.
|
|
40
43
|
* **Release Hardened**: Zero-config API, strict type safety, and a comprehensive benchmark harness for review and deployment readiness.
|
|
41
44
|
* **Deterministic**: Guaranteed reproducible results across different runs and hardware through strict seed management.
|
|
42
45
|
|
|
@@ -62,10 +65,19 @@ $ pip install adaptshot
|
|
|
62
65
|
|
|
63
66
|
</div>
|
|
64
67
|
|
|
68
|
+
**Core dependencies**: numpy, Pillow (~15 MB total). PyTorch is **optional** — install it only if you need training/fine-tuning:
|
|
69
|
+
|
|
70
|
+
```bash
|
|
71
|
+
pip install "adaptshot[torch]" # adds PyTorch + torchvision for training
|
|
72
|
+
```
|
|
73
|
+
|
|
74
|
+
> **Fast install**: The base library installs in under 60 seconds on standard connections — no GPU drivers, no CUDA, no 2 GB downloads.
|
|
75
|
+
|
|
65
76
|
### Optional Dependencies
|
|
66
77
|
|
|
67
78
|
AdaptShot provides optional extras for specialized workflows. The native Python API remains the source of truth; the GUI is an optional wrapper around it:
|
|
68
79
|
|
|
80
|
+
* **PyTorch (Training)**: `pip install "adaptshot[torch]"` (Required for CA-EWC fine-tuning and custom backbones)
|
|
69
81
|
* **FAISS Acceleration**: `pip install "adaptshot[faiss]"` (Recommended for support sets >100 images)
|
|
70
82
|
* **Gradio UI**: `pip install "adaptshot[ui]"` (For interactive pilots and human-in-the-loop dashboards)
|
|
71
83
|
* **Studio GUI**: `pip install "adaptshot[gui]"` (For the offline, folder-aware AdaptShot Studio workspace)
|
|
@@ -111,15 +123,21 @@ if result.uncertainty_flag:
|
|
|
111
123
|
|
|
112
124
|
---
|
|
113
125
|
|
|
114
|
-
## 🆕 What's New in v0.
|
|
115
|
-
|
|
116
|
-
- **
|
|
117
|
-
- **
|
|
118
|
-
- **
|
|
119
|
-
- **
|
|
120
|
-
- **
|
|
121
|
-
- **
|
|
122
|
-
- **
|
|
126
|
+
## 🆕 What's New in v0.2.0
|
|
127
|
+
|
|
128
|
+
- **Conformal Prediction**: Distribution-free prediction sets with true leave-one-out calibration guaranteeing finite-sample coverage at configurable significance levels
|
|
129
|
+
- **Contrastive Prototype Learning**: Gradient-trained InfoNCE class prototypes with 2-layer MLP projection head (full backpropagation through W1/b1/W2/b2)
|
|
130
|
+
- **Advanced Uncertainty**: Three complementary signals — epistemic (stochastic perturbation sensitivity), aleatoric (k-NN entropy), and distributional (shrinkage-regularized Mahalanobis OOD) — fused with mode-gated computation
|
|
131
|
+
- **XAI Explainability**: Embedding-space feature attribution, confidence decomposition with historical penalty tracking, counterfactual analysis, and per-dimension saliency
|
|
132
|
+
- **Cross-conformal prediction mode**: K-fold cross-conformal quantile averaging for more stable prediction sets
|
|
133
|
+
- **Bootstrap Temperature Calibration**: Autonomous LOO grid-search temperature optimization for cold-start scenarios
|
|
134
|
+
- **UP-UGF LSH Acceleration**: Approximate O(N log N) redundancy scoring via random projection locality-sensitive hashing for large buffers (>100 examples)
|
|
135
|
+
- **Memory Profiling**: `MemoryTracker` with tracemalloc + psutil instrumentation for verifying <250MB RAM operation
|
|
136
|
+
- **miniImageNet Benchmarks**: Standard few-shot benchmarks with baseline references (Prototypical Networks, Matching Networks, MAML)
|
|
137
|
+
- **ONNX Export**: Bundled backbone export script with SHA-256 verification for torch-free inference
|
|
138
|
+
- **ACT Symmetric Updates**: Mean-reverting threshold adaptation prevents monotonic drift in autonomous operation
|
|
139
|
+
- **37 new tests** (92 total) across conformal, contrastive, uncertainty, and explainability modules
|
|
140
|
+
- **12 new documentation pages**: Architecture deep-dive, algorithm theory, API reference, 5 tutorials, 2 GUI guides
|
|
123
141
|
|
|
124
142
|
## 🛠️ Configuration
|
|
125
143
|
|
|
@@ -147,7 +165,7 @@ AdaptShot uses a strictly typed, immutable `AdaptShotConfig` to ensure reproduci
|
|
|
147
165
|
| Parameter | Type | Default | Description |
|
|
148
166
|
| :--- | :--- | :--- | :--- |
|
|
149
167
|
| `similarity_metric` | `str` | `"euclidean"` | Distance metric (`cosine` or `euclidean`) |
|
|
150
|
-
| `inference_mode` | `str` | `"prototypical"` | Classification mode (`nearest_neighbor` or `
|
|
168
|
+
| `inference_mode` | `str` | `"prototypical"` | Classification mode (`nearest_neighbor`, `prototypical`, or `contrastive`) |
|
|
151
169
|
| `use_faiss` | `bool` | `False` | Enable FAISS-CPU acceleration for large support sets |
|
|
152
170
|
| `faiss_nprobe` | `int` | `8` | FAISS IVF index probing depth |
|
|
153
171
|
|
|
@@ -183,6 +201,15 @@ AdaptShot uses a strictly typed, immutable `AdaptShotConfig` to ensure reproduci
|
|
|
183
201
|
| `max_buffer_size` | `int` | `100` | Maximum replay buffer capacity (enforced by UP-UGF) |
|
|
184
202
|
| `log_dir` | `Optional[str]` | `None` | Optional log output directory |
|
|
185
203
|
|
|
204
|
+
### Advanced Algorithms (v0.2.0)
|
|
205
|
+
|
|
206
|
+
| Parameter | Type | Default | Description |
|
|
207
|
+
| :--- | :--- | :--- | :--- |
|
|
208
|
+
| `conformal_alpha` | `float` | `0.05` | Significance level for conformal prediction sets (0.01-0.50) |
|
|
209
|
+
| `conformal_mode` | `str` | `"split"` | Conformal prediction mode (`split` or `cross`) |
|
|
210
|
+
| `uncertainty_mode` | `str` | `"ensemble"` | Uncertainty mode (`mcdropout`, `entropy`, `mahalanobis`, or `ensemble`) |
|
|
211
|
+
| `explainability_enabled` | `bool` | `True` | Enable XAI explainability for predictions |
|
|
212
|
+
|
|
186
213
|
---
|
|
187
214
|
|
|
188
215
|
## ☁️ Deployment
|
|
@@ -236,8 +263,6 @@ Built in Mbeya, Tanzania 🇹🇿
|
|
|
236
263
|
<p><i>"The best AI doesn't guess confidently. It learns humbly, admits uncertainty, and improves through every human correction."</i></p>
|
|
237
264
|
</div>
|
|
238
265
|
|
|
239
|
-
```
|
|
240
|
-
|
|
241
266
|
---
|
|
242
267
|
|
|
243
268
|
## 🔍 Summary of Key Updates
|
|
@@ -248,7 +273,7 @@ Built in Mbeya, Tanzania 🇹🇿
|
|
|
248
273
|
| ✅ Updated **Docs badge** to live MkDocs URL | Users can access accurate, searchable documentation immediately |
|
|
249
274
|
| ✅ Fixed **installation instructions** to match `pyproject.toml` extras | Prevents user confusion; ensures `pip install adaptshot[faiss]` works |
|
|
250
275
|
| ✅ Corrected **API signatures** to match actual code (`FewShotLearner`, `PredictionResult`) | Developers can copy-paste examples with confidence |
|
|
251
|
-
| ✅ Marked v0.1.
|
|
276
|
+
| ✅ Marked v0.1.2 content as **stable / released** | Confirms publication status is accurate |
|
|
252
277
|
| ✅ Removed placeholder links (`arXiv:2604.XXXXX`, `adaptshot.dev/docs`) | No broken links; only verified, working resources |
|
|
253
278
|
| ✅ Kept the native API as the primary workflow | Reinforces code-first usage even with the optional GUI |
|
|
254
279
|
| ✅ Standardized **citation format** to GitHub + version | Academically sound; reproducible referencing |
|
|
@@ -4,16 +4,15 @@ build-backend = "setuptools.build_meta"
|
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "adaptshot"
|
|
7
|
-
version = "0.
|
|
7
|
+
version = "0.2.0-dev"
|
|
8
8
|
description = "Human-aligned few-shot vision learning for resource-constrained environments"
|
|
9
9
|
readme = "README.md"
|
|
10
|
-
license =
|
|
10
|
+
license = "MIT"
|
|
11
11
|
requires-python = ">=3.9"
|
|
12
12
|
authors = [{name = "Johnson Christopher Hassan", email = "johnson2006christopher@gmail.com"}]
|
|
13
13
|
classifiers = [
|
|
14
14
|
"Development Status :: 3 - Alpha",
|
|
15
15
|
"Intended Audience :: Science/Research",
|
|
16
|
-
"License :: OSI Approved :: MIT License",
|
|
17
16
|
"Programming Language :: Python :: 3.9",
|
|
18
17
|
"Programming Language :: Python :: 3.10",
|
|
19
18
|
"Programming Language :: Python :: 3.11",
|
|
@@ -21,16 +20,15 @@ classifiers = [
|
|
|
21
20
|
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
|
22
21
|
]
|
|
23
22
|
dependencies = [
|
|
24
|
-
"torch>=2.0.0",
|
|
25
|
-
"torchvision>=0.15.0",
|
|
26
23
|
"numpy>=1.24.0",
|
|
27
24
|
"Pillow>=9.0.0",
|
|
28
25
|
]
|
|
29
26
|
|
|
30
27
|
[project.optional-dependencies]
|
|
28
|
+
torch = ["torch>=2.0.0", "torchvision>=0.15.0"]
|
|
31
29
|
faiss = ["faiss-cpu>=1.7.4"]
|
|
32
30
|
ui = ["gradio>=3.50.0"]
|
|
33
|
-
gui = ["gradio>=3.50.0", "pandas>=2.0.0", "onnx>=1.16.0", "onnxscript>=0.3.0"]
|
|
31
|
+
gui = ["gradio>=3.50.0", "pandas>=2.0.0", "onnx>=1.16.0", "onnxscript>=0.3.0", "onnxruntime>=1.14.0"]
|
|
34
32
|
dev = [
|
|
35
33
|
"pytest>=7.4.0",
|
|
36
34
|
"pytest-cov>=4.1.0",
|
|
@@ -38,7 +36,7 @@ dev = [
|
|
|
38
36
|
"ruff>=0.1.0",
|
|
39
37
|
"pre-commit>=3.4.0",
|
|
40
38
|
]
|
|
41
|
-
all = ["adaptshot[faiss,gui,dev]"]
|
|
39
|
+
all = ["adaptshot[torch,faiss,gui,dev]"]
|
|
42
40
|
|
|
43
41
|
[project.scripts]
|
|
44
42
|
adaptshot-studio = "adaptshot.studio.app:launch"
|
|
@@ -52,6 +50,9 @@ Repository = "https://github.com/johnson2006christopher/adaptshot.git"
|
|
|
52
50
|
[tool.setuptools.packages.find]
|
|
53
51
|
where = ["src"]
|
|
54
52
|
|
|
53
|
+
[tool.setuptools.package-data]
|
|
54
|
+
"adaptshot" = ["data/*.onnx", "data/*.onnx.data"]
|
|
55
|
+
|
|
55
56
|
[tool.ruff]
|
|
56
57
|
target-version = "py39"
|
|
57
58
|
line-length = 100
|
|
@@ -62,6 +63,14 @@ strict = true
|
|
|
62
63
|
warn_return_any = true
|
|
63
64
|
warn_unused_configs = true
|
|
64
65
|
|
|
66
|
+
[[tool.mypy.overrides]]
|
|
67
|
+
module = "onnxruntime"
|
|
68
|
+
ignore_missing_imports = true
|
|
69
|
+
|
|
70
|
+
[[tool.mypy.overrides]]
|
|
71
|
+
module = "torchvision.*"
|
|
72
|
+
ignore_missing_imports = true
|
|
73
|
+
|
|
65
74
|
[tool.pytest.ini_options]
|
|
66
75
|
testpaths = ["tests"]
|
|
67
76
|
# Uncomment if pytest-cov is installed:
|
|
@@ -1,11 +1,15 @@
|
|
|
1
1
|
"""AdaptShot: Human-Aligned Few-Shot Vision Learning."""
|
|
2
2
|
|
|
3
|
-
__version__ = "0.
|
|
3
|
+
__version__ = "0.2.0-dev"
|
|
4
4
|
|
|
5
5
|
from .config.settings import AdaptShotConfig
|
|
6
6
|
from .core.learner import FewShotLearner
|
|
7
7
|
from .core.calibration import CalibrationEngine
|
|
8
8
|
from .core.act import ACTEngine
|
|
9
|
+
from .core.conformal import ConformalEngine, ConformalPredictionSet
|
|
10
|
+
from .core.contrastive import ContrastivePrototypeLearner, ContrastiveConfig
|
|
11
|
+
from .core.uncertainty import UncertaintyQuantifier, UncertaintyReport
|
|
12
|
+
from .core.explain import ExplainabilityEngine, ExplanationResult, FeatureAttribution
|
|
9
13
|
from .training.feedback_router import FeedbackRouter
|
|
10
14
|
from .training.up_ugf import UPUGFPruner
|
|
11
15
|
from .utils.exceptions import (
|
|
@@ -21,6 +25,15 @@ __all__ = [
|
|
|
21
25
|
"FewShotLearner",
|
|
22
26
|
"CalibrationEngine",
|
|
23
27
|
"ACTEngine",
|
|
28
|
+
"ConformalEngine",
|
|
29
|
+
"ConformalPredictionSet",
|
|
30
|
+
"ContrastivePrototypeLearner",
|
|
31
|
+
"ContrastiveConfig",
|
|
32
|
+
"UncertaintyQuantifier",
|
|
33
|
+
"UncertaintyReport",
|
|
34
|
+
"ExplainabilityEngine",
|
|
35
|
+
"ExplanationResult",
|
|
36
|
+
"FeatureAttribution",
|
|
24
37
|
"FeedbackRouter",
|
|
25
38
|
"UPUGFPruner",
|
|
26
39
|
"AdaptShotError",
|
|
@@ -3,8 +3,6 @@
|
|
|
3
3
|
from dataclasses import dataclass
|
|
4
4
|
from typing import Literal, Optional
|
|
5
5
|
|
|
6
|
-
import torch
|
|
7
|
-
|
|
8
6
|
|
|
9
7
|
@dataclass(frozen=True)
|
|
10
8
|
class AdaptShotConfig:
|
|
@@ -29,7 +27,7 @@ class AdaptShotConfig:
|
|
|
29
27
|
use_faiss: bool = False # Toggle FAISS-CPU acceleration
|
|
30
28
|
faiss_nprobe: int = 8 # FAISS IVF index probing depth (if used later)
|
|
31
29
|
similarity_metric: Literal["cosine", "euclidean"] = "euclidean"
|
|
32
|
-
inference_mode: Literal["nearest_neighbor", "prototypical"] = "prototypical"
|
|
30
|
+
inference_mode: Literal["nearest_neighbor", "prototypical", "contrastive"] = "prototypical"
|
|
33
31
|
|
|
34
32
|
# Energy-aware inference
|
|
35
33
|
eco_mode: bool = False
|
|
@@ -47,6 +45,16 @@ class AdaptShotConfig:
|
|
|
47
45
|
ood_threshold_quantile: float = 0.98
|
|
48
46
|
ood_absolute_min_distance: float = 0.25
|
|
49
47
|
|
|
48
|
+
# Conformal prediction (v0.2.0)
|
|
49
|
+
conformal_alpha: float = 0.05
|
|
50
|
+
conformal_mode: Literal["split", "cross"] = "split"
|
|
51
|
+
|
|
52
|
+
# Advanced uncertainty (v0.2.0)
|
|
53
|
+
uncertainty_mode: Literal["mcdropout", "entropy", "mahalanobis", "ensemble"] = "ensemble"
|
|
54
|
+
|
|
55
|
+
# Explainability (v0.2.0)
|
|
56
|
+
explainability_enabled: bool = True
|
|
57
|
+
|
|
50
58
|
# Memory management (UP-UGF)
|
|
51
59
|
max_buffer_size: int = 100
|
|
52
60
|
|
|
@@ -70,9 +78,24 @@ class AdaptShotConfig:
|
|
|
70
78
|
raise ValueError("ood_threshold_quantile must be in [0.5, 1.0].")
|
|
71
79
|
if self.ood_absolute_min_distance < 0.0:
|
|
72
80
|
raise ValueError("ood_absolute_min_distance must be >= 0.0.")
|
|
73
|
-
if
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
81
|
+
if not 0.0 < self.conformal_alpha < 1.0:
|
|
82
|
+
raise ValueError("conformal_alpha must be in (0.0, 1.0).")
|
|
83
|
+
if self.conformal_mode not in ("split", "cross"):
|
|
84
|
+
raise ValueError("conformal_mode must be 'split' or 'cross'.")
|
|
85
|
+
if self.device == "cuda":
|
|
86
|
+
try:
|
|
87
|
+
import torch
|
|
88
|
+
if not torch.cuda.is_available():
|
|
89
|
+
import warnings
|
|
90
|
+
warnings.warn(
|
|
91
|
+
"CUDA requested but not available. "
|
|
92
|
+
"Runtime logic will fall back to CPU.",
|
|
93
|
+
RuntimeWarning,
|
|
94
|
+
)
|
|
95
|
+
except ImportError:
|
|
96
|
+
import warnings
|
|
97
|
+
warnings.warn(
|
|
98
|
+
"CUDA requested but PyTorch is not installed. "
|
|
99
|
+
"Install with: pip install 'adaptshot[torch]'",
|
|
100
|
+
RuntimeWarning,
|
|
101
|
+
)
|
|
@@ -48,6 +48,8 @@ class ACTEngine:
|
|
|
48
48
|
self.gamma = feedback_cost_factor
|
|
49
49
|
self.min_threshold = min_threshold
|
|
50
50
|
self.max_threshold = max_threshold
|
|
51
|
+
self._base_threshold = base_threshold
|
|
52
|
+
self._mean_reversion_strength = 0.001 # Slow pull toward base
|
|
51
53
|
|
|
52
54
|
# Per-class state: {class_idx: {"threshold": float, "correct": float, "incorrect": float, "total": float}}
|
|
53
55
|
self._class_state: Dict[int, Dict[str, float]] = {}
|
|
@@ -92,9 +94,22 @@ class ACTEngine:
|
|
|
92
94
|
state = self._class_state[class_idx]
|
|
93
95
|
threshold = float(np.clip(state["threshold"], self.min_threshold, self.max_threshold))
|
|
94
96
|
|
|
95
|
-
#
|
|
96
|
-
delta =
|
|
97
|
-
|
|
97
|
+
# v0.2.0 fix: Symmetric bounded update with mean reversion.
|
|
98
|
+
# Previous formula (v0.2.0-dev): delta = η * (incorrect - γ * correct)
|
|
99
|
+
# This monotonically decreased thresholds because γ=0.5 multiplied the
|
|
100
|
+
# (usually larger) correct rate, creating a permanent downward bias.
|
|
101
|
+
#
|
|
102
|
+
# New formula: delta = η * (incorrect_rate - correct_rate) + μ * (base - τ)
|
|
103
|
+
# - Symmetric: equal weight to incorrect vs correct signals
|
|
104
|
+
# - Mean-reversion: thresholds drift back toward base_threshold slowly
|
|
105
|
+
# - Clamped: thresholds stay within [min_threshold, max_threshold]
|
|
106
|
+
error_signal = recent_incorrect_rate - recent_correct_rate
|
|
107
|
+
delta = self.eta * error_signal
|
|
108
|
+
# Mean-reversion toward base (prevents runaway drift)
|
|
109
|
+
delta += self._mean_reversion_strength * (self._base_threshold - threshold)
|
|
110
|
+
state["threshold"] = float(np.clip(
|
|
111
|
+
threshold + delta, self.min_threshold, self.max_threshold
|
|
112
|
+
))
|
|
98
113
|
|
|
99
114
|
# Update counters (EMA-style tracking)
|
|
100
115
|
state["total"] += 1.0
|
|
@@ -103,11 +118,15 @@ class ACTEngine:
|
|
|
103
118
|
else:
|
|
104
119
|
state["correct"] += 1.0
|
|
105
120
|
|
|
106
|
-
|
|
121
|
+
# Re-read threshold after update for decision
|
|
122
|
+
threshold_updated = float(np.clip(
|
|
123
|
+
state["threshold"], self.min_threshold, self.max_threshold
|
|
124
|
+
))
|
|
125
|
+
accept = confidence >= threshold_updated
|
|
107
126
|
action = "ACCEPT" if accept else "REQUEST_FEEDBACK"
|
|
108
127
|
|
|
109
128
|
logger.debug(
|
|
110
|
-
f"ACT | Class {class_idx} | Conf: {confidence:.3f} | τ: {
|
|
129
|
+
f"ACT | Class {class_idx} | Conf: {confidence:.3f} | τ: {threshold_updated:.3f} | Action: {action}"
|
|
111
130
|
)
|
|
112
131
|
|
|
113
132
|
return accept, action
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
"""Backend-agnostic feature extraction for AdaptShot.
|
|
2
|
+
|
|
3
|
+
Auto-detects the best available backend (ONNX Runtime → PyTorch) and provides
|
|
4
|
+
a unified ``extract_embedding()`` interface. The core library works with the
|
|
5
|
+
lightweight onnxruntime dependency by default; install ``adaptshot[torch]``
|
|
6
|
+
for training and fine-tuning support.
|
|
7
|
+
|
|
8
|
+
Priority order:
|
|
9
|
+
1. ONNX Runtime (lightweight, fast install, CPU-optimized)
|
|
10
|
+
2. PyTorch (full flexibility, required for training)
|
|
11
|
+
"""
|
|
@@ -0,0 +1,95 @@
|
|
|
1
|
+
"""ONNX Runtime backend for lightweight feature extraction.
|
|
2
|
+
|
|
3
|
+
Uses pre-exported ONNX backbone models bundled with the package.
|
|
4
|
+
No PyTorch required – ideal for fast installation and edge deployment.
|
|
5
|
+
|
|
6
|
+
The ONNX models are generated by: python scripts/export_backbones.py
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
import logging
|
|
12
|
+
from functools import lru_cache
|
|
13
|
+
from pathlib import Path
|
|
14
|
+
from typing import Any, Dict
|
|
15
|
+
|
|
16
|
+
import numpy as np
|
|
17
|
+
from PIL import Image
|
|
18
|
+
|
|
19
|
+
try:
|
|
20
|
+
import onnxruntime as ort
|
|
21
|
+
|
|
22
|
+
ONNX_AVAILABLE = True
|
|
23
|
+
except ImportError: # pragma: no cover – onnxruntime is optional
|
|
24
|
+
ONNX_AVAILABLE = False
|
|
25
|
+
|
|
26
|
+
logger = logging.getLogger(__name__)
|
|
27
|
+
|
|
28
|
+
_MODEL_DIR = Path(__file__).resolve().parent.parent.parent / "data"
|
|
29
|
+
|
|
30
|
+
_ONNX_MODEL_FILES: Dict[str, str] = {
|
|
31
|
+
"resnet18": "resnet18.onnx",
|
|
32
|
+
"mobilenet_v3_small": "mobilenet_v3_small.onnx",
|
|
33
|
+
}
|
|
34
|
+
|
|
35
|
+
_OUTPUT_DIM: Dict[str, int] = {
|
|
36
|
+
"resnet18": 512,
|
|
37
|
+
"mobilenet_v3_small": 576,
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
_IMAGENET_MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32)
|
|
41
|
+
_IMAGENET_STD = np.array([0.229, 0.224, 0.225], dtype=np.float32)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class ONNXBackend:
|
|
45
|
+
"""Feature extraction using ONNX Runtime with bundled backbone models."""
|
|
46
|
+
|
|
47
|
+
def __init__(self) -> None:
|
|
48
|
+
if not ONNX_AVAILABLE:
|
|
49
|
+
raise ImportError(
|
|
50
|
+
"onnxruntime is not installed. Run: pip install onnxruntime"
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
@lru_cache(maxsize=4)
|
|
54
|
+
def _get_session(self, backbone_name: str) -> Any:
|
|
55
|
+
"""Load and cache an ONNX inference session for the given backbone."""
|
|
56
|
+
if backbone_name not in _ONNX_MODEL_FILES:
|
|
57
|
+
raise ValueError(
|
|
58
|
+
f"Unknown backbone: {backbone_name}. "
|
|
59
|
+
f"Available: {list(_ONNX_MODEL_FILES.keys())}"
|
|
60
|
+
)
|
|
61
|
+
model_path = _MODEL_DIR / _ONNX_MODEL_FILES[backbone_name]
|
|
62
|
+
if not model_path.exists():
|
|
63
|
+
raise FileNotFoundError(
|
|
64
|
+
f"ONNX model not found: {model_path}. "
|
|
65
|
+
f"Run scripts/export_backbones.py to generate it, "
|
|
66
|
+
f"or install with: pip install 'adaptshot[torch]'"
|
|
67
|
+
)
|
|
68
|
+
logger.info("Loading ONNX backbone from %s", model_path)
|
|
69
|
+
return ort.InferenceSession(
|
|
70
|
+
str(model_path), providers=["CPUExecutionProvider"]
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
def extract(
|
|
74
|
+
self, pil_image: Image.Image, backbone_name: str
|
|
75
|
+
) -> np.ndarray:
|
|
76
|
+
"""Run inference and return the feature embedding as a numpy array."""
|
|
77
|
+
session = self._get_session(backbone_name)
|
|
78
|
+
input_tensor = self._preprocess(pil_image)
|
|
79
|
+
input_name = session.get_inputs()[0].name
|
|
80
|
+
outputs = session.run(None, {input_name: input_tensor})
|
|
81
|
+
return outputs[0].squeeze(0) # type: ignore[no-any-return]
|
|
82
|
+
|
|
83
|
+
@staticmethod
|
|
84
|
+
def _preprocess(pil_image: Image.Image) -> np.ndarray:
|
|
85
|
+
"""Apply ImageNet-standard preprocessing: resize → normalize → NCHW."""
|
|
86
|
+
img = pil_image.resize((224, 224), Image.BILINEAR) # type: ignore[attr-defined]
|
|
87
|
+
img_array = np.asarray(img, dtype=np.float32) / 255.0
|
|
88
|
+
img_array = (img_array - _IMAGENET_MEAN) / _IMAGENET_STD
|
|
89
|
+
img_array = np.transpose(img_array, (2, 0, 1)).astype(np.float32)
|
|
90
|
+
return img_array[np.newaxis, ...]
|
|
91
|
+
|
|
92
|
+
@property
|
|
93
|
+
def output_dim(self) -> Dict[str, int]:
|
|
94
|
+
"""Return the embedding dimensionality for each supported backbone."""
|
|
95
|
+
return dict(_OUTPUT_DIM)
|
|
@@ -5,7 +5,6 @@ from __future__ import annotations
|
|
|
5
5
|
from typing import Dict, List, Optional
|
|
6
6
|
|
|
7
7
|
import numpy as np
|
|
8
|
-
import torch
|
|
9
8
|
|
|
10
9
|
|
|
11
10
|
class CalibrationEngine:
|
|
@@ -25,7 +24,7 @@ class CalibrationEngine:
|
|
|
25
24
|
self.window_size = max(1, int(window_size))
|
|
26
25
|
self.eval_bins = max(self.n_bins, int(evaluation_bins or self.n_bins))
|
|
27
26
|
self.scaling_binning_bins = max(2, int(scaling_binning_bins or self.n_bins))
|
|
28
|
-
self.temperature =
|
|
27
|
+
self.temperature = float(temperature_init)
|
|
29
28
|
self.method = method
|
|
30
29
|
|
|
31
30
|
self._window_confidences: List[float] = []
|
|
@@ -148,7 +147,7 @@ class CalibrationEngine:
|
|
|
148
147
|
best_loss = float(loss)
|
|
149
148
|
best_temp = float(candidate)
|
|
150
149
|
|
|
151
|
-
self.temperature =
|
|
150
|
+
self.temperature = best_temp
|
|
152
151
|
|
|
153
152
|
def _refit_conformal_margin(self) -> None:
|
|
154
153
|
"""Update a conservative conformal-style correction margin."""
|