shortkit-ml 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- shortkit_ml-0.1.0/LICENSE +21 -0
- shortkit_ml-0.1.0/PKG-INFO +524 -0
- shortkit_ml-0.1.0/README.md +425 -0
- shortkit_ml-0.1.0/pyproject.toml +215 -0
- shortkit_ml-0.1.0/setup.cfg +4 -0
- shortkit_ml-0.1.0/shortcut_detect/__init__.py +93 -0
- shortkit_ml-0.1.0/shortcut_detect/base_builder.py +122 -0
- shortkit_ml-0.1.0/shortcut_detect/benchmark/__init__.py +87 -0
- shortkit_ml-0.1.0/shortcut_detect/benchmark/baseline_comparison.py +404 -0
- shortkit_ml-0.1.0/shortcut_detect/benchmark/chexpert_extraction.py +154 -0
- shortkit_ml-0.1.0/shortcut_detect/benchmark/convergence_viz.py +384 -0
- shortkit_ml-0.1.0/shortcut_detect/benchmark/figures.py +509 -0
- shortkit_ml-0.1.0/shortcut_detect/benchmark/fp_analysis.py +169 -0
- shortkit_ml-0.1.0/shortcut_detect/benchmark/measurement.py +524 -0
- shortkit_ml-0.1.0/shortcut_detect/benchmark/method_utils.py +197 -0
- shortkit_ml-0.1.0/shortcut_detect/benchmark/paper_run.py +6 -0
- shortkit_ml-0.1.0/shortcut_detect/benchmark/paper_runner.py +1071 -0
- shortkit_ml-0.1.0/shortcut_detect/benchmark/run.py +6 -0
- shortkit_ml-0.1.0/shortcut_detect/benchmark/runner.py +797 -0
- shortkit_ml-0.1.0/shortcut_detect/benchmark/sensitivity.py +481 -0
- shortkit_ml-0.1.0/shortcut_detect/benchmark/synthetic.py +111 -0
- shortkit_ml-0.1.0/shortcut_detect/benchmark/synthetic_generator.py +395 -0
- shortkit_ml-0.1.0/shortcut_detect/causal/__init__.py +6 -0
- shortkit_ml-0.1.0/shortcut_detect/causal/causal_effect/__init__.py +1 -0
- shortkit_ml-0.1.0/shortcut_detect/causal/causal_effect/builder.py +89 -0
- shortkit_ml-0.1.0/shortcut_detect/causal/causal_effect/registry.py +6 -0
- shortkit_ml-0.1.0/shortcut_detect/causal/causal_effect/src/__init__.py +1 -0
- shortkit_ml-0.1.0/shortcut_detect/causal/causal_effect/src/detector.py +267 -0
- shortkit_ml-0.1.0/shortcut_detect/causal/causal_effect_detector.py +17 -0
- shortkit_ml-0.1.0/shortcut_detect/causal/generative_cvae/__init__.py +1 -0
- shortkit_ml-0.1.0/shortcut_detect/causal/generative_cvae/builder.py +78 -0
- shortkit_ml-0.1.0/shortcut_detect/causal/generative_cvae/registry.py +6 -0
- shortkit_ml-0.1.0/shortcut_detect/causal/generative_cvae/src/__init__.py +1 -0
- shortkit_ml-0.1.0/shortcut_detect/causal/generative_cvae/src/detector.py +426 -0
- shortkit_ml-0.1.0/shortcut_detect/causal/plugin.py +4 -0
- shortkit_ml-0.1.0/shortcut_detect/clustering/__init__.py +8 -0
- shortkit_ml-0.1.0/shortcut_detect/clustering/builder.py +57 -0
- shortkit_ml-0.1.0/shortcut_detect/clustering/hbac_detector.py +528 -0
- shortkit_ml-0.1.0/shortcut_detect/clustering/plugin.py +6 -0
- shortkit_ml-0.1.0/shortcut_detect/comparison/__init__.py +7 -0
- shortkit_ml-0.1.0/shortcut_detect/comparison/runner.py +273 -0
- shortkit_ml-0.1.0/shortcut_detect/conditions/__init__.py +22 -0
- shortkit_ml-0.1.0/shortcut_detect/conditions/base.py +26 -0
- shortkit_ml-0.1.0/shortcut_detect/conditions/indicator_count.py +27 -0
- shortkit_ml-0.1.0/shortcut_detect/conditions/majority_vote.py +47 -0
- shortkit_ml-0.1.0/shortcut_detect/conditions/meta_classifier.py +211 -0
- shortkit_ml-0.1.0/shortcut_detect/conditions/meta_model.joblib +0 -0
- shortkit_ml-0.1.0/shortcut_detect/conditions/meta_model.meta.json +43 -0
- shortkit_ml-0.1.0/shortcut_detect/conditions/multi_attribute.py +108 -0
- shortkit_ml-0.1.0/shortcut_detect/conditions/registry.py +38 -0
- shortkit_ml-0.1.0/shortcut_detect/conditions/weighted_risk.py +131 -0
- shortkit_ml-0.1.0/shortcut_detect/datasets.py +220 -0
- shortkit_ml-0.1.0/shortcut_detect/detector_base.py +313 -0
- shortkit_ml-0.1.0/shortcut_detect/detector_template.py +368 -0
- shortkit_ml-0.1.0/shortcut_detect/discovery.py +51 -0
- shortkit_ml-0.1.0/shortcut_detect/embedding_sources.py +174 -0
- shortkit_ml-0.1.0/shortcut_detect/fairness/__init__.py +14 -0
- shortkit_ml-0.1.0/shortcut_detect/fairness/demographic_parity/__init__.py +1 -0
- shortkit_ml-0.1.0/shortcut_detect/fairness/demographic_parity/builder.py +73 -0
- shortkit_ml-0.1.0/shortcut_detect/fairness/demographic_parity/registry.py +6 -0
- shortkit_ml-0.1.0/shortcut_detect/fairness/demographic_parity/src/__init__.py +1 -0
- shortkit_ml-0.1.0/shortcut_detect/fairness/demographic_parity/src/detector.py +159 -0
- shortkit_ml-0.1.0/shortcut_detect/fairness/equalized_odds/__init__.py +1 -0
- shortkit_ml-0.1.0/shortcut_detect/fairness/equalized_odds/builder.py +74 -0
- shortkit_ml-0.1.0/shortcut_detect/fairness/equalized_odds/registry.py +6 -0
- shortkit_ml-0.1.0/shortcut_detect/fairness/equalized_odds/src/__init__.py +1 -0
- shortkit_ml-0.1.0/shortcut_detect/fairness/equalized_odds/src/detector.py +196 -0
- shortkit_ml-0.1.0/shortcut_detect/fairness/intersectional/__init__.py +1 -0
- shortkit_ml-0.1.0/shortcut_detect/fairness/intersectional/builder.py +96 -0
- shortkit_ml-0.1.0/shortcut_detect/fairness/intersectional/registry.py +6 -0
- shortkit_ml-0.1.0/shortcut_detect/fairness/intersectional/src/__init__.py +1 -0
- shortkit_ml-0.1.0/shortcut_detect/fairness/intersectional/src/detector.py +324 -0
- shortkit_ml-0.1.0/shortcut_detect/fairness/plugin.py +5 -0
- shortkit_ml-0.1.0/shortcut_detect/frequency/__init__.py +5 -0
- shortkit_ml-0.1.0/shortcut_detect/frequency/adcs.py +28 -0
- shortkit_ml-0.1.0/shortcut_detect/frequency/builder.py +68 -0
- shortkit_ml-0.1.0/shortcut_detect/frequency/detector.py +221 -0
- shortkit_ml-0.1.0/shortcut_detect/frequency/frequency_detector.py +9 -0
- shortkit_ml-0.1.0/shortcut_detect/frequency/plugin.py +6 -0
- shortkit_ml-0.1.0/shortcut_detect/frequency/sensitivity.py +37 -0
- shortkit_ml-0.1.0/shortcut_detect/gce/__init__.py +5 -0
- shortkit_ml-0.1.0/shortcut_detect/gce/builder.py +64 -0
- shortkit_ml-0.1.0/shortcut_detect/gce/gce_detector.py +301 -0
- shortkit_ml-0.1.0/shortcut_detect/gce/plugin.py +6 -0
- shortkit_ml-0.1.0/shortcut_detect/geometric/__init__.py +15 -0
- shortkit_ml-0.1.0/shortcut_detect/geometric/bias_direction_pca/__init__.py +1 -0
- shortkit_ml-0.1.0/shortcut_detect/geometric/bias_direction_pca/builder.py +69 -0
- shortkit_ml-0.1.0/shortcut_detect/geometric/bias_direction_pca/registry.py +6 -0
- shortkit_ml-0.1.0/shortcut_detect/geometric/bias_direction_pca/src/__init__.py +1 -0
- shortkit_ml-0.1.0/shortcut_detect/geometric/bias_direction_pca/src/detector.py +182 -0
- shortkit_ml-0.1.0/shortcut_detect/geometric/geometric/__init__.py +1 -0
- shortkit_ml-0.1.0/shortcut_detect/geometric/geometric/builder.py +73 -0
- shortkit_ml-0.1.0/shortcut_detect/geometric/geometric/registry.py +6 -0
- shortkit_ml-0.1.0/shortcut_detect/geometric/geometric/src/__init__.py +1 -0
- shortkit_ml-0.1.0/shortcut_detect/geometric/geometric/src/detector.py +281 -0
- shortkit_ml-0.1.0/shortcut_detect/geometric/plugin.py +4 -0
- shortkit_ml-0.1.0/shortcut_detect/gradcam.py +367 -0
- shortkit_ml-0.1.0/shortcut_detect/groupdro/__init__.py +1 -0
- shortkit_ml-0.1.0/shortcut_detect/groupdro/builder.py +137 -0
- shortkit_ml-0.1.0/shortcut_detect/groupdro/groupdro.py +1136 -0
- shortkit_ml-0.1.0/shortcut_detect/groupdro/plugin.py +6 -0
- shortkit_ml-0.1.0/shortcut_detect/mcp_server.py +460 -0
- shortkit_ml-0.1.0/shortcut_detect/metrics.py +35 -0
- shortkit_ml-0.1.0/shortcut_detect/mitigation/__init__.py +17 -0
- shortkit_ml-0.1.0/shortcut_detect/mitigation/adversarial_debiasing.py +269 -0
- shortkit_ml-0.1.0/shortcut_detect/mitigation/background_randomizer.py +102 -0
- shortkit_ml-0.1.0/shortcut_detect/mitigation/contrastive_debiasing.py +325 -0
- shortkit_ml-0.1.0/shortcut_detect/mitigation/explanation_regularization.py +194 -0
- shortkit_ml-0.1.0/shortcut_detect/mitigation/last_layer_retraining.py +210 -0
- shortkit_ml-0.1.0/shortcut_detect/mitigation/shortcut_masking.py +202 -0
- shortkit_ml-0.1.0/shortcut_detect/model_registry.py +131 -0
- shortkit_ml-0.1.0/shortcut_detect/probes/__init__.py +30 -0
- shortkit_ml-0.1.0/shortcut_detect/probes/builder.py +155 -0
- shortkit_ml-0.1.0/shortcut_detect/probes/pipeline.py +78 -0
- shortkit_ml-0.1.0/shortcut_detect/probes/plugin.py +6 -0
- shortkit_ml-0.1.0/shortcut_detect/probes/probe_factory.py +83 -0
- shortkit_ml-0.1.0/shortcut_detect/probes/sklearn_probe.py +392 -0
- shortkit_ml-0.1.0/shortcut_detect/probes/torch_probe.py +686 -0
- shortkit_ml-0.1.0/shortcut_detect/reporting/__init__.py +7 -0
- shortkit_ml-0.1.0/shortcut_detect/reporting/comparison_report.py +194 -0
- shortkit_ml-0.1.0/shortcut_detect/reporting/csv_export.py +786 -0
- shortkit_ml-0.1.0/shortcut_detect/reporting/report_builder.py +1901 -0
- shortkit_ml-0.1.0/shortcut_detect/reporting/reporters.py +187 -0
- shortkit_ml-0.1.0/shortcut_detect/reporting/risk_format.py +372 -0
- shortkit_ml-0.1.0/shortcut_detect/reporting/visualizations.py +525 -0
- shortkit_ml-0.1.0/shortcut_detect/ssa/__init__.py +5 -0
- shortkit_ml-0.1.0/shortcut_detect/ssa/builder.py +229 -0
- shortkit_ml-0.1.0/shortcut_detect/ssa/plugin.py +6 -0
- shortkit_ml-0.1.0/shortcut_detect/ssa/ssa.py +608 -0
- shortkit_ml-0.1.0/shortcut_detect/statistical/__init__.py +15 -0
- shortkit_ml-0.1.0/shortcut_detect/statistical/builder.py +75 -0
- shortkit_ml-0.1.0/shortcut_detect/statistical/group_diff_test.py +295 -0
- shortkit_ml-0.1.0/shortcut_detect/statistical/plugin.py +6 -0
- shortkit_ml-0.1.0/shortcut_detect/training/__init__.py +22 -0
- shortkit_ml-0.1.0/shortcut_detect/training/builder.py +107 -0
- shortkit_ml-0.1.0/shortcut_detect/training/data_adapters.py +142 -0
- shortkit_ml-0.1.0/shortcut_detect/training/early_epoch_clustering.py +164 -0
- shortkit_ml-0.1.0/shortcut_detect/training/loader_hooks.py +96 -0
- shortkit_ml-0.1.0/shortcut_detect/training/plugin.py +6 -0
- shortkit_ml-0.1.0/shortcut_detect/unified.py +596 -0
- shortkit_ml-0.1.0/shortcut_detect/utils.py +112 -0
- shortkit_ml-0.1.0/shortcut_detect/vae/__init__.py +14 -0
- shortkit_ml-0.1.0/shortcut_detect/vae/builder.py +129 -0
- shortkit_ml-0.1.0/shortcut_detect/vae/latent_analyzer.py +96 -0
- shortkit_ml-0.1.0/shortcut_detect/vae/plugin.py +6 -0
- shortkit_ml-0.1.0/shortcut_detect/vae/vae_arch.py +177 -0
- shortkit_ml-0.1.0/shortcut_detect/vae/vae_detector.py +387 -0
- shortkit_ml-0.1.0/shortcut_detect/xai/__init__.py +8 -0
- shortkit_ml-0.1.0/shortcut_detect/xai/cav/__init__.py +1 -0
- shortkit_ml-0.1.0/shortcut_detect/xai/cav/builder.py +85 -0
- shortkit_ml-0.1.0/shortcut_detect/xai/cav/registry.py +6 -0
- shortkit_ml-0.1.0/shortcut_detect/xai/cav/src/__init__.py +1 -0
- shortkit_ml-0.1.0/shortcut_detect/xai/cav/src/detector.py +353 -0
- shortkit_ml-0.1.0/shortcut_detect/xai/gradcam_mask_overlap/__init__.py +1 -0
- shortkit_ml-0.1.0/shortcut_detect/xai/gradcam_mask_overlap/builder.py +80 -0
- shortkit_ml-0.1.0/shortcut_detect/xai/gradcam_mask_overlap/registry.py +6 -0
- shortkit_ml-0.1.0/shortcut_detect/xai/gradcam_mask_overlap/src/__init__.py +1 -0
- shortkit_ml-0.1.0/shortcut_detect/xai/gradcam_mask_overlap/src/detector.py +292 -0
- shortkit_ml-0.1.0/shortcut_detect/xai/plugin.py +5 -0
- shortkit_ml-0.1.0/shortcut_detect/xai/sis/__init__.py +1 -0
- shortkit_ml-0.1.0/shortcut_detect/xai/sis/builder.py +75 -0
- shortkit_ml-0.1.0/shortcut_detect/xai/sis/registry.py +6 -0
- shortkit_ml-0.1.0/shortcut_detect/xai/sis/src/__init__.py +1 -0
- shortkit_ml-0.1.0/shortcut_detect/xai/sis/src/detector.py +285 -0
- shortkit_ml-0.1.0/shortcut_detect/xai/spray_detector.py +547 -0
- shortkit_ml-0.1.0/shortkit_ml.egg-info/PKG-INFO +524 -0
- shortkit_ml-0.1.0/shortkit_ml.egg-info/SOURCES.txt +226 -0
- shortkit_ml-0.1.0/shortkit_ml.egg-info/dependency_links.txt +1 -0
- shortkit_ml-0.1.0/shortkit_ml.egg-info/entry_points.txt +2 -0
- shortkit_ml-0.1.0/shortkit_ml.egg-info/requires.txt +86 -0
- shortkit_ml-0.1.0/shortkit_ml.egg-info/top_level.txt +1 -0
- shortkit_ml-0.1.0/tests/test_adversarial_debiasing.py +176 -0
- shortkit_ml-0.1.0/tests/test_aggregation_formula.py +241 -0
- shortkit_ml-0.1.0/tests/test_background_randomizer.py +79 -0
- shortkit_ml-0.1.0/tests/test_baseline_comparison.py +92 -0
- shortkit_ml-0.1.0/tests/test_benchmark.py +178 -0
- shortkit_ml-0.1.0/tests/test_benchmark_synthetic.py +102 -0
- shortkit_ml-0.1.0/tests/test_causal.py +296 -0
- shortkit_ml-0.1.0/tests/test_causal_effect.py +216 -0
- shortkit_ml-0.1.0/tests/test_causal_effect_loader_integration.py +58 -0
- shortkit_ml-0.1.0/tests/test_cav.py +108 -0
- shortkit_ml-0.1.0/tests/test_cav_loader_integration.py +36 -0
- shortkit_ml-0.1.0/tests/test_clustering.py +108 -0
- shortkit_ml-0.1.0/tests/test_conditions.py +627 -0
- shortkit_ml-0.1.0/tests/test_contrastive_debiasing.py +197 -0
- shortkit_ml-0.1.0/tests/test_convergence_viz.py +270 -0
- shortkit_ml-0.1.0/tests/test_dashboard.py +79 -0
- shortkit_ml-0.1.0/tests/test_detector_base.py +41 -0
- shortkit_ml-0.1.0/tests/test_detector_factory.py +64 -0
- shortkit_ml-0.1.0/tests/test_early_epoch_clustering.py +62 -0
- shortkit_ml-0.1.0/tests/test_edge_cases.py +117 -0
- shortkit_ml-0.1.0/tests/test_effect_size_calibration.py +70 -0
- shortkit_ml-0.1.0/tests/test_embedding_mode.py +85 -0
- shortkit_ml-0.1.0/tests/test_explanation_regularization.py +143 -0
- shortkit_ml-0.1.0/tests/test_fairness.py +455 -0
- shortkit_ml-0.1.0/tests/test_figures.py +178 -0
- shortkit_ml-0.1.0/tests/test_fp_analysis.py +71 -0
- shortkit_ml-0.1.0/tests/test_frequency.py +86 -0
- shortkit_ml-0.1.0/tests/test_gce.py +219 -0
- shortkit_ml-0.1.0/tests/test_geometric.py +96 -0
- shortkit_ml-0.1.0/tests/test_gradcam.py +228 -0
- shortkit_ml-0.1.0/tests/test_gradcam_mask_overlap.py +39 -0
- shortkit_ml-0.1.0/tests/test_groupdro.py +282 -0
- shortkit_ml-0.1.0/tests/test_last_layer_retraining.py +182 -0
- shortkit_ml-0.1.0/tests/test_mcp_server.py +336 -0
- shortkit_ml-0.1.0/tests/test_measurement_harness.py +289 -0
- shortkit_ml-0.1.0/tests/test_method_utils.py +171 -0
- shortkit_ml-0.1.0/tests/test_model_comparison.py +122 -0
- shortkit_ml-0.1.0/tests/test_model_registry.py +39 -0
- shortkit_ml-0.1.0/tests/test_multi_attribute.py +130 -0
- shortkit_ml-0.1.0/tests/test_paper_benchmark.py +77 -0
- shortkit_ml-0.1.0/tests/test_probes.py +492 -0
- shortkit_ml-0.1.0/tests/test_reporting.py +229 -0
- shortkit_ml-0.1.0/tests/test_reproducibility_smoke.py +325 -0
- shortkit_ml-0.1.0/tests/test_risk_format.py +74 -0
- shortkit_ml-0.1.0/tests/test_score_normalization.py +111 -0
- shortkit_ml-0.1.0/tests/test_sensitivity.py +206 -0
- shortkit_ml-0.1.0/tests/test_shortcut_masking.py +160 -0
- shortkit_ml-0.1.0/tests/test_sis_detector.py +125 -0
- shortkit_ml-0.1.0/tests/test_spray.py +59 -0
- shortkit_ml-0.1.0/tests/test_ssa.py +305 -0
- shortkit_ml-0.1.0/tests/test_statistical.py +225 -0
- shortkit_ml-0.1.0/tests/test_synthetic_generator.py +387 -0
- shortkit_ml-0.1.0/tests/test_threshold_audit.py +410 -0
- shortkit_ml-0.1.0/tests/test_unified.py +123 -0
- shortkit_ml-0.1.0/tests/test_vae.py +144 -0
- shortkit_ml-0.1.0/tests/test_vae_loader_integration.py +63 -0
- shortkit_ml-0.1.0/tests/test_validation.py +140 -0
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2024 ShortKIT-ML Team
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
|
@@ -0,0 +1,524 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: shortkit-ml
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: ShortKIT-ML: A toolkit for detecting shortcuts and biases in embedding spaces
|
|
5
|
+
Author: Sebastian Cajas, Aldo Marzullo, Sahil Kapadia, Qingpeng Kong, Filipe Santos, Alessandro Quarta, Leo Celi
|
|
6
|
+
License: MIT
|
|
7
|
+
Project-URL: Homepage, https://github.com/criticaldata/ShortKit-ML
|
|
8
|
+
Project-URL: Documentation, https://criticaldata.github.io/ShortKit-ML/
|
|
9
|
+
Project-URL: Repository, https://github.com/criticaldata/ShortKit-ML
|
|
10
|
+
Project-URL: Issues, https://github.com/criticaldata/ShortKit-ML/issues
|
|
11
|
+
Keywords: machine-learning,bias-detection,embeddings,fairness,shortcuts
|
|
12
|
+
Classifier: Development Status :: 3 - Alpha
|
|
13
|
+
Classifier: Intended Audience :: Science/Research
|
|
14
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
15
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
16
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
17
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
18
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
19
|
+
Requires-Python: <3.13,>=3.10
|
|
20
|
+
Description-Content-Type: text/markdown
|
|
21
|
+
License-File: LICENSE
|
|
22
|
+
Requires-Dist: numpy<2.0.0,>=1.24.0
|
|
23
|
+
Requires-Dist: pandas<3.0.0,>=2.0.0
|
|
24
|
+
Requires-Dist: scikit-learn<2.0.0,>=1.3.0
|
|
25
|
+
Requires-Dist: scipy<2.0.0,>=1.11.0
|
|
26
|
+
Requires-Dist: statsmodels<1.0.0,>=0.14.0
|
|
27
|
+
Requires-Dist: matplotlib<4.0.0,>=3.7.0
|
|
28
|
+
Requires-Dist: seaborn<1.0.0,>=0.12.0
|
|
29
|
+
Requires-Dist: joblib<2.0.0,>=1.3.0
|
|
30
|
+
Requires-Dist: torch<2.3.0,>=2.2.0
|
|
31
|
+
Requires-Dist: torchvision<0.18.0,>=0.17.0
|
|
32
|
+
Requires-Dist: openpyxl<4.0.0,>=3.1.0
|
|
33
|
+
Requires-Dist: gradio>=5.0.0
|
|
34
|
+
Requires-Dist: plotly>=5.17.0
|
|
35
|
+
Requires-Dist: jinja2>=3.1.0
|
|
36
|
+
Requires-Dist: markdown>=3.5.0
|
|
37
|
+
Requires-Dist: weasyprint>=60.0
|
|
38
|
+
Provides-Extra: dev
|
|
39
|
+
Requires-Dist: pytest>=7.4.0; extra == "dev"
|
|
40
|
+
Requires-Dist: pytest-cov>=4.1.0; extra == "dev"
|
|
41
|
+
Requires-Dist: pytest-xdist>=3.3.0; extra == "dev"
|
|
42
|
+
Requires-Dist: black>=23.0.0; extra == "dev"
|
|
43
|
+
Requires-Dist: ruff>=0.1.0; extra == "dev"
|
|
44
|
+
Requires-Dist: mypy>=1.5.0; extra == "dev"
|
|
45
|
+
Requires-Dist: pre-commit>=3.5.0; extra == "dev"
|
|
46
|
+
Provides-Extra: e2e
|
|
47
|
+
Requires-Dist: playwright>=1.40.0; extra == "e2e"
|
|
48
|
+
Provides-Extra: docs
|
|
49
|
+
Requires-Dist: mkdocs>=1.5.0; extra == "docs"
|
|
50
|
+
Requires-Dist: mkdocs-material>=9.5.0; extra == "docs"
|
|
51
|
+
Requires-Dist: mkdocstrings[python]>=0.24.0; extra == "docs"
|
|
52
|
+
Requires-Dist: pymdown-extensions>=10.0; extra == "docs"
|
|
53
|
+
Provides-Extra: jupyter
|
|
54
|
+
Requires-Dist: jupyter>=1.0.0; extra == "jupyter"
|
|
55
|
+
Requires-Dist: ipykernel>=6.25.0; extra == "jupyter"
|
|
56
|
+
Requires-Dist: nbconvert>=7.8.0; extra == "jupyter"
|
|
57
|
+
Requires-Dist: ipywidgets>=8.1.0; extra == "jupyter"
|
|
58
|
+
Provides-Extra: reporting
|
|
59
|
+
Requires-Dist: plotly>=5.17.0; extra == "reporting"
|
|
60
|
+
Requires-Dist: jinja2>=3.1.0; extra == "reporting"
|
|
61
|
+
Requires-Dist: markdown>=3.5.0; extra == "reporting"
|
|
62
|
+
Requires-Dist: weasyprint>=60.0; extra == "reporting"
|
|
63
|
+
Requires-Dist: openpyxl>=3.1.0; extra == "reporting"
|
|
64
|
+
Provides-Extra: dashboard
|
|
65
|
+
Requires-Dist: gradio>=5.0.0; extra == "dashboard"
|
|
66
|
+
Provides-Extra: hf
|
|
67
|
+
Requires-Dist: transformers>=4.39.0; extra == "hf"
|
|
68
|
+
Provides-Extra: vae
|
|
69
|
+
Requires-Dist: torch<2.3.0,>=2.2.0; extra == "vae"
|
|
70
|
+
Requires-Dist: torchvision<0.18.0,>=0.17.0; extra == "vae"
|
|
71
|
+
Provides-Extra: mcp
|
|
72
|
+
Requires-Dist: mcp>=1.0.0; extra == "mcp"
|
|
73
|
+
Provides-Extra: all
|
|
74
|
+
Requires-Dist: black>=23.0.0; extra == "all"
|
|
75
|
+
Requires-Dist: mcp>=1.0.0; extra == "all"
|
|
76
|
+
Requires-Dist: gradio>=5.0.0; extra == "all"
|
|
77
|
+
Requires-Dist: ipykernel>=6.25.0; extra == "all"
|
|
78
|
+
Requires-Dist: ipywidgets>=8.1.0; extra == "all"
|
|
79
|
+
Requires-Dist: jinja2>=3.1.0; extra == "all"
|
|
80
|
+
Requires-Dist: markdown>=3.5.0; extra == "all"
|
|
81
|
+
Requires-Dist: mypy>=1.5.0; extra == "all"
|
|
82
|
+
Requires-Dist: mkdocstrings[python]>=0.24.0; extra == "all"
|
|
83
|
+
Requires-Dist: mkdocs-material>=9.5.0; extra == "all"
|
|
84
|
+
Requires-Dist: mkdocs>=1.5.0; extra == "all"
|
|
85
|
+
Requires-Dist: nbconvert>=7.8.0; extra == "all"
|
|
86
|
+
Requires-Dist: openpyxl>=3.1.0; extra == "all"
|
|
87
|
+
Requires-Dist: plotly>=5.17.0; extra == "all"
|
|
88
|
+
Requires-Dist: pre-commit>=3.5.0; extra == "all"
|
|
89
|
+
Requires-Dist: pymdown-extensions>=10.0; extra == "all"
|
|
90
|
+
Requires-Dist: pytest-cov>=4.1.0; extra == "all"
|
|
91
|
+
Requires-Dist: pytest-xdist>=3.3.0; extra == "all"
|
|
92
|
+
Requires-Dist: pytest>=7.4.0; extra == "all"
|
|
93
|
+
Requires-Dist: ruff>=0.1.0; extra == "all"
|
|
94
|
+
Requires-Dist: torch<2.3.0,>=2.2.0; extra == "all"
|
|
95
|
+
Requires-Dist: torchvision<0.18.0,>=0.17.0; extra == "all"
|
|
96
|
+
Requires-Dist: transformers>=4.39.0; extra == "all"
|
|
97
|
+
Requires-Dist: weasyprint>=60.0; extra == "all"
|
|
98
|
+
Dynamic: license-file
|
|
99
|
+
|
|
100
|
+
# ShortKit-ML
|
|
101
|
+
|
|
102
|
+
> **ShortKit-ML** — Detect and mitigate shortcuts and biases in machine learning embedding spaces. 20+ detection and mitigation methods with a unified API. **Multi-attribute support** tests multiple sensitive attributes simultaneously. Model Comparison mode for benchmarking multiple embedding models.
|
|
103
|
+
|
|
104
|
+
[](https://www.python.org/downloads/)
|
|
105
|
+
[](https://pytorch.org/)
|
|
106
|
+
[](https://github.com/criticaldata/ShortKit-ML/actions/workflows/tests.yml)
|
|
107
|
+
[](https://huggingface.co/datasets/MITCriticalData/ShortKit-ML-data)
|
|
108
|
+
[](https://criticaldata.github.io/ShortKit-ML/)
|
|
109
|
+
|
|
110
|
+
## Table of Contents
|
|
111
|
+
|
|
112
|
+
- [Overview](#overview)
|
|
113
|
+
- [Installation](#installation)
|
|
114
|
+
- [Quick Start](#quick-start)
|
|
115
|
+
- [Detection Methods](#detection-methods)
|
|
116
|
+
- [Overall Assessment Conditions](#overall-assessment-conditions)
|
|
117
|
+
- [MCP Server](#mcp-server)
|
|
118
|
+
- [Paper Benchmarks](#paper-benchmark-datasets)
|
|
119
|
+
- [Reproducing Paper Results](#reproducing-paper-results)
|
|
120
|
+
- [GPU Support](#gpu-support)
|
|
121
|
+
- [Interactive Dashboard](#interactive-dashboard)
|
|
122
|
+
- [Testing](#testing)
|
|
123
|
+
- [Contributing](#contributing)
|
|
124
|
+
- [Citation](#citation)
|
|
125
|
+
|
|
126
|
+
## Overview
|
|
127
|
+
|
|
128
|
+
ShortKit-ML provides a comprehensive toolkit for detecting and mitigating shortcuts (unwanted biases) in embedding spaces:
|
|
129
|
+
|
|
130
|
+
- **20+ detection methods**: HBAC, Probe, Statistical, Geometric, Bias Direction PCA, Equalized Odds, Demographic Parity, Intersectional, GroupDRO, GCE, Causal Effect, SSA, SIS, CAV, VAE, Early-Epoch Clustering, and more
|
|
131
|
+
- **6 mitigation methods**: Shortcut Masking, Background Randomization, Adversarial Debiasing, Explanation Regularization, Last Layer Retraining, Contrastive Debiasing
|
|
132
|
+
- **5 pluggable risk conditions**: indicator_count, majority_vote, weighted_risk, multi_attribute, meta_classifier
|
|
133
|
+
|
|
134
|
+
**Key Features:**
|
|
135
|
+
- Unified `ShortcutDetector` API for all methods
|
|
136
|
+
- Interactive Gradio dashboard with real-time analysis
|
|
137
|
+
- PDF/HTML/Markdown reports with visualizations
|
|
138
|
+
- Embedding-only mode (no model access needed)
|
|
139
|
+
- Multi-attribute support: test race, gender, age simultaneously
|
|
140
|
+
- Model Comparison mode: compare multiple embedding models side-by-side
|
|
141
|
+
|
|
142
|
+
## Installation
|
|
143
|
+
|
|
144
|
+
```bash
|
|
145
|
+
pip install shortkit-ml
|
|
146
|
+
```
|
|
147
|
+
|
|
148
|
+
For all optional extras (dashboard, reporting, VAE, HuggingFace, etc.):
|
|
149
|
+
|
|
150
|
+
```bash
|
|
151
|
+
pip install "shortkit-ml[all]"
|
|
152
|
+
```
|
|
153
|
+
|
|
154
|
+
### Development Install (from source)
|
|
155
|
+
|
|
156
|
+
```bash
|
|
157
|
+
git clone https://github.com/criticaldata/ShortKit-ML.git
|
|
158
|
+
cd ShortKit-ML
|
|
159
|
+
pip install -e ".[all]"
|
|
160
|
+
```
|
|
161
|
+
|
|
162
|
+
Or with `uv`:
|
|
163
|
+
|
|
164
|
+
```bash
|
|
165
|
+
uv venv --python 3.10
|
|
166
|
+
source .venv/bin/activate # Windows: .venv\Scripts\activate
|
|
167
|
+
uv pip install -e ".[all]"
|
|
168
|
+
```
|
|
169
|
+
|
|
170
|
+
### Optional: PDF Export Dependencies
|
|
171
|
+
|
|
172
|
+
```bash
|
|
173
|
+
# macOS
|
|
174
|
+
brew install pango gdk-pixbuf libffi
|
|
175
|
+
# Ubuntu/Debian
|
|
176
|
+
sudo apt-get install libpango-1.0-0 libpangocairo-1.0-0 libgdk-pixbuf2.0-0
|
|
177
|
+
```
|
|
178
|
+
|
|
179
|
+
> HTML and Markdown reports work without these. PDF export is optional.
|
|
180
|
+
|
|
181
|
+
## Quick Start
|
|
182
|
+
|
|
183
|
+
```python
|
|
184
|
+
from shortcut_detect import ShortcutDetector
|
|
185
|
+
import numpy as np
|
|
186
|
+
|
|
187
|
+
embeddings = np.load("embeddings.npy") # (n_samples, embedding_dim)
|
|
188
|
+
labels = np.load("labels.npy") # (n_samples,)
|
|
189
|
+
|
|
190
|
+
detector = ShortcutDetector(methods=['hbac', 'probe', 'statistical', 'geometric', 'equalized_odds'])
|
|
191
|
+
detector.fit(embeddings, labels)
|
|
192
|
+
|
|
193
|
+
detector.generate_report("report.html", format="html")
|
|
194
|
+
print(detector.summary())
|
|
195
|
+
```
|
|
196
|
+
|
|
197
|
+
### Embedding-Only Mode
|
|
198
|
+
|
|
199
|
+
For closed-source models or systems that only expose embeddings:
|
|
200
|
+
|
|
201
|
+
```python
|
|
202
|
+
from shortcut_detect import ShortcutDetector, HuggingFaceEmbeddingSource
|
|
203
|
+
|
|
204
|
+
hf_source = HuggingFaceEmbeddingSource(model_name="sentence-transformers/all-MiniLM-L6-v2")
|
|
205
|
+
detector = ShortcutDetector(methods=["probe", "statistical"])
|
|
206
|
+
detector.fit(embeddings=None, labels=labels, group_labels=groups,
|
|
207
|
+
raw_inputs=texts, embedding_source=hf_source)
|
|
208
|
+
```
|
|
209
|
+
|
|
210
|
+
> See [Embedding-Only Guide](docs/methods/overview.md) for `CallableEmbeddingSource` and caching options.
|
|
211
|
+
|
|
212
|
+
## Detection Methods
|
|
213
|
+
|
|
214
|
+
| Method | Key | What It Detects | Reference |
|
|
215
|
+
|--------|-----|-----------------|-----------|
|
|
216
|
+
| **HBAC** | `hbac` | Clustering by protected attributes | - |
|
|
217
|
+
| **Probe** | `probe` | Group info recoverable from embeddings | - |
|
|
218
|
+
| **Statistical** | `statistical` | Dimensions with group differences | - |
|
|
219
|
+
| **Geometric** | `geometric` | Bias directions & prototype overlap | - |
|
|
220
|
+
| **Bias Direction PCA** | `bias_direction_pca` | Projection gap along bias direction | Bolukbasi 2016 |
|
|
221
|
+
| **Equalized Odds** | `equalized_odds` | TPR/FPR disparities | Hardt 2016 |
|
|
222
|
+
| **Demographic Parity** | `demographic_parity` | Prediction rate disparities | Feldman 2015 |
|
|
223
|
+
| **Early Epoch Clustering** | `early_epoch_clustering` | Shortcut reliance in early reps | Yang 2023 |
|
|
224
|
+
| **GCE** | `gce` | High-loss minority samples | - |
|
|
225
|
+
| **Frequency** | `frequency` | Signal in few dimensions | - |
|
|
226
|
+
| **GradCAM Mask Overlap** | `gradcam_mask_overlap` | Attention overlap with shortcut masks | - |
|
|
227
|
+
| **SpRAy** | `spray` | Spectral clustering of heatmaps | Lapuschkin 2019 |
|
|
228
|
+
| **CAV** | `cav` | Concept-level sensitivity | Kim 2018 |
|
|
229
|
+
| **Causal Effect** | `causal_effect` | Spurious attribute influence | - |
|
|
230
|
+
| **VAE** | `vae` | Latent disentanglement signatures | - |
|
|
231
|
+
| **SSA** | `ssa` | Semi-supervised spectral shift | [arXiv:2204.02070](https://arxiv.org/abs/2204.02070) |
|
|
232
|
+
| **Generative CVAE** | `generative_cvae` | Counterfactual embedding shifts | - |
|
|
233
|
+
| **GroupDRO** | `groupdro` | Worst-group performance gaps | Sagawa 2020 |
|
|
234
|
+
| **SIS** | `sis` | Sufficient input subsets (minimal dims for prediction) | Carter 2019 |
|
|
235
|
+
| **Intersectional** | `intersectional` | Intersectional fairness gaps (2+ attributes) | Buolamwini 2018 |
|
|
236
|
+
|
|
237
|
+
### Mitigation Methods
|
|
238
|
+
|
|
239
|
+
| Method | Class | Strategy | Reference |
|
|
240
|
+
|--------|-------|----------|-----------|
|
|
241
|
+
| **Shortcut Masking** | `ShortcutMasker` | Zero/randomize/inpaint shortcut regions | - |
|
|
242
|
+
| **Background Randomization** | `BackgroundRandomizer` | Swap foreground across backgrounds | - |
|
|
243
|
+
| **Adversarial Debiasing** | `AdversarialDebiasing` | Remove group information adversarially | Zhang 2018 |
|
|
244
|
+
| **Explanation Regularization** | `ExplanationRegularization` | Penalize attention on shortcuts (RRR) | Ross 2017 |
|
|
245
|
+
| **Last Layer Retraining** | `LastLayerRetraining` | Retrain final layer balanced (DFR) | Kirichenko 2023 |
|
|
246
|
+
| **Contrastive Debiasing** | `ContrastiveDebiasing` | Contrastive loss to align groups (CNC) | - |
|
|
247
|
+
|
|
248
|
+
> See [Detection Methods Overview](docs/methods/overview.md) for per-method usage, interpretation guides, and code examples.
|
|
249
|
+
|
|
250
|
+
## Overall Assessment Conditions
|
|
251
|
+
|
|
252
|
+
`ShortcutDetector` supports pluggable risk aggregation conditions that control how method-level results map to the final HIGH/MODERATE/LOW summary.
|
|
253
|
+
|
|
254
|
+
| Condition | Best For | Description |
|
|
255
|
+
|-----------|----------|-------------|
|
|
256
|
+
| `indicator_count` | General use (default) | Count of risk signals: 2+ = HIGH, 1 = MODERATE, 0 = LOW |
|
|
257
|
+
| `majority_vote` | Conservative screening | Consensus across methods |
|
|
258
|
+
| `weighted_risk` | Nuanced analysis | Evidence strength matters (probe accuracy, effect sizes, etc.) |
|
|
259
|
+
| `multi_attribute` | Multi-demographic | Escalates when multiple attributes flag risk |
|
|
260
|
+
| `meta_classifier` | Trained pipelines | Logistic regression meta-model on detector outputs (bundled model included) |
|
|
261
|
+
|
|
262
|
+
```python
|
|
263
|
+
detector = ShortcutDetector(
|
|
264
|
+
methods=["probe", "statistical"],
|
|
265
|
+
condition_name="weighted_risk",
|
|
266
|
+
condition_kwargs={"high_threshold": 0.6, "moderate_threshold": 0.3},
|
|
267
|
+
)
|
|
268
|
+
```
|
|
269
|
+
|
|
270
|
+
Custom conditions can be registered via `@register_condition("name")`. See [Conditions API](docs/api/shortcut-detector.md) for details.
|
|
271
|
+
|
|
272
|
+
## MCP Server
|
|
273
|
+
|
|
274
|
+
ShortKit-ML ships an [MCP](https://modelcontextprotocol.io/) server so AI assistants (Claude, Cursor, etc.) can call detection tools directly from chat — no Python script required.
|
|
275
|
+
|
|
276
|
+
### Install the MCP extra
|
|
277
|
+
|
|
278
|
+
```bash
|
|
279
|
+
pip install -e ".[mcp]"
|
|
280
|
+
```
|
|
281
|
+
|
|
282
|
+
### Start the server
|
|
283
|
+
|
|
284
|
+
```bash
|
|
285
|
+
# via entry point (after install)
|
|
286
|
+
shortkit-ml-mcp
|
|
287
|
+
|
|
288
|
+
# or directly
|
|
289
|
+
python -m shortcut_detect.mcp_server
|
|
290
|
+
```
|
|
291
|
+
|
|
292
|
+
### Available tools
|
|
293
|
+
|
|
294
|
+
| Tool | Description |
|
|
295
|
+
|------|-------------|
|
|
296
|
+
| `list_methods` | List all 19 detection methods with descriptions |
|
|
297
|
+
| `generate_synthetic_data` | Generate a synthetic shortcut dataset (linear / nonlinear / none) |
|
|
298
|
+
| `run_detector` | Run selected methods on embeddings — returns verdict, risk level, per-method breakdown |
|
|
299
|
+
| `get_summary` | Human-readable summary from a prior `run_detector` call |
|
|
300
|
+
| `get_method_detail` | Full raw result dict for a single method |
|
|
301
|
+
| `compare_methods` | Side-by-side comparison table + consensus vote across methods |
|
|
302
|
+
|
|
303
|
+
### Connect to Claude Desktop
|
|
304
|
+
|
|
305
|
+
Add the following to `~/Library/Application Support/Claude/claude_desktop_config.json` (macOS):
|
|
306
|
+
|
|
307
|
+
```json
|
|
308
|
+
{
|
|
309
|
+
"mcpServers": {
|
|
310
|
+
"shortkit-ml": {
|
|
311
|
+
"command": "python",
|
|
312
|
+
"args": ["-m", "shortcut_detect.mcp_server"],
|
|
313
|
+
"cwd": "/path/to/ShortKit-ML"
|
|
314
|
+
}
|
|
315
|
+
}
|
|
316
|
+
}
|
|
317
|
+
```
|
|
318
|
+
|
|
319
|
+
A ready-to-edit template is included at [`claude_desktop_config.json`](claude_desktop_config.json).
|
|
320
|
+
|
|
321
|
+
## Paper Benchmark Datasets
|
|
322
|
+
|
|
323
|
+
### Dataset 1 -- Synthetic Grid
|
|
324
|
+
|
|
325
|
+
Configure `examples/paper_benchmark_config.json` to control effect sizes, sample sizes, imbalance ratios, and embedding dimensionalities. A smoke profile (`examples/paper_benchmark_config_smoke.json`) is provided for quick sanity checks.
|
|
326
|
+
|
|
327
|
+
```bash
|
|
328
|
+
python -m shortcut_detect.benchmark.paper_run --config examples/paper_benchmark_config.json
|
|
329
|
+
```
|
|
330
|
+
|
|
331
|
+
Outputs CSVs, figures, and summary markdown into `output/paper_benchmark/`.
|
|
332
|
+
|
|
333
|
+
### Dataset 2 -- CheXpert Real Data
|
|
334
|
+
|
|
335
|
+
Requires a CheXpert manifest (`data/chexpert_manifest.csv`) plus model-specific embedding pickles. Supported models: `medclip`, `biomedclip`, `cxr-foundation`.
|
|
336
|
+
|
|
337
|
+
```bash
|
|
338
|
+
python3 scripts/run_dataset2_benchmark.py \
|
|
339
|
+
--manifest data/chexpert_manifest.csv \
|
|
340
|
+
--model medclip \
|
|
341
|
+
--root . \
|
|
342
|
+
--artifacts-dir output/paper_benchmark/chexpert_embeddings \
|
|
343
|
+
--config examples/paper_benchmark_config.json
|
|
344
|
+
```
|
|
345
|
+
|
|
346
|
+
See `scripts/reproduce_paper.sh` and the Dockerfile for full reproducibility.
|
|
347
|
+
|
|
348
|
+
## Reproducing Paper Results
|
|
349
|
+
|
|
350
|
+
All paper results are fully reproducible with fixed seeds (`seed=42`). Every table and figure in the paper can be regenerated from the scripts and data in this repository.
|
|
351
|
+
|
|
352
|
+
**13 benchmark methods** are evaluated across all datasets: `hbac`, `probe`, `statistical`, `geometric`, `frequency`, `bias_direction_pca`, `sis`, `demographic_parity`, `equalized_odds`, `intersectional`, `groupdro`, `gce`, `ssa`. These span 5 paradigms: embedding-level analysis, representation geometry, fairness evaluation, explainability, and training dynamics.
|
|
353
|
+
|
|
354
|
+
### Step-by-step Reproduction
|
|
355
|
+
|
|
356
|
+
| Step | Command | Output | Time |
|
|
357
|
+
|------|---------|--------|------|
|
|
358
|
+
| 1. Install | `pip install -e ".[all]"` | Package + deps | 2 min |
|
|
359
|
+
| 2. Synthetic benchmarks | `python scripts/generate_all_paper_tables.py` | `output/paper_tables/*.tex` | ~10 min |
|
|
360
|
+
| 3. Paper figures | `python scripts/generate_paper_figures.py` | `output/paper_figures/*.pdf` | ~2 min |
|
|
361
|
+
| 4. CheXpert benchmark | `python scripts/run_chexpert_benchmark.py` | `output/paper_benchmark/chexpert_results/` | ~1 min |
|
|
362
|
+
| 5. MIMIC-CXR setup | `python scripts/setup_mimic_cxr_data.py` | `data/mimic_cxr/*.npy` | ~1 min |
|
|
363
|
+
| 6. MIMIC-CXR benchmark | `python scripts/run_mimic_benchmark.py` | `output/paper_benchmark/mimic_cxr_results/` | ~2 min |
|
|
364
|
+
| 7. CelebA extraction | `python scripts/extract_celeba_embeddings.py` | `data/celeba/celeba_real_*.npy` | ~5 min (MPS) |
|
|
365
|
+
| 8. CelebA benchmark | `python scripts/run_celeba_real_benchmark.py` | `output/paper_benchmark/celeba_real_results/` | ~1 min |
|
|
366
|
+
| 9. Full pipeline (smoke) | `./scripts/reproduce_paper.sh smoke` | All synthetic outputs | ~5 min |
|
|
367
|
+
| 10. Full pipeline | `./scripts/reproduce_paper.sh full` | All synthetic outputs | ~2-4 hrs |
|
|
368
|
+
|
|
369
|
+
### Docker (fully self-contained)
|
|
370
|
+
```bash
|
|
371
|
+
docker build -t shortcut-detect .
|
|
372
|
+
docker run --rm -v $(pwd)/output:/app/output shortcut-detect full
|
|
373
|
+
```
|
|
374
|
+
|
|
375
|
+
### Data Sources
|
|
376
|
+
|
|
377
|
+
All embeddings are hosted on HuggingFace: **[MITCriticalData/ShortKit-ML-data](https://huggingface.co/datasets/MITCriticalData/ShortKit-ML-data)**
|
|
378
|
+
|
|
379
|
+
```bash
|
|
380
|
+
# Download all data into data/
|
|
381
|
+
huggingface-cli download MITCriticalData/ShortKit-ML-data --repo-type dataset --local-dir data/
|
|
382
|
+
```
|
|
383
|
+
|
|
384
|
+
| Dataset | Location | Embedding Models | Dim | Samples |
|
|
385
|
+
|---------|----------|-----------------|-----|---------|
|
|
386
|
+
| Synthetic | Generated at runtime | `SyntheticGenerator(seed=42)` | 128 | Configurable |
|
|
387
|
+
| CheXpert | `data/chexpert/` | MedCLIP, ResNet-50, DenseNet-121, ViT-B/16, ViT-B/32, DINOv2, RAD-DINO, MedSigLIP | 512-2048 | 2,000 each |
|
|
388
|
+
| MIMIC-CXR | `data/mimic_cxr/` | RAD-DINO, ViT-B/16, ViT-B/32, MedSigLIP | 768-1152 | ~1,500 each |
|
|
389
|
+
| CelebA | `data/celeba/` | ResNet-50 (ImageNet) | 2,048 | 10,000 |
|
|
390
|
+
|
|
391
|
+
### Paper Tables → Scripts Mapping
|
|
392
|
+
|
|
393
|
+
| Paper Table | Script | Data | Seed |
|
|
394
|
+
|-------------|--------|------|------|
|
|
395
|
+
| Tab 3: Synthetic P/R/F1 | `generate_all_paper_tables.py` | `SyntheticGenerator` | 42 |
|
|
396
|
+
| Tab 4: False positive rates | `generate_all_paper_tables.py` | `SyntheticGenerator` (null) | 42 |
|
|
397
|
+
| Tab 5: Sensitivity analysis | `generate_all_paper_tables.py` | `SensitivitySweep` | 42 |
|
|
398
|
+
| Tab 6: CheXpert results | `run_chexpert_benchmark.py` | `data/chest_embeddings.npy` | 42 |
|
|
399
|
+
| Tab 7: MIMIC-CXR cross-val | `run_mimic_benchmark.py` | `data/mimic_cxr/*.npy` | 42 |
|
|
400
|
+
| Tab 8: CelebA validation | `run_celeba_real_benchmark.py` | `data/celeba/celeba_real_embeddings.npy` | 42 |
|
|
401
|
+
| Tab 9: Risk conditions | `generate_all_paper_tables.py` | `SyntheticGenerator` | 42 |
|
|
402
|
+
| Fig 2: Convergence matrix | `generate_paper_figures.py` | Synthetic + CheXpert | 42 |
|
|
403
|
+
|
|
404
|
+
See `docs/reproducibility.md` for full details.
|
|
405
|
+
|
|
406
|
+
## GPU Support
|
|
407
|
+
|
|
408
|
+
The library auto-selects the best available device. PyTorch components (probes, VAE, GroupDRO, adversarial debiasing, etc.) use the standard `torch.device` fallback:
|
|
409
|
+
|
|
410
|
+
| Platform | Backend | Auto-detected |
|
|
411
|
+
|----------|---------|---------------|
|
|
412
|
+
| Linux/Windows with NVIDIA GPU | CUDA | Yes (`torch.cuda.is_available()`) |
|
|
413
|
+
| macOS Apple Silicon | MPS | Partial -- pass `device="mps"` explicitly |
|
|
414
|
+
| CPU (any platform) | CPU | Yes (default fallback) |
|
|
415
|
+
|
|
416
|
+
> **Note:** Most detection methods (HBAC, statistical, geometric, etc.) run on CPU via NumPy/scikit-learn and do not require GPU. GPU acceleration benefits the torch-based probe, VAE, GroupDRO, and mitigation methods. MPS support depends on PyTorch operator coverage; if you encounter errors on Apple Silicon, fall back to `device="cpu"`.
|
|
417
|
+
|
|
418
|
+
## Interactive Dashboard
|
|
419
|
+
|
|
420
|
+
```bash
|
|
421
|
+
python app.py
|
|
422
|
+
# Opens at http://127.0.0.1:7860
|
|
423
|
+
```
|
|
424
|
+
|
|
425
|
+
Features: sample CheXpert data, custom CSV upload, PDF/HTML reports, model comparison tab, multi-attribute analysis.
|
|
426
|
+
|
|
427
|
+
**CSV Format:**
|
|
428
|
+
```csv
|
|
429
|
+
embedding_0,embedding_1,...,task_label,group_label,attr_race,attr_gender
|
|
430
|
+
0.123,0.456,...,1,group_a,Black,Male
|
|
431
|
+
```
|
|
432
|
+
|
|
433
|
+
> See [Dashboard Guide](docs/getting-started/dashboard.md) for detailed usage.
|
|
434
|
+
|
|
435
|
+
## Testing
|
|
436
|
+
|
|
437
|
+
```bash
|
|
438
|
+
pytest tests/ -v
|
|
439
|
+
pytest --cov=shortcut_detect --cov-report=html
|
|
440
|
+
```
|
|
441
|
+
|
|
442
|
+
**638 tests passing** across all detection and mitigation methods.
|
|
443
|
+
|
|
444
|
+
## Contributing
|
|
445
|
+
|
|
446
|
+
```bash
|
|
447
|
+
pip install -e ".[dev]"
|
|
448
|
+
pre-commit install
|
|
449
|
+
```
|
|
450
|
+
|
|
451
|
+
- **Black** for formatting (line length: 100), **Ruff** for linting, **MyPy** for types
|
|
452
|
+
- Pre-commit hooks run automatically; CI tests on Python 3.10, 3.11, 3.12
|
|
453
|
+
- New detectors must implement `DetectorBase`. See `docs/contributing.md` and `shortcut_detect/detector_template.py`
|
|
454
|
+
|
|
455
|
+
## Project Structure
|
|
456
|
+
|
|
457
|
+
```
|
|
458
|
+
shortcut_detect/
|
|
459
|
+
├── probes/ # Probe-based detection (sklearn + torch)
|
|
460
|
+
├── clustering/ # HBAC detector
|
|
461
|
+
├── statistical/ # Statistical testing
|
|
462
|
+
├── geometric/ # Geometric & bias direction analysis
|
|
463
|
+
├── fairness/ # Equalized Odds, Demographic Parity, Intersectional
|
|
464
|
+
├── frequency/ # Frequency shortcut detector
|
|
465
|
+
├── causal/ # Causal effect detector
|
|
466
|
+
├── gce/ # Generalized cross-entropy detector
|
|
467
|
+
├── training/ # Early epoch clustering (SPARE)
|
|
468
|
+
├── vae/ # VAE latent disentanglement
|
|
469
|
+
├── xai/ # CAV, SpRAy, GradCAM mask overlap, SIS
|
|
470
|
+
├── ssa/ # Semi-supervised spectral analysis
|
|
471
|
+
├── groupdro/ # GroupDRO worst-group robustness
|
|
472
|
+
├── conditions/ # Pluggable risk aggregation conditions
|
|
473
|
+
│ ├── base.py, registry.py, indicator_count.py, majority_vote.py
|
|
474
|
+
│ ├── weighted_risk.py, multi_attribute.py, meta_classifier.py
|
|
475
|
+
│ └── meta_model.joblib # Trained meta-classifier (bundled)
|
|
476
|
+
├── benchmark/ # Paper benchmark infrastructure
|
|
477
|
+
│ ├── runner.py, paper_runner.py, synthetic_generator.py
|
|
478
|
+
│ ├── measurement.py, fp_analysis.py, sensitivity.py
|
|
479
|
+
│ ├── convergence_viz.py, baseline_comparison.py, figures.py
|
|
480
|
+
├── comparison/ # Model comparison runner
|
|
481
|
+
├── mitigation/ # Debiasing & masking methods (M01-M07)
|
|
482
|
+
├── reporting/ # HTML/PDF/CSV reports & visualizations
|
|
483
|
+
├── unified.py # ShortcutDetector unified API
|
|
484
|
+
└── detector_base.py # DetectorBase ABC with results_ schema
|
|
485
|
+
|
|
486
|
+
docs/ # MkDocs documentation site
|
|
487
|
+
examples/ # Notebooks and benchmark configs
|
|
488
|
+
app.py # Gradio dashboard
|
|
489
|
+
Dockerfile # Reproducible environment
|
|
490
|
+
scripts/ # Paper reproduction scripts
|
|
491
|
+
tests/ # Test suite (475+ tests)
|
|
492
|
+
```
|
|
493
|
+
|
|
494
|
+
## Documentation
|
|
495
|
+
|
|
496
|
+
```bash
|
|
497
|
+
pip install mkdocs mkdocs-material "mkdocstrings[python]" pymdown-extensions
|
|
498
|
+
mkdocs serve # http://127.0.0.1:8000
|
|
499
|
+
```
|
|
500
|
+
|
|
501
|
+
- [Getting Started](docs/getting-started/installation.md)
|
|
502
|
+
- [Detection Methods](docs/methods/overview.md) -- all 20+ methods with guides
|
|
503
|
+
- [API Reference](docs/api/shortcut-detector.md)
|
|
504
|
+
- [Contributing](docs/contributing.md)
|
|
505
|
+
|
|
506
|
+
## Citation
|
|
507
|
+
|
|
508
|
+
```bibtex
|
|
509
|
+
@software{shortkit_ml2025,
|
|
510
|
+
title={ShortKit-ML: Tools for Identifying Biases in Embedding Spaces},
|
|
511
|
+
author={Sebastian Cajas, Aldo Marzullo, Sahil Kapadia, Qingpeng Kong, Filipe Santos, Alessandro Quarta, Leo Celi},
|
|
512
|
+
year={2025},
|
|
513
|
+
url={https://github.com/criticaldata/ShortKit-ML}
|
|
514
|
+
}
|
|
515
|
+
```
|
|
516
|
+
|
|
517
|
+
## License
|
|
518
|
+
|
|
519
|
+
MIT License - see [LICENSE](LICENSE) file
|
|
520
|
+
|
|
521
|
+
## Contact
|
|
522
|
+
|
|
523
|
+
- **GitHub**: [criticaldata/ShortKit-ML](https://github.com/criticaldata/ShortKit-ML)
|
|
524
|
+
- **Issues**: [GitHub Issues](https://github.com/criticaldata/ShortKit-ML/issues)
|