interpreto 0.5.0.dev0__tar.gz → 0.5.0.dev1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/PKG-INFO +23 -13
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/README.md +21 -11
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/__init__.py +2 -1
- interpreto-0.5.0.dev1/interpreto/_vendor/overcomplete/base.pyi +16 -0
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/attributions/base.py +101 -5
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/attributions/methods/kernel_shap.py +3 -2
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/attributions/methods/lime.py +3 -2
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/attributions/methods/occlusion.py +4 -4
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/attributions/methods/sobol_attribution.py +3 -2
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/concepts/__init__.py +22 -0
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/concepts/base.py +116 -43
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/concepts/interpretations/base.py +11 -5
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/concepts/interpretations/llm_labels.py +1 -1
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/concepts/interpretations/topk_inputs.py +1 -1
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/concepts/methods/sklearn_wrappers.py +1 -1
- interpreto-0.5.0.dev1/interpreto/concepts/probes/__init__.py +81 -0
- interpreto-0.5.0.dev1/interpreto/concepts/probes/base.py +279 -0
- interpreto-0.5.0.dev1/interpreto/concepts/probes/bias_calibrators.py +223 -0
- interpreto-0.5.0.dev1/interpreto/concepts/probes/centroid.py +432 -0
- interpreto-0.5.0.dev1/interpreto/concepts/probes/linear.py +418 -0
- interpreto-0.5.0.dev1/interpreto/concepts/probes/normalizations.py +201 -0
- interpreto-0.5.0.dev1/interpreto/concepts/probes/sklearn.py +140 -0
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/model_wrapping/__init__.py +2 -1
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/model_wrapping/inference_wrapper.py +2 -0
- interpreto-0.5.0.dev1/interpreto/model_wrapping/inputs_to_concepts_inference_wrapper.py +91 -0
- interpreto-0.5.0.dev1/interpreto/model_wrapping/split_sequence_classification.py +427 -0
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/typing.py +1 -8
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto.egg-info/PKG-INFO +23 -13
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto.egg-info/SOURCES.txt +10 -1
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto.egg-info/requires.txt +1 -1
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/mkdocs.yml +17 -9
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/pyproject.toml +2 -2
- interpreto-0.5.0.dev0/interpreto/concepts/plots/__init__.py +0 -0
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/.editorconfig +0 -0
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/.gitignore +0 -0
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/.pre-commit-config.yaml +0 -0
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/AGENTS.md +0 -0
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/LICENSE +0 -0
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/MANIFEST.in +0 -0
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/_vendor/overcomplete/VENDORED_FROM +0 -0
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/_vendor/overcomplete/__init__.py +0 -0
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/_vendor/overcomplete/_sync_overcomplete.py +0 -0
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/_vendor/overcomplete/base.py +0 -0
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/_vendor/overcomplete/data.py +0 -0
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/_vendor/overcomplete/metrics.py +0 -0
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/_vendor/overcomplete/optimization/__init__.py +0 -0
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/_vendor/overcomplete/optimization/archetypal_analysis.py +0 -0
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/_vendor/overcomplete/optimization/base.py +0 -0
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/_vendor/overcomplete/optimization/convex_nmf.py +0 -0
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/_vendor/overcomplete/optimization/nmf.py +0 -0
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/_vendor/overcomplete/optimization/semi_nmf.py +0 -0
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/_vendor/overcomplete/optimization/sklearn_wrappers.py +0 -0
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/_vendor/overcomplete/optimization/utils.py +0 -0
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/_vendor/overcomplete/overcomplete.patch +0 -0
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/_vendor/overcomplete/sae/__init__.py +0 -0
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/_vendor/overcomplete/sae/archetypal_dictionary.py +0 -0
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/_vendor/overcomplete/sae/base.py +0 -0
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/_vendor/overcomplete/sae/batchtopk_sae.py +0 -0
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/_vendor/overcomplete/sae/dictionary.py +0 -0
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/_vendor/overcomplete/sae/factory.py +0 -0
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/_vendor/overcomplete/sae/jump_sae.py +0 -0
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/_vendor/overcomplete/sae/kernels.py +0 -0
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/_vendor/overcomplete/sae/losses.py +0 -0
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/_vendor/overcomplete/sae/modules.py +0 -0
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/_vendor/overcomplete/sae/mp_sae.py +0 -0
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/_vendor/overcomplete/sae/omp_sae.py +0 -0
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/_vendor/overcomplete/sae/optimizer.py +0 -0
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/_vendor/overcomplete/sae/qsae.py +0 -0
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/_vendor/overcomplete/sae/rasae.py +0 -0
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/_vendor/overcomplete/sae/topk_sae.py +0 -0
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/_vendor/overcomplete/sae/trackers.py +0 -0
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/_vendor/overcomplete/sae/train.py +0 -0
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/attributions/__init__.py +0 -0
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/attributions/aggregations/__init__.py +0 -0
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/attributions/aggregations/base.py +0 -0
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/attributions/aggregations/linear_regression_aggregation.py +0 -0
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/attributions/aggregations/sobol_aggregation.py +0 -0
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/attributions/methods/__init__.py +0 -0
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/attributions/methods/gradient_shap.py +0 -0
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/attributions/methods/integrated_gradients.py +0 -0
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/attributions/methods/saliency.py +0 -0
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/attributions/methods/smooth_grad.py +0 -0
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/attributions/methods/square_grad.py +0 -0
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/attributions/methods/var_grad.py +0 -0
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/attributions/metrics/__init__.py +0 -0
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/attributions/metrics/insertion_deletion.py +0 -0
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/attributions/perturbations/__init__.py +0 -0
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/attributions/perturbations/base.py +0 -0
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/attributions/perturbations/gaussian_noise_perturbation.py +0 -0
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/attributions/perturbations/gradient_shap_perturbation.py +0 -0
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/attributions/perturbations/insertion_deletion_perturbation.py +0 -0
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/attributions/perturbations/linear_interpolation_perturbation.py +0 -0
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/attributions/perturbations/occlusion_perturbation.py +0 -0
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/attributions/perturbations/random_perturbation.py +0 -0
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/attributions/perturbations/shap_perturbation.py +0 -0
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/attributions/perturbations/sobol_perturbation.py +0 -0
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/attributions/plots/__init__.py +0 -0
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/commons/__init__.py +0 -0
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/commons/distances.py +0 -0
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/commons/generator_tools.py +0 -0
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/commons/granularity.py +0 -0
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/concepts/interpretations/__init__.py +0 -0
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/concepts/methods/__init__.py +0 -0
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/concepts/methods/cockatiel.py +0 -0
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/concepts/methods/neurons_as_concepts.py +0 -0
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/concepts/methods/overcomplete.py +0 -0
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/concepts/metrics/__init__.py +0 -0
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/concepts/metrics/consim.py +0 -0
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/concepts/metrics/dictionary_metrics.py +0 -0
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/concepts/metrics/reconstruction_metrics.py +0 -0
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/concepts/metrics/sparsity_metrics.py +0 -0
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/model_wrapping/classification_inference_wrapper.py +0 -0
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/model_wrapping/generation_inference_wrapper.py +0 -0
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/model_wrapping/llm_interface.py +0 -0
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/model_wrapping/model_with_split_points.py +0 -0
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/model_wrapping/splitting_utils.py +0 -0
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/model_wrapping/transformers_classes.py +0 -0
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/visualizations/README.md +0 -0
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/visualizations/__init__.py +0 -0
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/visualizations/attributions.py +0 -0
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/visualizations/commons.py +0 -0
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/visualizations/concepts.py +0 -0
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/visualizations/css/visualization.css +0 -0
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/visualizations/js/core/dom_renderer.js +0 -0
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/visualizations/js/core/state_manager.js +0 -0
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/visualizations/js/core/style_computer.js +0 -0
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/visualizations/js/core/view_updater.js +0 -0
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/visualizations/js/visualizations/attribution_classification.js +0 -0
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/visualizations/js/visualizations/attribution_generation.js +0 -0
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/visualizations/js/visualizations/concepts_classification_global.js +0 -0
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/visualizations/js/visualizations/concepts_classification_local.js +0 -0
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/visualizations/js/visualizations/concepts_generation_local.js +0 -0
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto.egg-info/dependency_links.txt +0 -0
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto.egg-info/top_level.txt +0 -0
- {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/setup.cfg +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: interpreto
|
|
3
|
-
Version: 0.5.0.
|
|
3
|
+
Version: 0.5.0.dev1
|
|
4
4
|
Summary: Interpretability toolbox for LLMs
|
|
5
5
|
Author: FOR Team
|
|
6
6
|
Author-email: fanny.jourdan@irt-saintexupery.com
|
|
@@ -57,7 +57,7 @@ License-File: LICENSE
|
|
|
57
57
|
Requires-Dist: transformers>=4.22.0
|
|
58
58
|
Requires-Dist: nltk
|
|
59
59
|
Requires-Dist: torch>=2.0
|
|
60
|
-
Requires-Dist: nnsight<0.
|
|
60
|
+
Requires-Dist: nnsight<0.8.0,>=0.7.0
|
|
61
61
|
Requires-Dist: jaxtyping<=0.2.36
|
|
62
62
|
Requires-Dist: beartype
|
|
63
63
|
Requires-Dist: mknotebooks
|
|
@@ -120,6 +120,7 @@ Dynamic: license-file
|
|
|
120
120
|
<p align="center">
|
|
121
121
|
<a href="https://for-sight-ai.github.io/interpreto/"><strong>📚 Explore Interpreto docs >></strong></a><br />
|
|
122
122
|
<a href="https://for-sight-ai.github.io/interpreto-demo/"><strong>🖼️ Checkout our explanation gallery >></strong></a>
|
|
123
|
+
<a href="https://arxiv.org/abs/2512.09730"><strong>📜 Read our paper >></strong></a>
|
|
123
124
|
</p>
|
|
124
125
|
|
|
125
126
|
## 🚀 Quick Start
|
|
@@ -163,41 +164,51 @@ They all work seamlessly for both classification (`...ForSequenceClassification`
|
|
|
163
164
|
|
|
164
165
|
Concept-based explanations aim to provide high-level interpretations of latent model representations.
|
|
165
166
|
|
|
166
|
-
|
|
167
|
+
We propose both supervised (probes and CAVs) and unsupervised (dictionary learning) approaches.
|
|
168
|
+
|
|
169
|
+
Interpreto generalizes these methods through four core steps, the two first are common between both approaches:
|
|
167
170
|
|
|
168
171
|
1. Split a model in two and obtain a dataset of activations
|
|
169
|
-
2.
|
|
170
|
-
3.
|
|
171
|
-
4.
|
|
172
|
+
2. Learn concepts (e.g., from latent embeddings)
|
|
173
|
+
3. Interpret concepts (mapping discovered concepts to human-understandable elements)
|
|
174
|
+
4. Estimate concepts importance (assessing concept relevance to model outputs)
|
|
172
175
|
|
|
173
176
|
**1. Split a model in two and obtain a dataset of activations:** (mainly via [`nnsight`](https://github.com/ndif-team/nnsight)):
|
|
174
177
|
|
|
175
178
|
Choose any layer in any HuggingFace language model with our `ModelWithSplitPoints` based on `nnsight`. Then pass a dataset through it to obtain a dataset of activations.
|
|
176
179
|
|
|
177
|
-
**2.
|
|
180
|
+
**2. (supervised) Train probe** with the [`ProbeExplainer`](https://for-sight-ai.github.io/interpreto/api/concepts/probes/)
|
|
181
|
+
|
|
182
|
+
We differentiate two families of probes:
|
|
183
|
+
|
|
184
|
+
- Linear probes: [`LinearRegressionProbe`](https://for-sight-ai.github.io/interpreto/api/concepts/probes/#interpreto.concepts.probes.LinearRegressionProbe), [`LogisticRegressionProbe`](https://for-sight-ai.github.io/interpreto/api/concepts/probes/#interpreto.concepts.probes.LogisticRegressionProbe), [`LinearSVMProbe`](https://for-sight-ai.github.io/interpreto/api/concepts/probes/#interpreto.concepts.probes.LinearSVMProbe), [`MeansDiffProbe`](https://for-sight-ai.github.io/interpreto/api/concepts/probes/#interpreto.concepts.probes.MeansDiffProbe)
|
|
185
|
+
- Centroid-based probes: [`CosineCentroidProbe`](https://for-sight-ai.github.io/interpreto/api/concepts/probes/#interpreto.concepts.probes.CosineCentroidProbe), [`DotProductCentroidProbe`](https://for-sight-ai.github.io/interpreto/api/concepts/probes/#interpreto.concepts.probes.DotProductCentroidProbe), [`SqL2CentroidProbe`](https://for-sight-ai.github.io/interpreto/api/concepts/probes/#interpreto.concepts.probes.SqL2CentroidProbe), [`SVDDCentroidProbe`](https://for-sight-ai.github.io/interpreto/api/concepts/probes/#interpreto.concepts.probes.SVDDCentroidProbe), [`DiagonalMahalanobisCentroidProbe`](https://for-sight-ai.github.io/interpreto/api/concepts/probes/#interpreto.concepts.probes.DiagonalMahalanobisCentroidProbe)
|
|
186
|
+
|
|
187
|
+
Both can be tuned with `bias_calibrator` and `normalization` parameters.
|
|
188
|
+
|
|
189
|
+
**2. (unsupervised) Dictionary Learning for Concept Discovery** (mainly via [`overcomplete`](https://github.com/KempnerInstitute/overcomplete)):
|
|
178
190
|
|
|
179
191
|
- Interpret neurons directly via [`NeuronsAsConcepts`](https://for-sight-ai.github.io/interpreto/api/concepts/methods/neurons_as_concepts/)
|
|
180
192
|
- [`NMF`](https://for-sight-ai.github.io/interpreto/api/concepts/methods/optim/#interpreto.concepts.NMFConcepts), [`Semi-NMF`](https://for-sight-ai.github.io/interpreto/api/concepts/methods/optim/#interpreto.concepts.SemiNMFConcepts), [`ConvexNMF`](https://for-sight-ai.github.io/interpreto/api/concepts/methods/optim/#interpreto.concepts.ConvexNMFConcepts)
|
|
181
193
|
- [`ICA`](https://for-sight-ai.github.io/interpreto/api/concepts/methods/optim/#interpreto.concepts.ICAConcepts), [`SVD`](https://for-sight-ai.github.io/interpreto/api/concepts/methods/optim/#interpreto.concepts.SVDConcepts), [`PCA`](https://for-sight-ai.github.io/interpreto/api/concepts/methods/optim/#interpreto.concepts.PCAConcepts), [`KMeans`](https://for-sight-ai.github.io/interpreto/api/concepts/methods/optim/#interpreto.concepts.KMeansConcepts)
|
|
182
194
|
- SAE variants: [`Vanilla SAE`](https://for-sight-ai.github.io/interpreto/api/concepts/methods/sae/#interpreto.concepts.VanillaSAEConcepts), [`TopK SAE`](https://for-sight-ai.github.io/interpreto/api/concepts/methods/sae/#interpreto.concepts.TopKSAEConcepts), [`JumpReLU SAE`](https://for-sight-ai.github.io/interpreto/api/concepts/methods/sae/#interpreto.concepts.JumpReLUSAEConcepts), [`BatchTopK SAE`](https://for-sight-ai.github.io/interpreto/api/concepts/methods/sae/#interpreto.concepts.BatchTopKSAEConcepts)
|
|
183
195
|
|
|
184
|
-
**3. Available Concept Interpretation Techniques:**
|
|
196
|
+
**3. (unsupervised) Available Concept Interpretation Techniques:**
|
|
185
197
|
|
|
186
198
|
- Top-k tokens from tokenizer vocabulary via [`TopKInputs`](https://for-sight-ai.github.io/interpreto/api/concepts/concepts_interpretations/#interpreto.concepts.interpretations.TopKInputs) and `use_vocab=True`
|
|
187
199
|
- Top-k tokens/words/sentences/samples from specific datasets via [`TopKInputs`](https://for-sight-ai.github.io/interpreto/api/concepts/concepts_interpretations/#interpreto.concepts.interpretations.TopKInputs)
|
|
188
200
|
- Label concepts via LLMs with [`LLMLabels`](https://for-sight-ai.github.io/interpreto/api/concepts/concepts_interpretations/#interpreto.concepts.interpretations.LLMLabels) ([Bills et al. 2023](https://openai.com/index/language-models-can-explain-neurons-in-language-models/))
|
|
201
|
+
- Input-to-concept attribution from dataset examples ([Concept Attributions](https://for-sight-ai.github.io/interpreto/api/concepts/interpretations/concept_attributions/)) ([Jourdan et al. 2023](https://aclanthology.org/2023.findings-acl.317/))
|
|
189
202
|
|
|
190
203
|
<details><summary>Concept Interpretation Techniques Added in the future:</summary>
|
|
191
204
|
|
|
192
|
-
- Input-to-concept attribution from dataset examples ([Jourdan et al. 2023](https://aclanthology.org/2023.findings-acl.317/))
|
|
193
|
-
- Theme prediction via LLMs from top-k tokens/sentences
|
|
194
205
|
- Aligning concepts with human labels ([Sajjad et al. 2022](https://aclanthology.org/2022.naacl-main.225/))
|
|
195
206
|
- Word cloud visualizations of concepts ([Dalvi et al. 2022](https://arxiv.org/abs/2205.07237))
|
|
196
207
|
- VocabProj & TokenChange ([Gur-Arieh et al. 2025](https://arxiv.org/abs/2501.08319))
|
|
197
208
|
|
|
198
209
|
</details>
|
|
199
210
|
|
|
200
|
-
**4. Concept-to-Output Attribution:**
|
|
211
|
+
**4. (unsupervised) Concept-to-Output Attribution:**
|
|
201
212
|
|
|
202
213
|
Estimate the contribution of each concept to the model output.
|
|
203
214
|
|
|
@@ -207,7 +218,6 @@ Can be obtained with any concept-based explainer via [`MethodConcepts.concept_ou
|
|
|
207
218
|
|
|
208
219
|
Thanks to this generalization encompassing all concept-based methods and our highly flexible architecture, we can easily obtain a large number of concept-based methods:
|
|
209
220
|
|
|
210
|
-
- CAV and TCAV: [Kim et al. 2018, Interpretability Beyond Feature Attribution: Quantitative Testing with Concept Activation Vectors (TCAV)](http://proceedings.mlr.press/v80/kim18d.html)
|
|
211
221
|
- ConceptSHAP: [Yeh et al. 2020, On Completeness-aware Concept-Based Explanations in Deep Neural Networks](https://proceedings.neurips.cc/paper/2020/hash/ecb287ff763c169694f682af52c1f309-Abstract.html)
|
|
212
222
|
- COCKATIEL: [Jourdan et al. 2023, COCKATIEL: COntinuous Concept ranKed ATtribution with Interpretable ELements for explaining neural net classifiers on NLP](https://aclanthology.org/2023.findings-acl.317/)
|
|
213
223
|
- Yun et al. 2021, [Transformer visualization via dictionary learning: contextualized embedding as a linear superposition of transformer factors](https://arxiv.org/abs/2103.15949)
|
|
@@ -264,7 +274,7 @@ Interpreto 🪄 is a project of the [FOR](https://www.irt-saintexupery.com/fr/fo
|
|
|
264
274
|
|
|
265
275
|
## 🗞️ Citation
|
|
266
276
|
|
|
267
|
-
If you use Interpreto 🪄 as part of your workflow in a scientific publication, please consider citing 🗞️ our paper:
|
|
277
|
+
If you use Interpreto 🪄 as part of your workflow in a scientific publication, please consider citing 🗞️ [our paper](https://arxiv.org/abs/2512.09730):
|
|
268
278
|
|
|
269
279
|
```bibtex
|
|
270
280
|
@article{poche2025interpreto,
|
|
@@ -15,6 +15,7 @@
|
|
|
15
15
|
<p align="center">
|
|
16
16
|
<a href="https://for-sight-ai.github.io/interpreto/"><strong>📚 Explore Interpreto docs >></strong></a><br />
|
|
17
17
|
<a href="https://for-sight-ai.github.io/interpreto-demo/"><strong>🖼️ Checkout our explanation gallery >></strong></a>
|
|
18
|
+
<a href="https://arxiv.org/abs/2512.09730"><strong>📜 Read our paper >></strong></a>
|
|
18
19
|
</p>
|
|
19
20
|
|
|
20
21
|
## 🚀 Quick Start
|
|
@@ -58,41 +59,51 @@ They all work seamlessly for both classification (`...ForSequenceClassification`
|
|
|
58
59
|
|
|
59
60
|
Concept-based explanations aim to provide high-level interpretations of latent model representations.
|
|
60
61
|
|
|
61
|
-
|
|
62
|
+
We propose both supervised (probes and CAVs) and unsupervised (dictionary learning) approaches.
|
|
63
|
+
|
|
64
|
+
Interpreto generalizes these methods through four core steps, the two first are common between both approaches:
|
|
62
65
|
|
|
63
66
|
1. Split a model in two and obtain a dataset of activations
|
|
64
|
-
2.
|
|
65
|
-
3.
|
|
66
|
-
4.
|
|
67
|
+
2. Learn concepts (e.g., from latent embeddings)
|
|
68
|
+
3. Interpret concepts (mapping discovered concepts to human-understandable elements)
|
|
69
|
+
4. Estimate concepts importance (assessing concept relevance to model outputs)
|
|
67
70
|
|
|
68
71
|
**1. Split a model in two and obtain a dataset of activations:** (mainly via [`nnsight`](https://github.com/ndif-team/nnsight)):
|
|
69
72
|
|
|
70
73
|
Choose any layer in any HuggingFace language model with our `ModelWithSplitPoints` based on `nnsight`. Then pass a dataset through it to obtain a dataset of activations.
|
|
71
74
|
|
|
72
|
-
**2.
|
|
75
|
+
**2. (supervised) Train probe** with the [`ProbeExplainer`](https://for-sight-ai.github.io/interpreto/api/concepts/probes/)
|
|
76
|
+
|
|
77
|
+
We differentiate two families of probes:
|
|
78
|
+
|
|
79
|
+
- Linear probes: [`LinearRegressionProbe`](https://for-sight-ai.github.io/interpreto/api/concepts/probes/#interpreto.concepts.probes.LinearRegressionProbe), [`LogisticRegressionProbe`](https://for-sight-ai.github.io/interpreto/api/concepts/probes/#interpreto.concepts.probes.LogisticRegressionProbe), [`LinearSVMProbe`](https://for-sight-ai.github.io/interpreto/api/concepts/probes/#interpreto.concepts.probes.LinearSVMProbe), [`MeansDiffProbe`](https://for-sight-ai.github.io/interpreto/api/concepts/probes/#interpreto.concepts.probes.MeansDiffProbe)
|
|
80
|
+
- Centroid-based probes: [`CosineCentroidProbe`](https://for-sight-ai.github.io/interpreto/api/concepts/probes/#interpreto.concepts.probes.CosineCentroidProbe), [`DotProductCentroidProbe`](https://for-sight-ai.github.io/interpreto/api/concepts/probes/#interpreto.concepts.probes.DotProductCentroidProbe), [`SqL2CentroidProbe`](https://for-sight-ai.github.io/interpreto/api/concepts/probes/#interpreto.concepts.probes.SqL2CentroidProbe), [`SVDDCentroidProbe`](https://for-sight-ai.github.io/interpreto/api/concepts/probes/#interpreto.concepts.probes.SVDDCentroidProbe), [`DiagonalMahalanobisCentroidProbe`](https://for-sight-ai.github.io/interpreto/api/concepts/probes/#interpreto.concepts.probes.DiagonalMahalanobisCentroidProbe)
|
|
81
|
+
|
|
82
|
+
Both can be tuned with `bias_calibrator` and `normalization` parameters.
|
|
83
|
+
|
|
84
|
+
**2. (unsupervised) Dictionary Learning for Concept Discovery** (mainly via [`overcomplete`](https://github.com/KempnerInstitute/overcomplete)):
|
|
73
85
|
|
|
74
86
|
- Interpret neurons directly via [`NeuronsAsConcepts`](https://for-sight-ai.github.io/interpreto/api/concepts/methods/neurons_as_concepts/)
|
|
75
87
|
- [`NMF`](https://for-sight-ai.github.io/interpreto/api/concepts/methods/optim/#interpreto.concepts.NMFConcepts), [`Semi-NMF`](https://for-sight-ai.github.io/interpreto/api/concepts/methods/optim/#interpreto.concepts.SemiNMFConcepts), [`ConvexNMF`](https://for-sight-ai.github.io/interpreto/api/concepts/methods/optim/#interpreto.concepts.ConvexNMFConcepts)
|
|
76
88
|
- [`ICA`](https://for-sight-ai.github.io/interpreto/api/concepts/methods/optim/#interpreto.concepts.ICAConcepts), [`SVD`](https://for-sight-ai.github.io/interpreto/api/concepts/methods/optim/#interpreto.concepts.SVDConcepts), [`PCA`](https://for-sight-ai.github.io/interpreto/api/concepts/methods/optim/#interpreto.concepts.PCAConcepts), [`KMeans`](https://for-sight-ai.github.io/interpreto/api/concepts/methods/optim/#interpreto.concepts.KMeansConcepts)
|
|
77
89
|
- SAE variants: [`Vanilla SAE`](https://for-sight-ai.github.io/interpreto/api/concepts/methods/sae/#interpreto.concepts.VanillaSAEConcepts), [`TopK SAE`](https://for-sight-ai.github.io/interpreto/api/concepts/methods/sae/#interpreto.concepts.TopKSAEConcepts), [`JumpReLU SAE`](https://for-sight-ai.github.io/interpreto/api/concepts/methods/sae/#interpreto.concepts.JumpReLUSAEConcepts), [`BatchTopK SAE`](https://for-sight-ai.github.io/interpreto/api/concepts/methods/sae/#interpreto.concepts.BatchTopKSAEConcepts)
|
|
78
90
|
|
|
79
|
-
**3. Available Concept Interpretation Techniques:**
|
|
91
|
+
**3. (unsupervised) Available Concept Interpretation Techniques:**
|
|
80
92
|
|
|
81
93
|
- Top-k tokens from tokenizer vocabulary via [`TopKInputs`](https://for-sight-ai.github.io/interpreto/api/concepts/concepts_interpretations/#interpreto.concepts.interpretations.TopKInputs) and `use_vocab=True`
|
|
82
94
|
- Top-k tokens/words/sentences/samples from specific datasets via [`TopKInputs`](https://for-sight-ai.github.io/interpreto/api/concepts/concepts_interpretations/#interpreto.concepts.interpretations.TopKInputs)
|
|
83
95
|
- Label concepts via LLMs with [`LLMLabels`](https://for-sight-ai.github.io/interpreto/api/concepts/concepts_interpretations/#interpreto.concepts.interpretations.LLMLabels) ([Bills et al. 2023](https://openai.com/index/language-models-can-explain-neurons-in-language-models/))
|
|
96
|
+
- Input-to-concept attribution from dataset examples ([Concept Attributions](https://for-sight-ai.github.io/interpreto/api/concepts/interpretations/concept_attributions/)) ([Jourdan et al. 2023](https://aclanthology.org/2023.findings-acl.317/))
|
|
84
97
|
|
|
85
98
|
<details><summary>Concept Interpretation Techniques Added in the future:</summary>
|
|
86
99
|
|
|
87
|
-
- Input-to-concept attribution from dataset examples ([Jourdan et al. 2023](https://aclanthology.org/2023.findings-acl.317/))
|
|
88
|
-
- Theme prediction via LLMs from top-k tokens/sentences
|
|
89
100
|
- Aligning concepts with human labels ([Sajjad et al. 2022](https://aclanthology.org/2022.naacl-main.225/))
|
|
90
101
|
- Word cloud visualizations of concepts ([Dalvi et al. 2022](https://arxiv.org/abs/2205.07237))
|
|
91
102
|
- VocabProj & TokenChange ([Gur-Arieh et al. 2025](https://arxiv.org/abs/2501.08319))
|
|
92
103
|
|
|
93
104
|
</details>
|
|
94
105
|
|
|
95
|
-
**4. Concept-to-Output Attribution:**
|
|
106
|
+
**4. (unsupervised) Concept-to-Output Attribution:**
|
|
96
107
|
|
|
97
108
|
Estimate the contribution of each concept to the model output.
|
|
98
109
|
|
|
@@ -102,7 +113,6 @@ Can be obtained with any concept-based explainer via [`MethodConcepts.concept_ou
|
|
|
102
113
|
|
|
103
114
|
Thanks to this generalization encompassing all concept-based methods and our highly flexible architecture, we can easily obtain a large number of concept-based methods:
|
|
104
115
|
|
|
105
|
-
- CAV and TCAV: [Kim et al. 2018, Interpretability Beyond Feature Attribution: Quantitative Testing with Concept Activation Vectors (TCAV)](http://proceedings.mlr.press/v80/kim18d.html)
|
|
106
116
|
- ConceptSHAP: [Yeh et al. 2020, On Completeness-aware Concept-Based Explanations in Deep Neural Networks](https://proceedings.neurips.cc/paper/2020/hash/ecb287ff763c169694f682af52c1f309-Abstract.html)
|
|
107
117
|
- COCKATIEL: [Jourdan et al. 2023, COCKATIEL: COntinuous Concept ranKed ATtribution with Interpretable ELements for explaining neural net classifiers on NLP](https://aclanthology.org/2023.findings-acl.317/)
|
|
108
118
|
- Yun et al. 2021, [Transformer visualization via dictionary learning: contextualized embedding as a linear superposition of transformer factors](https://arxiv.org/abs/2103.15949)
|
|
@@ -159,7 +169,7 @@ Interpreto 🪄 is a project of the [FOR](https://www.irt-saintexupery.com/fr/fo
|
|
|
159
169
|
|
|
160
170
|
## 🗞️ Citation
|
|
161
171
|
|
|
162
|
-
If you use Interpreto 🪄 as part of your workflow in a scientific publication, please consider citing 🗞️ our paper:
|
|
172
|
+
If you use Interpreto 🪄 as part of your workflow in a scientific publication, please consider citing 🗞️ [our paper](https://arxiv.org/abs/2512.09730):
|
|
163
173
|
|
|
164
174
|
```bibtex
|
|
165
175
|
@article{poche2025interpreto,
|
|
@@ -42,7 +42,7 @@ from .attributions import (
|
|
|
42
42
|
from .commons import (
|
|
43
43
|
Granularity,
|
|
44
44
|
)
|
|
45
|
-
from .model_wrapping import ModelWithSplitPoints
|
|
45
|
+
from .model_wrapping import ModelWithSplitPoints, SplitSequenceClassification
|
|
46
46
|
from .visualizations import (
|
|
47
47
|
AttributionVisualization,
|
|
48
48
|
plot_attributions,
|
|
@@ -71,6 +71,7 @@ __all__ = [
|
|
|
71
71
|
"SquareGrad",
|
|
72
72
|
"Saliency",
|
|
73
73
|
"SmoothGrad",
|
|
74
|
+
"SplitSequenceClassification",
|
|
74
75
|
"Sobol",
|
|
75
76
|
"VarGrad",
|
|
76
77
|
"get_version",
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
from abc import ABC
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from torch import nn
|
|
5
|
+
|
|
6
|
+
class BaseDictionaryLearning(ABC, nn.Module):
|
|
7
|
+
nb_concepts: int
|
|
8
|
+
device: str | torch.device
|
|
9
|
+
fitted: bool
|
|
10
|
+
|
|
11
|
+
def __init__(self, nb_concepts: int, device: str | torch.device = "cpu") -> None: ...
|
|
12
|
+
def encode(self, x: torch.Tensor) -> torch.Tensor: ...
|
|
13
|
+
def decode(self, z: torch.Tensor) -> torch.Tensor: ...
|
|
14
|
+
def fit(self, x: torch.Tensor) -> None: ...
|
|
15
|
+
def get_dictionary(self) -> torch.Tensor: ...
|
|
16
|
+
def to(self, device: str | torch.device) -> "BaseDictionaryLearning": ... # type: ignore[override]
|
|
@@ -45,13 +45,17 @@ from interpreto.attributions.perturbations.base import Perturbator
|
|
|
45
45
|
from interpreto.commons import Granularity
|
|
46
46
|
from interpreto.commons.generator_tools import split_iterator
|
|
47
47
|
from interpreto.commons.granularity import GranularityAggregationStrategy
|
|
48
|
+
from interpreto.concepts.base import ModelForInputsToConcepts
|
|
48
49
|
from interpreto.model_wrapping.classification_inference_wrapper import ClassificationInferenceWrapper
|
|
49
50
|
from interpreto.model_wrapping.generation_inference_wrapper import GenerationInferenceWrapper
|
|
50
51
|
from interpreto.model_wrapping.inference_wrapper import InferenceModes, InferenceWrapper
|
|
52
|
+
from interpreto.model_wrapping.inputs_to_concepts_inference_wrapper import InputsToConceptsInferenceWrapper
|
|
51
53
|
from interpreto.typing import ClassificationTarget, GeneratedTarget, ModelInputs, SingleAttribution, TensorMapping
|
|
52
54
|
|
|
53
55
|
|
|
54
|
-
def setup_token_ids(
|
|
56
|
+
def setup_token_ids(
|
|
57
|
+
model: PreTrainedModel | ModelForInputsToConcepts, tokenizer: PreTrainedTokenizer, require_mask_token: bool = True
|
|
58
|
+
) -> int:
|
|
55
59
|
"""
|
|
56
60
|
Setup the tokenizer and the model with the appropriate token IDs, for padding and masking.
|
|
57
61
|
|
|
@@ -139,6 +143,7 @@ class ModelTask(Enum):
|
|
|
139
143
|
|
|
140
144
|
CLASSIFICATION = "classification"
|
|
141
145
|
GENERATION = "generation"
|
|
146
|
+
CONCEPTS = "concepts"
|
|
142
147
|
|
|
143
148
|
|
|
144
149
|
def clone_tensor_mapping(tm: TensorMapping, detach: bool = False) -> TensorMapping:
|
|
@@ -232,7 +237,7 @@ class AttributionExplainer:
|
|
|
232
237
|
|
|
233
238
|
def __init__(
|
|
234
239
|
self,
|
|
235
|
-
model: PreTrainedModel,
|
|
240
|
+
model: PreTrainedModel | ModelForInputsToConcepts,
|
|
236
241
|
tokenizer: PreTrainedTokenizer,
|
|
237
242
|
batch_size: int = 4,
|
|
238
243
|
perturbator: Perturbator | None = None,
|
|
@@ -248,7 +253,7 @@ class AttributionExplainer:
|
|
|
248
253
|
Initializes the AttributionExplainer.
|
|
249
254
|
|
|
250
255
|
Args:
|
|
251
|
-
model (PreTrainedModel): The model to be explained.
|
|
256
|
+
model (PreTrainedModel | ModelForInputsToConcepts): The model to be explained.
|
|
252
257
|
tokenizer (PreTrainedTokenizer): The tokenizer associated with the model.
|
|
253
258
|
batch_size (int): The batch size used for model inference.
|
|
254
259
|
perturbator (Perturbator, optional): Instance used to generate input perturbations.
|
|
@@ -276,7 +281,7 @@ class AttributionExplainer:
|
|
|
276
281
|
self.tokenizer = tokenizer
|
|
277
282
|
|
|
278
283
|
self.inference_wrapper = self._associated_inference_wrapper(
|
|
279
|
-
model,
|
|
284
|
+
model, # type: ignore
|
|
280
285
|
gradients=use_gradient,
|
|
281
286
|
input_x_gradient=input_x_gradient,
|
|
282
287
|
batch_size=batch_size,
|
|
@@ -846,13 +851,99 @@ class GenerationAttributionExplainer(AttributionExplainer):
|
|
|
846
851
|
return ModelTask.GENERATION, contribution
|
|
847
852
|
|
|
848
853
|
|
|
854
|
+
class InputsToConceptsAttributionsExplainer(AttributionExplainer):
|
|
855
|
+
"""Attribution explainer for input-to-concept models.
|
|
856
|
+
|
|
857
|
+
This explainer computes how much each input token contributes to each concept
|
|
858
|
+
activation. It bridges the attribution framework with the concept framework:
|
|
859
|
+
once a concept explainer is fitted, its ``inputs_to_concepts`` property returns
|
|
860
|
+
a model that can be passed to any perturbation-based attribution method.
|
|
861
|
+
|
|
862
|
+
The result is a per-token attribution for each concept, revealing which parts
|
|
863
|
+
of the input are most responsible for activating a given concept.
|
|
864
|
+
|
|
865
|
+
Note:
|
|
866
|
+
Only perturbation-based methods (Lime, KernelShap, Occlusion, Sobol) are
|
|
867
|
+
supported. Gradient-based methods are incompatible because the
|
|
868
|
+
``ModelForInputsToConcepts`` is based on `nnsight` and which make differentiation complex.
|
|
869
|
+
|
|
870
|
+
Example:
|
|
871
|
+
```python
|
|
872
|
+
from interpreto import Occlusion, SplitSequenceClassification
|
|
873
|
+
from interpreto.concepts import SemiNMFConcepts
|
|
874
|
+
|
|
875
|
+
split_model = SplitSequenceClassification("model_id", device_map="cuda")
|
|
876
|
+
concept_explainer = SemiNMFConcepts(split_model, nb_concepts=20)
|
|
877
|
+
concept_explainer.fit(activations)
|
|
878
|
+
|
|
879
|
+
explainer = Occlusion(concept_explainer.inputs_to_concepts, split_model.tokenizer)
|
|
880
|
+
results = explainer.explain("Some input text.", targets=torch.arange(5))
|
|
881
|
+
```
|
|
882
|
+
"""
|
|
883
|
+
|
|
884
|
+
_associated_inference_wrapper = InputsToConceptsInferenceWrapper
|
|
885
|
+
inference_wrapper: InputsToConceptsInferenceWrapper
|
|
886
|
+
|
|
887
|
+
def process_inputs_to_explain_and_targets( # type: ignore
|
|
888
|
+
self,
|
|
889
|
+
model_inputs: ModelInputs,
|
|
890
|
+
targets: Iterable[int] | None = None,
|
|
891
|
+
) -> tuple[list[TensorMapping], list[Int[torch.Tensor, "t"]]]:
|
|
892
|
+
"""
|
|
893
|
+
Processes the inputs and targets for explanation.
|
|
894
|
+
|
|
895
|
+
This method must be implemented by subclasses.
|
|
896
|
+
|
|
897
|
+
Args:
|
|
898
|
+
model_inputs (ModelInputs):
|
|
899
|
+
The inputs to the model.
|
|
900
|
+
targets (Optional[Iterable[int]]):
|
|
901
|
+
The targets to be explained.
|
|
902
|
+
If None, all concepts are explained.
|
|
903
|
+
|
|
904
|
+
Returns:
|
|
905
|
+
processed_inputs (list[TensorMapping]):
|
|
906
|
+
The processed inputs.
|
|
907
|
+
processed_targets (list[Int[torch.Tensor, "t"]]):
|
|
908
|
+
The processed targets.
|
|
909
|
+
"""
|
|
910
|
+
sanitized_targets: list[Int[torch.Tensor, "t"]]
|
|
911
|
+
if targets is None:
|
|
912
|
+
# explain all concepts
|
|
913
|
+
input_wise_targets = torch.arange(self.inference_wrapper.model.nb_concepts) # type: ignore
|
|
914
|
+
sanitized_targets = [input_wise_targets] * len(model_inputs) # type: ignore
|
|
915
|
+
else:
|
|
916
|
+
# targets are concept indices, shared across all inputs
|
|
917
|
+
if isinstance(targets, torch.Tensor):
|
|
918
|
+
input_wise_targets = targets.long()
|
|
919
|
+
else:
|
|
920
|
+
input_wise_targets = torch.tensor(list(targets), dtype=torch.long)
|
|
921
|
+
sanitized_targets = [input_wise_targets] * len(model_inputs) # type: ignore
|
|
922
|
+
return model_inputs, sanitized_targets # type: ignore
|
|
923
|
+
|
|
924
|
+
def post_processing(self, contribution: Float[torch.Tensor, "t l"]):
|
|
925
|
+
"""
|
|
926
|
+
Concepts specific post-processing of the attribution scores.
|
|
927
|
+
|
|
928
|
+
No post-processing is required for concept attributions.
|
|
929
|
+
|
|
930
|
+
Args:
|
|
931
|
+
contribution (Float[torch.Tensor, "t l"]): The contribution values.
|
|
932
|
+
|
|
933
|
+
Returns:
|
|
934
|
+
model_task (ModelTask): The model task.
|
|
935
|
+
contribution (Float[torch.Tensor, "t l"]): The post-processed contribution values.
|
|
936
|
+
"""
|
|
937
|
+
return ModelTask.CONCEPTS, contribution
|
|
938
|
+
|
|
939
|
+
|
|
849
940
|
class FactoryGeneratedMeta(type):
|
|
850
941
|
"""
|
|
851
942
|
Metaclass to distinguish classes generated by the MultitaskExplainerMixin.
|
|
852
943
|
"""
|
|
853
944
|
|
|
854
945
|
|
|
855
|
-
class MultitaskExplainerMixin
|
|
946
|
+
class MultitaskExplainerMixin:
|
|
856
947
|
"""
|
|
857
948
|
Mixin class to generate the appropriate Explainer based on the model type.
|
|
858
949
|
"""
|
|
@@ -866,6 +957,11 @@ class MultitaskExplainerMixin(AttributionExplainer):
|
|
|
866
957
|
if model.__class__.__name__.endswith("ForCausalLM") or model.__class__.__name__.endswith("LMHeadModel"):
|
|
867
958
|
t = FactoryGeneratedMeta("Generation" + cls.__name__, (cls, GenerationAttributionExplainer), {})
|
|
868
959
|
return t.__new__(t, model, *args, **kwargs) # type: ignore
|
|
960
|
+
if model.__class__.__name__.endswith("ForInputsToConcepts"):
|
|
961
|
+
t = FactoryGeneratedMeta(
|
|
962
|
+
"InputsToConcepts" + cls.__name__, (cls, InputsToConceptsAttributionsExplainer), {}
|
|
963
|
+
)
|
|
964
|
+
return t.__new__(t, model, *args, **kwargs) # type: ignore
|
|
869
965
|
raise NotImplementedError(
|
|
870
966
|
"Model type not supported for Explainer. Use a ModelForSequenceClassification, a ModelForCausalLM model or a LMHeadModel model."
|
|
871
967
|
)
|
{interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/attributions/methods/kernel_shap.py
RENAMED
|
@@ -40,6 +40,7 @@ from interpreto.attributions.aggregations.linear_regression_aggregation import (
|
|
|
40
40
|
from interpreto.attributions.base import AttributionExplainer, MultitaskExplainerMixin, setup_token_ids
|
|
41
41
|
from interpreto.attributions.perturbations.shap_perturbation import ShapTokenPerturbator
|
|
42
42
|
from interpreto.commons.granularity import Granularity, GranularityAggregationStrategy
|
|
43
|
+
from interpreto.concepts.base import ModelForInputsToConcepts
|
|
43
44
|
from interpreto.model_wrapping.inference_wrapper import InferenceModes
|
|
44
45
|
|
|
45
46
|
|
|
@@ -68,7 +69,7 @@ class KernelShap(MultitaskExplainerMixin, AttributionExplainer):
|
|
|
68
69
|
|
|
69
70
|
def __init__(
|
|
70
71
|
self,
|
|
71
|
-
model: PreTrainedModel,
|
|
72
|
+
model: PreTrainedModel | ModelForInputsToConcepts,
|
|
72
73
|
tokenizer: PreTrainedTokenizer,
|
|
73
74
|
batch_size: int = 4,
|
|
74
75
|
granularity: Granularity = Granularity.WORD,
|
|
@@ -81,7 +82,7 @@ class KernelShap(MultitaskExplainerMixin, AttributionExplainer):
|
|
|
81
82
|
Initialize the attribution method.
|
|
82
83
|
|
|
83
84
|
Args:
|
|
84
|
-
model (PreTrainedModel): model to explain
|
|
85
|
+
model (PreTrainedModel | ModelForInputsToConcepts): model to explain
|
|
85
86
|
tokenizer (PreTrainedTokenizer): Hugging Face tokenizer associated with the model
|
|
86
87
|
batch_size (int): batch size for the attribution method
|
|
87
88
|
granularity (Granularity, optional): The level of granularity for the explanation.
|
|
@@ -43,6 +43,7 @@ from interpreto.attributions.aggregations.linear_regression_aggregation import (
|
|
|
43
43
|
from interpreto.attributions.base import AttributionExplainer, InferenceModes, MultitaskExplainerMixin, setup_token_ids
|
|
44
44
|
from interpreto.attributions.perturbations.random_perturbation import RandomMaskedTokenPerturbator
|
|
45
45
|
from interpreto.commons import Granularity, GranularityAggregationStrategy
|
|
46
|
+
from interpreto.concepts.base import ModelForInputsToConcepts
|
|
46
47
|
|
|
47
48
|
|
|
48
49
|
class Lime(MultitaskExplainerMixin, AttributionExplainer):
|
|
@@ -72,7 +73,7 @@ class Lime(MultitaskExplainerMixin, AttributionExplainer):
|
|
|
72
73
|
|
|
73
74
|
def __init__(
|
|
74
75
|
self,
|
|
75
|
-
model: PreTrainedModel,
|
|
76
|
+
model: PreTrainedModel | ModelForInputsToConcepts,
|
|
76
77
|
tokenizer: PreTrainedTokenizer,
|
|
77
78
|
batch_size: int = 4,
|
|
78
79
|
granularity: Granularity = Granularity.WORD,
|
|
@@ -88,7 +89,7 @@ class Lime(MultitaskExplainerMixin, AttributionExplainer):
|
|
|
88
89
|
Initialize the attribution method.
|
|
89
90
|
|
|
90
91
|
Args:
|
|
91
|
-
model (PreTrainedModel): model to explain
|
|
92
|
+
model (PreTrainedModel | ModelForInputsToConcepts): model to explain
|
|
92
93
|
tokenizer (PreTrainedTokenizer): Hugging Face tokenizer associated with the model
|
|
93
94
|
batch_size (int): batch size for the attribution method
|
|
94
95
|
granularity (Granularity, optional): The level of granularity for the explanation.
|
{interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/attributions/methods/occlusion.py
RENAMED
|
@@ -29,10 +29,9 @@ Occlusion attribution method
|
|
|
29
29
|
from __future__ import annotations
|
|
30
30
|
|
|
31
31
|
from collections.abc import Callable
|
|
32
|
-
from typing import Any
|
|
33
32
|
|
|
34
33
|
import torch
|
|
35
|
-
from transformers import PreTrainedTokenizer
|
|
34
|
+
from transformers import PreTrainedModel, PreTrainedTokenizer
|
|
36
35
|
|
|
37
36
|
from interpreto.attributions.aggregations.base import OcclusionAggregator
|
|
38
37
|
from interpreto.attributions.base import (
|
|
@@ -42,6 +41,7 @@ from interpreto.attributions.base import (
|
|
|
42
41
|
)
|
|
43
42
|
from interpreto.attributions.perturbations import OcclusionPerturbator
|
|
44
43
|
from interpreto.commons.granularity import Granularity, GranularityAggregationStrategy
|
|
44
|
+
from interpreto.concepts.base import ModelForInputsToConcepts
|
|
45
45
|
from interpreto.model_wrapping.inference_wrapper import InferenceModes
|
|
46
46
|
|
|
47
47
|
|
|
@@ -68,7 +68,7 @@ class Occlusion(MultitaskExplainerMixin, AttributionExplainer):
|
|
|
68
68
|
|
|
69
69
|
def __init__(
|
|
70
70
|
self,
|
|
71
|
-
model:
|
|
71
|
+
model: PreTrainedModel | ModelForInputsToConcepts,
|
|
72
72
|
tokenizer: PreTrainedTokenizer,
|
|
73
73
|
batch_size: int = 4,
|
|
74
74
|
granularity: Granularity = Granularity.WORD,
|
|
@@ -80,7 +80,7 @@ class Occlusion(MultitaskExplainerMixin, AttributionExplainer):
|
|
|
80
80
|
Initialize the attribution method.
|
|
81
81
|
|
|
82
82
|
Args:
|
|
83
|
-
model (PreTrainedModel): model to explain
|
|
83
|
+
model (PreTrainedModel | ModelForInputsToConcepts): model to explain
|
|
84
84
|
tokenizer (PreTrainedTokenizer): Hugging Face tokenizer associated with the model
|
|
85
85
|
batch_size (int): batch size for the attribution method
|
|
86
86
|
granularity (Granularity, optional): The level of granularity for the explanation.
|
{interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/attributions/methods/sobol_attribution.py
RENAMED
|
@@ -41,6 +41,7 @@ from interpreto.attributions.perturbations.sobol_perturbation import (
|
|
|
41
41
|
SobolTokenPerturbator,
|
|
42
42
|
)
|
|
43
43
|
from interpreto.commons.granularity import Granularity, GranularityAggregationStrategy
|
|
44
|
+
from interpreto.concepts.base import ModelForInputsToConcepts
|
|
44
45
|
|
|
45
46
|
|
|
46
47
|
class Sobol(MultitaskExplainerMixin, AttributionExplainer):
|
|
@@ -73,7 +74,7 @@ class Sobol(MultitaskExplainerMixin, AttributionExplainer):
|
|
|
73
74
|
|
|
74
75
|
def __init__(
|
|
75
76
|
self,
|
|
76
|
-
model: PreTrainedModel,
|
|
77
|
+
model: PreTrainedModel | ModelForInputsToConcepts,
|
|
77
78
|
tokenizer: PreTrainedTokenizer,
|
|
78
79
|
batch_size: int = 4,
|
|
79
80
|
granularity: Granularity = Granularity.WORD,
|
|
@@ -88,7 +89,7 @@ class Sobol(MultitaskExplainerMixin, AttributionExplainer):
|
|
|
88
89
|
Initialize the attribution method.
|
|
89
90
|
|
|
90
91
|
Args:
|
|
91
|
-
model (PreTrainedModel): model to explain
|
|
92
|
+
model (PreTrainedModel | ModelForInputsToConcepts): model to explain
|
|
92
93
|
tokenizer (PreTrainedTokenizer): Hugging Face tokenizer associated with the model
|
|
93
94
|
batch_size (int): batch size for the attribution method
|
|
94
95
|
granularity (Granularity, optional): The level of granularity for the explanation.
|
|
@@ -43,6 +43,18 @@ from .methods import (
|
|
|
43
43
|
TopKSAEConcepts,
|
|
44
44
|
VanillaSAEConcepts,
|
|
45
45
|
)
|
|
46
|
+
from .probes import (
|
|
47
|
+
CosineCentroidProbe,
|
|
48
|
+
DiagonalMahalanobisCentroidProbe,
|
|
49
|
+
DotProductCentroidProbe,
|
|
50
|
+
LinearRegressionProbe,
|
|
51
|
+
LinearSVMProbe,
|
|
52
|
+
LogisticRegressionProbe,
|
|
53
|
+
MeansDiffProbe,
|
|
54
|
+
ProbeExplainer,
|
|
55
|
+
SqL2CentroidProbe,
|
|
56
|
+
SVDDCentroidProbe,
|
|
57
|
+
)
|
|
46
58
|
|
|
47
59
|
__all__ = [
|
|
48
60
|
"BatchTopKSAEConcepts",
|
|
@@ -50,19 +62,29 @@ __all__ = [
|
|
|
50
62
|
"ConceptAutoEncoderExplainer",
|
|
51
63
|
"ConceptEncoderExplainer",
|
|
52
64
|
"ConvexNMFConcepts",
|
|
65
|
+
"CosineCentroidProbe",
|
|
66
|
+
"DiagonalMahalanobisCentroidProbe",
|
|
53
67
|
"DictionaryLearningConcepts",
|
|
68
|
+
"DotProductCentroidProbe",
|
|
54
69
|
"ICAConcepts",
|
|
55
70
|
"JumpReLUSAEConcepts",
|
|
56
71
|
"KMeansConcepts",
|
|
57
72
|
"LLMLabels",
|
|
73
|
+
"LinearRegressionProbe",
|
|
74
|
+
"LinearSVMProbe",
|
|
75
|
+
"LogisticRegressionProbe",
|
|
76
|
+
"MeansDiffProbe",
|
|
58
77
|
"MpSAEConcepts",
|
|
59
78
|
"NeuronsAsConcepts",
|
|
60
79
|
"NMFConcepts",
|
|
61
80
|
"PCAConcepts",
|
|
81
|
+
"ProbeExplainer",
|
|
62
82
|
"SAELossClasses",
|
|
63
83
|
"SemiNMFConcepts",
|
|
64
84
|
"SparsePCAConcepts",
|
|
85
|
+
"SqL2CentroidProbe",
|
|
65
86
|
"SVDConcepts",
|
|
87
|
+
"SVDDCentroidProbe",
|
|
66
88
|
"TopKInputs",
|
|
67
89
|
"TopKSAEConcepts",
|
|
68
90
|
"VanillaSAEConcepts",
|