interpreto 0.5.0.dev0__tar.gz → 0.5.0.dev1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (135) hide show
  1. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/PKG-INFO +23 -13
  2. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/README.md +21 -11
  3. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/__init__.py +2 -1
  4. interpreto-0.5.0.dev1/interpreto/_vendor/overcomplete/base.pyi +16 -0
  5. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/attributions/base.py +101 -5
  6. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/attributions/methods/kernel_shap.py +3 -2
  7. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/attributions/methods/lime.py +3 -2
  8. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/attributions/methods/occlusion.py +4 -4
  9. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/attributions/methods/sobol_attribution.py +3 -2
  10. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/concepts/__init__.py +22 -0
  11. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/concepts/base.py +116 -43
  12. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/concepts/interpretations/base.py +11 -5
  13. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/concepts/interpretations/llm_labels.py +1 -1
  14. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/concepts/interpretations/topk_inputs.py +1 -1
  15. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/concepts/methods/sklearn_wrappers.py +1 -1
  16. interpreto-0.5.0.dev1/interpreto/concepts/probes/__init__.py +81 -0
  17. interpreto-0.5.0.dev1/interpreto/concepts/probes/base.py +279 -0
  18. interpreto-0.5.0.dev1/interpreto/concepts/probes/bias_calibrators.py +223 -0
  19. interpreto-0.5.0.dev1/interpreto/concepts/probes/centroid.py +432 -0
  20. interpreto-0.5.0.dev1/interpreto/concepts/probes/linear.py +418 -0
  21. interpreto-0.5.0.dev1/interpreto/concepts/probes/normalizations.py +201 -0
  22. interpreto-0.5.0.dev1/interpreto/concepts/probes/sklearn.py +140 -0
  23. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/model_wrapping/__init__.py +2 -1
  24. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/model_wrapping/inference_wrapper.py +2 -0
  25. interpreto-0.5.0.dev1/interpreto/model_wrapping/inputs_to_concepts_inference_wrapper.py +91 -0
  26. interpreto-0.5.0.dev1/interpreto/model_wrapping/split_sequence_classification.py +427 -0
  27. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/typing.py +1 -8
  28. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto.egg-info/PKG-INFO +23 -13
  29. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto.egg-info/SOURCES.txt +10 -1
  30. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto.egg-info/requires.txt +1 -1
  31. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/mkdocs.yml +17 -9
  32. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/pyproject.toml +2 -2
  33. interpreto-0.5.0.dev0/interpreto/concepts/plots/__init__.py +0 -0
  34. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/.editorconfig +0 -0
  35. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/.gitignore +0 -0
  36. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/.pre-commit-config.yaml +0 -0
  37. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/AGENTS.md +0 -0
  38. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/LICENSE +0 -0
  39. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/MANIFEST.in +0 -0
  40. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/_vendor/overcomplete/VENDORED_FROM +0 -0
  41. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/_vendor/overcomplete/__init__.py +0 -0
  42. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/_vendor/overcomplete/_sync_overcomplete.py +0 -0
  43. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/_vendor/overcomplete/base.py +0 -0
  44. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/_vendor/overcomplete/data.py +0 -0
  45. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/_vendor/overcomplete/metrics.py +0 -0
  46. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/_vendor/overcomplete/optimization/__init__.py +0 -0
  47. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/_vendor/overcomplete/optimization/archetypal_analysis.py +0 -0
  48. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/_vendor/overcomplete/optimization/base.py +0 -0
  49. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/_vendor/overcomplete/optimization/convex_nmf.py +0 -0
  50. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/_vendor/overcomplete/optimization/nmf.py +0 -0
  51. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/_vendor/overcomplete/optimization/semi_nmf.py +0 -0
  52. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/_vendor/overcomplete/optimization/sklearn_wrappers.py +0 -0
  53. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/_vendor/overcomplete/optimization/utils.py +0 -0
  54. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/_vendor/overcomplete/overcomplete.patch +0 -0
  55. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/_vendor/overcomplete/sae/__init__.py +0 -0
  56. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/_vendor/overcomplete/sae/archetypal_dictionary.py +0 -0
  57. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/_vendor/overcomplete/sae/base.py +0 -0
  58. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/_vendor/overcomplete/sae/batchtopk_sae.py +0 -0
  59. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/_vendor/overcomplete/sae/dictionary.py +0 -0
  60. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/_vendor/overcomplete/sae/factory.py +0 -0
  61. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/_vendor/overcomplete/sae/jump_sae.py +0 -0
  62. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/_vendor/overcomplete/sae/kernels.py +0 -0
  63. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/_vendor/overcomplete/sae/losses.py +0 -0
  64. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/_vendor/overcomplete/sae/modules.py +0 -0
  65. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/_vendor/overcomplete/sae/mp_sae.py +0 -0
  66. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/_vendor/overcomplete/sae/omp_sae.py +0 -0
  67. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/_vendor/overcomplete/sae/optimizer.py +0 -0
  68. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/_vendor/overcomplete/sae/qsae.py +0 -0
  69. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/_vendor/overcomplete/sae/rasae.py +0 -0
  70. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/_vendor/overcomplete/sae/topk_sae.py +0 -0
  71. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/_vendor/overcomplete/sae/trackers.py +0 -0
  72. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/_vendor/overcomplete/sae/train.py +0 -0
  73. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/attributions/__init__.py +0 -0
  74. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/attributions/aggregations/__init__.py +0 -0
  75. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/attributions/aggregations/base.py +0 -0
  76. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/attributions/aggregations/linear_regression_aggregation.py +0 -0
  77. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/attributions/aggregations/sobol_aggregation.py +0 -0
  78. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/attributions/methods/__init__.py +0 -0
  79. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/attributions/methods/gradient_shap.py +0 -0
  80. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/attributions/methods/integrated_gradients.py +0 -0
  81. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/attributions/methods/saliency.py +0 -0
  82. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/attributions/methods/smooth_grad.py +0 -0
  83. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/attributions/methods/square_grad.py +0 -0
  84. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/attributions/methods/var_grad.py +0 -0
  85. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/attributions/metrics/__init__.py +0 -0
  86. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/attributions/metrics/insertion_deletion.py +0 -0
  87. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/attributions/perturbations/__init__.py +0 -0
  88. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/attributions/perturbations/base.py +0 -0
  89. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/attributions/perturbations/gaussian_noise_perturbation.py +0 -0
  90. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/attributions/perturbations/gradient_shap_perturbation.py +0 -0
  91. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/attributions/perturbations/insertion_deletion_perturbation.py +0 -0
  92. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/attributions/perturbations/linear_interpolation_perturbation.py +0 -0
  93. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/attributions/perturbations/occlusion_perturbation.py +0 -0
  94. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/attributions/perturbations/random_perturbation.py +0 -0
  95. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/attributions/perturbations/shap_perturbation.py +0 -0
  96. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/attributions/perturbations/sobol_perturbation.py +0 -0
  97. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/attributions/plots/__init__.py +0 -0
  98. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/commons/__init__.py +0 -0
  99. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/commons/distances.py +0 -0
  100. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/commons/generator_tools.py +0 -0
  101. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/commons/granularity.py +0 -0
  102. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/concepts/interpretations/__init__.py +0 -0
  103. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/concepts/methods/__init__.py +0 -0
  104. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/concepts/methods/cockatiel.py +0 -0
  105. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/concepts/methods/neurons_as_concepts.py +0 -0
  106. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/concepts/methods/overcomplete.py +0 -0
  107. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/concepts/metrics/__init__.py +0 -0
  108. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/concepts/metrics/consim.py +0 -0
  109. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/concepts/metrics/dictionary_metrics.py +0 -0
  110. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/concepts/metrics/reconstruction_metrics.py +0 -0
  111. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/concepts/metrics/sparsity_metrics.py +0 -0
  112. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/model_wrapping/classification_inference_wrapper.py +0 -0
  113. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/model_wrapping/generation_inference_wrapper.py +0 -0
  114. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/model_wrapping/llm_interface.py +0 -0
  115. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/model_wrapping/model_with_split_points.py +0 -0
  116. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/model_wrapping/splitting_utils.py +0 -0
  117. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/model_wrapping/transformers_classes.py +0 -0
  118. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/visualizations/README.md +0 -0
  119. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/visualizations/__init__.py +0 -0
  120. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/visualizations/attributions.py +0 -0
  121. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/visualizations/commons.py +0 -0
  122. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/visualizations/concepts.py +0 -0
  123. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/visualizations/css/visualization.css +0 -0
  124. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/visualizations/js/core/dom_renderer.js +0 -0
  125. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/visualizations/js/core/state_manager.js +0 -0
  126. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/visualizations/js/core/style_computer.js +0 -0
  127. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/visualizations/js/core/view_updater.js +0 -0
  128. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/visualizations/js/visualizations/attribution_classification.js +0 -0
  129. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/visualizations/js/visualizations/attribution_generation.js +0 -0
  130. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/visualizations/js/visualizations/concepts_classification_global.js +0 -0
  131. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/visualizations/js/visualizations/concepts_classification_local.js +0 -0
  132. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto/visualizations/js/visualizations/concepts_generation_local.js +0 -0
  133. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto.egg-info/dependency_links.txt +0 -0
  134. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/interpreto.egg-info/top_level.txt +0 -0
  135. {interpreto-0.5.0.dev0 → interpreto-0.5.0.dev1}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: interpreto
3
- Version: 0.5.0.dev0
3
+ Version: 0.5.0.dev1
4
4
  Summary: Interpretability toolbox for LLMs
5
5
  Author: FOR Team
6
6
  Author-email: fanny.jourdan@irt-saintexupery.com
@@ -57,7 +57,7 @@ License-File: LICENSE
57
57
  Requires-Dist: transformers>=4.22.0
58
58
  Requires-Dist: nltk
59
59
  Requires-Dist: torch>=2.0
60
- Requires-Dist: nnsight<0.6.0,>=0.5.1
60
+ Requires-Dist: nnsight<0.8.0,>=0.7.0
61
61
  Requires-Dist: jaxtyping<=0.2.36
62
62
  Requires-Dist: beartype
63
63
  Requires-Dist: mknotebooks
@@ -120,6 +120,7 @@ Dynamic: license-file
120
120
  <p align="center">
121
121
  <a href="https://for-sight-ai.github.io/interpreto/"><strong>📚 Explore Interpreto docs &gt;&gt;</strong></a><br />
122
122
  <a href="https://for-sight-ai.github.io/interpreto-demo/"><strong>🖼️ Checkout our explanation gallery &gt;&gt;</strong></a>
123
+ <a href="https://arxiv.org/abs/2512.09730"><strong>📜 Read our paper &gt;&gt;</strong></a>
123
124
  </p>
124
125
 
125
126
  ## 🚀 Quick Start
@@ -163,41 +164,51 @@ They all work seamlessly for both classification (`...ForSequenceClassification`
163
164
 
164
165
  Concept-based explanations aim to provide high-level interpretations of latent model representations.
165
166
 
166
- Interpreto generalizes these methods through four core steps:
167
+ We propose both supervised (probes and CAVs) and unsupervised (dictionary learning) approaches.
168
+
169
+ Interpreto generalizes these methods through four core steps, the two first are common between both approaches:
167
170
 
168
171
  1. Split a model in two and obtain a dataset of activations
169
- 2. Concept Discovery (e.g., from latent embeddings)
170
- 3. Concept Interpretation (mapping discovered concepts to human-understandable elements)
171
- 4. Concept-to-Output Attribution (assessing concept relevance to model outputs)
172
+ 2. Learn concepts (e.g., from latent embeddings)
173
+ 3. Interpret concepts (mapping discovered concepts to human-understandable elements)
174
+ 4. Estimate concepts importance (assessing concept relevance to model outputs)
172
175
 
173
176
  **1. Split a model in two and obtain a dataset of activations:** (mainly via [`nnsight`](https://github.com/ndif-team/nnsight)):
174
177
 
175
178
  Choose any layer in any HuggingFace language model with our `ModelWithSplitPoints` based on `nnsight`. Then pass a dataset through it to obtain a dataset of activations.
176
179
 
177
- **2. Dictionary Learning for Concept Discovery** (mainly via [`overcomplete`](https://github.com/KempnerInstitute/overcomplete)):
180
+ **2. (supervised) Train probe** with the [`ProbeExplainer`](https://for-sight-ai.github.io/interpreto/api/concepts/probes/)
181
+
182
+ We differentiate two families of probes:
183
+
184
+ - Linear probes: [`LinearRegressionProbe`](https://for-sight-ai.github.io/interpreto/api/concepts/probes/#interpreto.concepts.probes.LinearRegressionProbe), [`LogisticRegressionProbe`](https://for-sight-ai.github.io/interpreto/api/concepts/probes/#interpreto.concepts.probes.LogisticRegressionProbe), [`LinearSVMProbe`](https://for-sight-ai.github.io/interpreto/api/concepts/probes/#interpreto.concepts.probes.LinearSVMProbe), [`MeansDiffProbe`](https://for-sight-ai.github.io/interpreto/api/concepts/probes/#interpreto.concepts.probes.MeansDiffProbe)
185
+ - Centroid-based probes: [`CosineCentroidProbe`](https://for-sight-ai.github.io/interpreto/api/concepts/probes/#interpreto.concepts.probes.CosineCentroidProbe), [`DotProductCentroidProbe`](https://for-sight-ai.github.io/interpreto/api/concepts/probes/#interpreto.concepts.probes.DotProductCentroidProbe), [`SqL2CentroidProbe`](https://for-sight-ai.github.io/interpreto/api/concepts/probes/#interpreto.concepts.probes.SqL2CentroidProbe), [`SVDDCentroidProbe`](https://for-sight-ai.github.io/interpreto/api/concepts/probes/#interpreto.concepts.probes.SVDDCentroidProbe), [`DiagonalMahalanobisCentroidProbe`](https://for-sight-ai.github.io/interpreto/api/concepts/probes/#interpreto.concepts.probes.DiagonalMahalanobisCentroidProbe)
186
+
187
+ Both can be tuned with `bias_calibrator` and `normalization` parameters.
188
+
189
+ **2. (unsupervised) Dictionary Learning for Concept Discovery** (mainly via [`overcomplete`](https://github.com/KempnerInstitute/overcomplete)):
178
190
 
179
191
  - Interpret neurons directly via [`NeuronsAsConcepts`](https://for-sight-ai.github.io/interpreto/api/concepts/methods/neurons_as_concepts/)
180
192
  - [`NMF`](https://for-sight-ai.github.io/interpreto/api/concepts/methods/optim/#interpreto.concepts.NMFConcepts), [`Semi-NMF`](https://for-sight-ai.github.io/interpreto/api/concepts/methods/optim/#interpreto.concepts.SemiNMFConcepts), [`ConvexNMF`](https://for-sight-ai.github.io/interpreto/api/concepts/methods/optim/#interpreto.concepts.ConvexNMFConcepts)
181
193
  - [`ICA`](https://for-sight-ai.github.io/interpreto/api/concepts/methods/optim/#interpreto.concepts.ICAConcepts), [`SVD`](https://for-sight-ai.github.io/interpreto/api/concepts/methods/optim/#interpreto.concepts.SVDConcepts), [`PCA`](https://for-sight-ai.github.io/interpreto/api/concepts/methods/optim/#interpreto.concepts.PCAConcepts), [`KMeans`](https://for-sight-ai.github.io/interpreto/api/concepts/methods/optim/#interpreto.concepts.KMeansConcepts)
182
194
  - SAE variants: [`Vanilla SAE`](https://for-sight-ai.github.io/interpreto/api/concepts/methods/sae/#interpreto.concepts.VanillaSAEConcepts), [`TopK SAE`](https://for-sight-ai.github.io/interpreto/api/concepts/methods/sae/#interpreto.concepts.TopKSAEConcepts), [`JumpReLU SAE`](https://for-sight-ai.github.io/interpreto/api/concepts/methods/sae/#interpreto.concepts.JumpReLUSAEConcepts), [`BatchTopK SAE`](https://for-sight-ai.github.io/interpreto/api/concepts/methods/sae/#interpreto.concepts.BatchTopKSAEConcepts)
183
195
 
184
- **3. Available Concept Interpretation Techniques:**
196
+ **3. (unsupervised) Available Concept Interpretation Techniques:**
185
197
 
186
198
  - Top-k tokens from tokenizer vocabulary via [`TopKInputs`](https://for-sight-ai.github.io/interpreto/api/concepts/concepts_interpretations/#interpreto.concepts.interpretations.TopKInputs) and `use_vocab=True`
187
199
  - Top-k tokens/words/sentences/samples from specific datasets via [`TopKInputs`](https://for-sight-ai.github.io/interpreto/api/concepts/concepts_interpretations/#interpreto.concepts.interpretations.TopKInputs)
188
200
  - Label concepts via LLMs with [`LLMLabels`](https://for-sight-ai.github.io/interpreto/api/concepts/concepts_interpretations/#interpreto.concepts.interpretations.LLMLabels) ([Bills et al. 2023](https://openai.com/index/language-models-can-explain-neurons-in-language-models/))
201
+ - Input-to-concept attribution from dataset examples ([Concept Attributions](https://for-sight-ai.github.io/interpreto/api/concepts/interpretations/concept_attributions/)) ([Jourdan et al. 2023](https://aclanthology.org/2023.findings-acl.317/))
189
202
 
190
203
  <details><summary>Concept Interpretation Techniques Added in the future:</summary>
191
204
 
192
- - Input-to-concept attribution from dataset examples ([Jourdan et al. 2023](https://aclanthology.org/2023.findings-acl.317/))
193
- - Theme prediction via LLMs from top-k tokens/sentences
194
205
  - Aligning concepts with human labels ([Sajjad et al. 2022](https://aclanthology.org/2022.naacl-main.225/))
195
206
  - Word cloud visualizations of concepts ([Dalvi et al. 2022](https://arxiv.org/abs/2205.07237))
196
207
  - VocabProj & TokenChange ([Gur-Arieh et al. 2025](https://arxiv.org/abs/2501.08319))
197
208
 
198
209
  </details>
199
210
 
200
- **4. Concept-to-Output Attribution:**
211
+ **4. (unsupervised) Concept-to-Output Attribution:**
201
212
 
202
213
  Estimate the contribution of each concept to the model output.
203
214
 
@@ -207,7 +218,6 @@ Can be obtained with any concept-based explainer via [`MethodConcepts.concept_ou
207
218
 
208
219
  Thanks to this generalization encompassing all concept-based methods and our highly flexible architecture, we can easily obtain a large number of concept-based methods:
209
220
 
210
- - CAV and TCAV: [Kim et al. 2018, Interpretability Beyond Feature Attribution: Quantitative Testing with Concept Activation Vectors (TCAV)](http://proceedings.mlr.press/v80/kim18d.html)
211
221
  - ConceptSHAP: [Yeh et al. 2020, On Completeness-aware Concept-Based Explanations in Deep Neural Networks](https://proceedings.neurips.cc/paper/2020/hash/ecb287ff763c169694f682af52c1f309-Abstract.html)
212
222
  - COCKATIEL: [Jourdan et al. 2023, COCKATIEL: COntinuous Concept ranKed ATtribution with Interpretable ELements for explaining neural net classifiers on NLP](https://aclanthology.org/2023.findings-acl.317/)
213
223
  - Yun et al. 2021, [Transformer visualization via dictionary learning: contextualized embedding as a linear superposition of transformer factors](https://arxiv.org/abs/2103.15949)
@@ -264,7 +274,7 @@ Interpreto 🪄 is a project of the [FOR](https://www.irt-saintexupery.com/fr/fo
264
274
 
265
275
  ## 🗞️ Citation
266
276
 
267
- If you use Interpreto 🪄 as part of your workflow in a scientific publication, please consider citing 🗞️ our paper:
277
+ If you use Interpreto 🪄 as part of your workflow in a scientific publication, please consider citing 🗞️ [our paper](https://arxiv.org/abs/2512.09730):
268
278
 
269
279
  ```bibtex
270
280
  @article{poche2025interpreto,
@@ -15,6 +15,7 @@
15
15
  <p align="center">
16
16
  <a href="https://for-sight-ai.github.io/interpreto/"><strong>📚 Explore Interpreto docs &gt;&gt;</strong></a><br />
17
17
  <a href="https://for-sight-ai.github.io/interpreto-demo/"><strong>🖼️ Checkout our explanation gallery &gt;&gt;</strong></a>
18
+ <a href="https://arxiv.org/abs/2512.09730"><strong>📜 Read our paper &gt;&gt;</strong></a>
18
19
  </p>
19
20
 
20
21
  ## 🚀 Quick Start
@@ -58,41 +59,51 @@ They all work seamlessly for both classification (`...ForSequenceClassification`
58
59
 
59
60
  Concept-based explanations aim to provide high-level interpretations of latent model representations.
60
61
 
61
- Interpreto generalizes these methods through four core steps:
62
+ We propose both supervised (probes and CAVs) and unsupervised (dictionary learning) approaches.
63
+
64
+ Interpreto generalizes these methods through four core steps, the two first are common between both approaches:
62
65
 
63
66
  1. Split a model in two and obtain a dataset of activations
64
- 2. Concept Discovery (e.g., from latent embeddings)
65
- 3. Concept Interpretation (mapping discovered concepts to human-understandable elements)
66
- 4. Concept-to-Output Attribution (assessing concept relevance to model outputs)
67
+ 2. Learn concepts (e.g., from latent embeddings)
68
+ 3. Interpret concepts (mapping discovered concepts to human-understandable elements)
69
+ 4. Estimate concepts importance (assessing concept relevance to model outputs)
67
70
 
68
71
  **1. Split a model in two and obtain a dataset of activations:** (mainly via [`nnsight`](https://github.com/ndif-team/nnsight)):
69
72
 
70
73
  Choose any layer in any HuggingFace language model with our `ModelWithSplitPoints` based on `nnsight`. Then pass a dataset through it to obtain a dataset of activations.
71
74
 
72
- **2. Dictionary Learning for Concept Discovery** (mainly via [`overcomplete`](https://github.com/KempnerInstitute/overcomplete)):
75
+ **2. (supervised) Train probe** with the [`ProbeExplainer`](https://for-sight-ai.github.io/interpreto/api/concepts/probes/)
76
+
77
+ We differentiate two families of probes:
78
+
79
+ - Linear probes: [`LinearRegressionProbe`](https://for-sight-ai.github.io/interpreto/api/concepts/probes/#interpreto.concepts.probes.LinearRegressionProbe), [`LogisticRegressionProbe`](https://for-sight-ai.github.io/interpreto/api/concepts/probes/#interpreto.concepts.probes.LogisticRegressionProbe), [`LinearSVMProbe`](https://for-sight-ai.github.io/interpreto/api/concepts/probes/#interpreto.concepts.probes.LinearSVMProbe), [`MeansDiffProbe`](https://for-sight-ai.github.io/interpreto/api/concepts/probes/#interpreto.concepts.probes.MeansDiffProbe)
80
+ - Centroid-based probes: [`CosineCentroidProbe`](https://for-sight-ai.github.io/interpreto/api/concepts/probes/#interpreto.concepts.probes.CosineCentroidProbe), [`DotProductCentroidProbe`](https://for-sight-ai.github.io/interpreto/api/concepts/probes/#interpreto.concepts.probes.DotProductCentroidProbe), [`SqL2CentroidProbe`](https://for-sight-ai.github.io/interpreto/api/concepts/probes/#interpreto.concepts.probes.SqL2CentroidProbe), [`SVDDCentroidProbe`](https://for-sight-ai.github.io/interpreto/api/concepts/probes/#interpreto.concepts.probes.SVDDCentroidProbe), [`DiagonalMahalanobisCentroidProbe`](https://for-sight-ai.github.io/interpreto/api/concepts/probes/#interpreto.concepts.probes.DiagonalMahalanobisCentroidProbe)
81
+
82
+ Both can be tuned with `bias_calibrator` and `normalization` parameters.
83
+
84
+ **2. (unsupervised) Dictionary Learning for Concept Discovery** (mainly via [`overcomplete`](https://github.com/KempnerInstitute/overcomplete)):
73
85
 
74
86
  - Interpret neurons directly via [`NeuronsAsConcepts`](https://for-sight-ai.github.io/interpreto/api/concepts/methods/neurons_as_concepts/)
75
87
  - [`NMF`](https://for-sight-ai.github.io/interpreto/api/concepts/methods/optim/#interpreto.concepts.NMFConcepts), [`Semi-NMF`](https://for-sight-ai.github.io/interpreto/api/concepts/methods/optim/#interpreto.concepts.SemiNMFConcepts), [`ConvexNMF`](https://for-sight-ai.github.io/interpreto/api/concepts/methods/optim/#interpreto.concepts.ConvexNMFConcepts)
76
88
  - [`ICA`](https://for-sight-ai.github.io/interpreto/api/concepts/methods/optim/#interpreto.concepts.ICAConcepts), [`SVD`](https://for-sight-ai.github.io/interpreto/api/concepts/methods/optim/#interpreto.concepts.SVDConcepts), [`PCA`](https://for-sight-ai.github.io/interpreto/api/concepts/methods/optim/#interpreto.concepts.PCAConcepts), [`KMeans`](https://for-sight-ai.github.io/interpreto/api/concepts/methods/optim/#interpreto.concepts.KMeansConcepts)
77
89
  - SAE variants: [`Vanilla SAE`](https://for-sight-ai.github.io/interpreto/api/concepts/methods/sae/#interpreto.concepts.VanillaSAEConcepts), [`TopK SAE`](https://for-sight-ai.github.io/interpreto/api/concepts/methods/sae/#interpreto.concepts.TopKSAEConcepts), [`JumpReLU SAE`](https://for-sight-ai.github.io/interpreto/api/concepts/methods/sae/#interpreto.concepts.JumpReLUSAEConcepts), [`BatchTopK SAE`](https://for-sight-ai.github.io/interpreto/api/concepts/methods/sae/#interpreto.concepts.BatchTopKSAEConcepts)
78
90
 
79
- **3. Available Concept Interpretation Techniques:**
91
+ **3. (unsupervised) Available Concept Interpretation Techniques:**
80
92
 
81
93
  - Top-k tokens from tokenizer vocabulary via [`TopKInputs`](https://for-sight-ai.github.io/interpreto/api/concepts/concepts_interpretations/#interpreto.concepts.interpretations.TopKInputs) and `use_vocab=True`
82
94
  - Top-k tokens/words/sentences/samples from specific datasets via [`TopKInputs`](https://for-sight-ai.github.io/interpreto/api/concepts/concepts_interpretations/#interpreto.concepts.interpretations.TopKInputs)
83
95
  - Label concepts via LLMs with [`LLMLabels`](https://for-sight-ai.github.io/interpreto/api/concepts/concepts_interpretations/#interpreto.concepts.interpretations.LLMLabels) ([Bills et al. 2023](https://openai.com/index/language-models-can-explain-neurons-in-language-models/))
96
+ - Input-to-concept attribution from dataset examples ([Concept Attributions](https://for-sight-ai.github.io/interpreto/api/concepts/interpretations/concept_attributions/)) ([Jourdan et al. 2023](https://aclanthology.org/2023.findings-acl.317/))
84
97
 
85
98
  <details><summary>Concept Interpretation Techniques Added in the future:</summary>
86
99
 
87
- - Input-to-concept attribution from dataset examples ([Jourdan et al. 2023](https://aclanthology.org/2023.findings-acl.317/))
88
- - Theme prediction via LLMs from top-k tokens/sentences
89
100
  - Aligning concepts with human labels ([Sajjad et al. 2022](https://aclanthology.org/2022.naacl-main.225/))
90
101
  - Word cloud visualizations of concepts ([Dalvi et al. 2022](https://arxiv.org/abs/2205.07237))
91
102
  - VocabProj & TokenChange ([Gur-Arieh et al. 2025](https://arxiv.org/abs/2501.08319))
92
103
 
93
104
  </details>
94
105
 
95
- **4. Concept-to-Output Attribution:**
106
+ **4. (unsupervised) Concept-to-Output Attribution:**
96
107
 
97
108
  Estimate the contribution of each concept to the model output.
98
109
 
@@ -102,7 +113,6 @@ Can be obtained with any concept-based explainer via [`MethodConcepts.concept_ou
102
113
 
103
114
  Thanks to this generalization encompassing all concept-based methods and our highly flexible architecture, we can easily obtain a large number of concept-based methods:
104
115
 
105
- - CAV and TCAV: [Kim et al. 2018, Interpretability Beyond Feature Attribution: Quantitative Testing with Concept Activation Vectors (TCAV)](http://proceedings.mlr.press/v80/kim18d.html)
106
116
  - ConceptSHAP: [Yeh et al. 2020, On Completeness-aware Concept-Based Explanations in Deep Neural Networks](https://proceedings.neurips.cc/paper/2020/hash/ecb287ff763c169694f682af52c1f309-Abstract.html)
107
117
  - COCKATIEL: [Jourdan et al. 2023, COCKATIEL: COntinuous Concept ranKed ATtribution with Interpretable ELements for explaining neural net classifiers on NLP](https://aclanthology.org/2023.findings-acl.317/)
108
118
  - Yun et al. 2021, [Transformer visualization via dictionary learning: contextualized embedding as a linear superposition of transformer factors](https://arxiv.org/abs/2103.15949)
@@ -159,7 +169,7 @@ Interpreto 🪄 is a project of the [FOR](https://www.irt-saintexupery.com/fr/fo
159
169
 
160
170
  ## 🗞️ Citation
161
171
 
162
- If you use Interpreto 🪄 as part of your workflow in a scientific publication, please consider citing 🗞️ our paper:
172
+ If you use Interpreto 🪄 as part of your workflow in a scientific publication, please consider citing 🗞️ [our paper](https://arxiv.org/abs/2512.09730):
163
173
 
164
174
  ```bibtex
165
175
  @article{poche2025interpreto,
@@ -42,7 +42,7 @@ from .attributions import (
42
42
  from .commons import (
43
43
  Granularity,
44
44
  )
45
- from .model_wrapping import ModelWithSplitPoints
45
+ from .model_wrapping import ModelWithSplitPoints, SplitSequenceClassification
46
46
  from .visualizations import (
47
47
  AttributionVisualization,
48
48
  plot_attributions,
@@ -71,6 +71,7 @@ __all__ = [
71
71
  "SquareGrad",
72
72
  "Saliency",
73
73
  "SmoothGrad",
74
+ "SplitSequenceClassification",
74
75
  "Sobol",
75
76
  "VarGrad",
76
77
  "get_version",
@@ -0,0 +1,16 @@
1
+ from abc import ABC
2
+
3
+ import torch
4
+ from torch import nn
5
+
6
+ class BaseDictionaryLearning(ABC, nn.Module):
7
+ nb_concepts: int
8
+ device: str | torch.device
9
+ fitted: bool
10
+
11
+ def __init__(self, nb_concepts: int, device: str | torch.device = "cpu") -> None: ...
12
+ def encode(self, x: torch.Tensor) -> torch.Tensor: ...
13
+ def decode(self, z: torch.Tensor) -> torch.Tensor: ...
14
+ def fit(self, x: torch.Tensor) -> None: ...
15
+ def get_dictionary(self) -> torch.Tensor: ...
16
+ def to(self, device: str | torch.device) -> "BaseDictionaryLearning": ... # type: ignore[override]
@@ -45,13 +45,17 @@ from interpreto.attributions.perturbations.base import Perturbator
45
45
  from interpreto.commons import Granularity
46
46
  from interpreto.commons.generator_tools import split_iterator
47
47
  from interpreto.commons.granularity import GranularityAggregationStrategy
48
+ from interpreto.concepts.base import ModelForInputsToConcepts
48
49
  from interpreto.model_wrapping.classification_inference_wrapper import ClassificationInferenceWrapper
49
50
  from interpreto.model_wrapping.generation_inference_wrapper import GenerationInferenceWrapper
50
51
  from interpreto.model_wrapping.inference_wrapper import InferenceModes, InferenceWrapper
52
+ from interpreto.model_wrapping.inputs_to_concepts_inference_wrapper import InputsToConceptsInferenceWrapper
51
53
  from interpreto.typing import ClassificationTarget, GeneratedTarget, ModelInputs, SingleAttribution, TensorMapping
52
54
 
53
55
 
54
- def setup_token_ids(model: PreTrainedModel, tokenizer: PreTrainedTokenizer, require_mask_token: bool = True) -> int:
56
+ def setup_token_ids(
57
+ model: PreTrainedModel | ModelForInputsToConcepts, tokenizer: PreTrainedTokenizer, require_mask_token: bool = True
58
+ ) -> int:
55
59
  """
56
60
  Setup the tokenizer and the model with the appropriate token IDs, for padding and masking.
57
61
 
@@ -139,6 +143,7 @@ class ModelTask(Enum):
139
143
 
140
144
  CLASSIFICATION = "classification"
141
145
  GENERATION = "generation"
146
+ CONCEPTS = "concepts"
142
147
 
143
148
 
144
149
  def clone_tensor_mapping(tm: TensorMapping, detach: bool = False) -> TensorMapping:
@@ -232,7 +237,7 @@ class AttributionExplainer:
232
237
 
233
238
  def __init__(
234
239
  self,
235
- model: PreTrainedModel,
240
+ model: PreTrainedModel | ModelForInputsToConcepts,
236
241
  tokenizer: PreTrainedTokenizer,
237
242
  batch_size: int = 4,
238
243
  perturbator: Perturbator | None = None,
@@ -248,7 +253,7 @@ class AttributionExplainer:
248
253
  Initializes the AttributionExplainer.
249
254
 
250
255
  Args:
251
- model (PreTrainedModel): The model to be explained.
256
+ model (PreTrainedModel | ModelForInputsToConcepts): The model to be explained.
252
257
  tokenizer (PreTrainedTokenizer): The tokenizer associated with the model.
253
258
  batch_size (int): The batch size used for model inference.
254
259
  perturbator (Perturbator, optional): Instance used to generate input perturbations.
@@ -276,7 +281,7 @@ class AttributionExplainer:
276
281
  self.tokenizer = tokenizer
277
282
 
278
283
  self.inference_wrapper = self._associated_inference_wrapper(
279
- model,
284
+ model, # type: ignore
280
285
  gradients=use_gradient,
281
286
  input_x_gradient=input_x_gradient,
282
287
  batch_size=batch_size,
@@ -846,13 +851,99 @@ class GenerationAttributionExplainer(AttributionExplainer):
846
851
  return ModelTask.GENERATION, contribution
847
852
 
848
853
 
854
+ class InputsToConceptsAttributionsExplainer(AttributionExplainer):
855
+ """Attribution explainer for input-to-concept models.
856
+
857
+ This explainer computes how much each input token contributes to each concept
858
+ activation. It bridges the attribution framework with the concept framework:
859
+ once a concept explainer is fitted, its ``inputs_to_concepts`` property returns
860
+ a model that can be passed to any perturbation-based attribution method.
861
+
862
+ The result is a per-token attribution for each concept, revealing which parts
863
+ of the input are most responsible for activating a given concept.
864
+
865
+ Note:
866
+ Only perturbation-based methods (Lime, KernelShap, Occlusion, Sobol) are
867
+ supported. Gradient-based methods are incompatible because the
868
+ ``ModelForInputsToConcepts`` is based on `nnsight` and which make differentiation complex.
869
+
870
+ Example:
871
+ ```python
872
+ from interpreto import Occlusion, SplitSequenceClassification
873
+ from interpreto.concepts import SemiNMFConcepts
874
+
875
+ split_model = SplitSequenceClassification("model_id", device_map="cuda")
876
+ concept_explainer = SemiNMFConcepts(split_model, nb_concepts=20)
877
+ concept_explainer.fit(activations)
878
+
879
+ explainer = Occlusion(concept_explainer.inputs_to_concepts, split_model.tokenizer)
880
+ results = explainer.explain("Some input text.", targets=torch.arange(5))
881
+ ```
882
+ """
883
+
884
+ _associated_inference_wrapper = InputsToConceptsInferenceWrapper
885
+ inference_wrapper: InputsToConceptsInferenceWrapper
886
+
887
+ def process_inputs_to_explain_and_targets( # type: ignore
888
+ self,
889
+ model_inputs: ModelInputs,
890
+ targets: Iterable[int] | None = None,
891
+ ) -> tuple[list[TensorMapping], list[Int[torch.Tensor, "t"]]]:
892
+ """
893
+ Processes the inputs and targets for explanation.
894
+
895
+ This method must be implemented by subclasses.
896
+
897
+ Args:
898
+ model_inputs (ModelInputs):
899
+ The inputs to the model.
900
+ targets (Optional[Iterable[int]]):
901
+ The targets to be explained.
902
+ If None, all concepts are explained.
903
+
904
+ Returns:
905
+ processed_inputs (list[TensorMapping]):
906
+ The processed inputs.
907
+ processed_targets (list[Int[torch.Tensor, "t"]]):
908
+ The processed targets.
909
+ """
910
+ sanitized_targets: list[Int[torch.Tensor, "t"]]
911
+ if targets is None:
912
+ # explain all concepts
913
+ input_wise_targets = torch.arange(self.inference_wrapper.model.nb_concepts) # type: ignore
914
+ sanitized_targets = [input_wise_targets] * len(model_inputs) # type: ignore
915
+ else:
916
+ # targets are concept indices, shared across all inputs
917
+ if isinstance(targets, torch.Tensor):
918
+ input_wise_targets = targets.long()
919
+ else:
920
+ input_wise_targets = torch.tensor(list(targets), dtype=torch.long)
921
+ sanitized_targets = [input_wise_targets] * len(model_inputs) # type: ignore
922
+ return model_inputs, sanitized_targets # type: ignore
923
+
924
+ def post_processing(self, contribution: Float[torch.Tensor, "t l"]):
925
+ """
926
+ Concepts specific post-processing of the attribution scores.
927
+
928
+ No post-processing is required for concept attributions.
929
+
930
+ Args:
931
+ contribution (Float[torch.Tensor, "t l"]): The contribution values.
932
+
933
+ Returns:
934
+ model_task (ModelTask): The model task.
935
+ contribution (Float[torch.Tensor, "t l"]): The post-processed contribution values.
936
+ """
937
+ return ModelTask.CONCEPTS, contribution
938
+
939
+
849
940
  class FactoryGeneratedMeta(type):
850
941
  """
851
942
  Metaclass to distinguish classes generated by the MultitaskExplainerMixin.
852
943
  """
853
944
 
854
945
 
855
- class MultitaskExplainerMixin(AttributionExplainer):
946
+ class MultitaskExplainerMixin:
856
947
  """
857
948
  Mixin class to generate the appropriate Explainer based on the model type.
858
949
  """
@@ -866,6 +957,11 @@ class MultitaskExplainerMixin(AttributionExplainer):
866
957
  if model.__class__.__name__.endswith("ForCausalLM") or model.__class__.__name__.endswith("LMHeadModel"):
867
958
  t = FactoryGeneratedMeta("Generation" + cls.__name__, (cls, GenerationAttributionExplainer), {})
868
959
  return t.__new__(t, model, *args, **kwargs) # type: ignore
960
+ if model.__class__.__name__.endswith("ForInputsToConcepts"):
961
+ t = FactoryGeneratedMeta(
962
+ "InputsToConcepts" + cls.__name__, (cls, InputsToConceptsAttributionsExplainer), {}
963
+ )
964
+ return t.__new__(t, model, *args, **kwargs) # type: ignore
869
965
  raise NotImplementedError(
870
966
  "Model type not supported for Explainer. Use a ModelForSequenceClassification, a ModelForCausalLM model or a LMHeadModel model."
871
967
  )
@@ -40,6 +40,7 @@ from interpreto.attributions.aggregations.linear_regression_aggregation import (
40
40
  from interpreto.attributions.base import AttributionExplainer, MultitaskExplainerMixin, setup_token_ids
41
41
  from interpreto.attributions.perturbations.shap_perturbation import ShapTokenPerturbator
42
42
  from interpreto.commons.granularity import Granularity, GranularityAggregationStrategy
43
+ from interpreto.concepts.base import ModelForInputsToConcepts
43
44
  from interpreto.model_wrapping.inference_wrapper import InferenceModes
44
45
 
45
46
 
@@ -68,7 +69,7 @@ class KernelShap(MultitaskExplainerMixin, AttributionExplainer):
68
69
 
69
70
  def __init__(
70
71
  self,
71
- model: PreTrainedModel,
72
+ model: PreTrainedModel | ModelForInputsToConcepts,
72
73
  tokenizer: PreTrainedTokenizer,
73
74
  batch_size: int = 4,
74
75
  granularity: Granularity = Granularity.WORD,
@@ -81,7 +82,7 @@ class KernelShap(MultitaskExplainerMixin, AttributionExplainer):
81
82
  Initialize the attribution method.
82
83
 
83
84
  Args:
84
- model (PreTrainedModel): model to explain
85
+ model (PreTrainedModel | ModelForInputsToConcepts): model to explain
85
86
  tokenizer (PreTrainedTokenizer): Hugging Face tokenizer associated with the model
86
87
  batch_size (int): batch size for the attribution method
87
88
  granularity (Granularity, optional): The level of granularity for the explanation.
@@ -43,6 +43,7 @@ from interpreto.attributions.aggregations.linear_regression_aggregation import (
43
43
  from interpreto.attributions.base import AttributionExplainer, InferenceModes, MultitaskExplainerMixin, setup_token_ids
44
44
  from interpreto.attributions.perturbations.random_perturbation import RandomMaskedTokenPerturbator
45
45
  from interpreto.commons import Granularity, GranularityAggregationStrategy
46
+ from interpreto.concepts.base import ModelForInputsToConcepts
46
47
 
47
48
 
48
49
  class Lime(MultitaskExplainerMixin, AttributionExplainer):
@@ -72,7 +73,7 @@ class Lime(MultitaskExplainerMixin, AttributionExplainer):
72
73
 
73
74
  def __init__(
74
75
  self,
75
- model: PreTrainedModel,
76
+ model: PreTrainedModel | ModelForInputsToConcepts,
76
77
  tokenizer: PreTrainedTokenizer,
77
78
  batch_size: int = 4,
78
79
  granularity: Granularity = Granularity.WORD,
@@ -88,7 +89,7 @@ class Lime(MultitaskExplainerMixin, AttributionExplainer):
88
89
  Initialize the attribution method.
89
90
 
90
91
  Args:
91
- model (PreTrainedModel): model to explain
92
+ model (PreTrainedModel | ModelForInputsToConcepts): model to explain
92
93
  tokenizer (PreTrainedTokenizer): Hugging Face tokenizer associated with the model
93
94
  batch_size (int): batch size for the attribution method
94
95
  granularity (Granularity, optional): The level of granularity for the explanation.
@@ -29,10 +29,9 @@ Occlusion attribution method
29
29
  from __future__ import annotations
30
30
 
31
31
  from collections.abc import Callable
32
- from typing import Any
33
32
 
34
33
  import torch
35
- from transformers import PreTrainedTokenizer
34
+ from transformers import PreTrainedModel, PreTrainedTokenizer
36
35
 
37
36
  from interpreto.attributions.aggregations.base import OcclusionAggregator
38
37
  from interpreto.attributions.base import (
@@ -42,6 +41,7 @@ from interpreto.attributions.base import (
42
41
  )
43
42
  from interpreto.attributions.perturbations import OcclusionPerturbator
44
43
  from interpreto.commons.granularity import Granularity, GranularityAggregationStrategy
44
+ from interpreto.concepts.base import ModelForInputsToConcepts
45
45
  from interpreto.model_wrapping.inference_wrapper import InferenceModes
46
46
 
47
47
 
@@ -68,7 +68,7 @@ class Occlusion(MultitaskExplainerMixin, AttributionExplainer):
68
68
 
69
69
  def __init__(
70
70
  self,
71
- model: Any,
71
+ model: PreTrainedModel | ModelForInputsToConcepts,
72
72
  tokenizer: PreTrainedTokenizer,
73
73
  batch_size: int = 4,
74
74
  granularity: Granularity = Granularity.WORD,
@@ -80,7 +80,7 @@ class Occlusion(MultitaskExplainerMixin, AttributionExplainer):
80
80
  Initialize the attribution method.
81
81
 
82
82
  Args:
83
- model (PreTrainedModel): model to explain
83
+ model (PreTrainedModel | ModelForInputsToConcepts): model to explain
84
84
  tokenizer (PreTrainedTokenizer): Hugging Face tokenizer associated with the model
85
85
  batch_size (int): batch size for the attribution method
86
86
  granularity (Granularity, optional): The level of granularity for the explanation.
@@ -41,6 +41,7 @@ from interpreto.attributions.perturbations.sobol_perturbation import (
41
41
  SobolTokenPerturbator,
42
42
  )
43
43
  from interpreto.commons.granularity import Granularity, GranularityAggregationStrategy
44
+ from interpreto.concepts.base import ModelForInputsToConcepts
44
45
 
45
46
 
46
47
  class Sobol(MultitaskExplainerMixin, AttributionExplainer):
@@ -73,7 +74,7 @@ class Sobol(MultitaskExplainerMixin, AttributionExplainer):
73
74
 
74
75
  def __init__(
75
76
  self,
76
- model: PreTrainedModel,
77
+ model: PreTrainedModel | ModelForInputsToConcepts,
77
78
  tokenizer: PreTrainedTokenizer,
78
79
  batch_size: int = 4,
79
80
  granularity: Granularity = Granularity.WORD,
@@ -88,7 +89,7 @@ class Sobol(MultitaskExplainerMixin, AttributionExplainer):
88
89
  Initialize the attribution method.
89
90
 
90
91
  Args:
91
- model (PreTrainedModel): model to explain
92
+ model (PreTrainedModel | ModelForInputsToConcepts): model to explain
92
93
  tokenizer (PreTrainedTokenizer): Hugging Face tokenizer associated with the model
93
94
  batch_size (int): batch size for the attribution method
94
95
  granularity (Granularity, optional): The level of granularity for the explanation.
@@ -43,6 +43,18 @@ from .methods import (
43
43
  TopKSAEConcepts,
44
44
  VanillaSAEConcepts,
45
45
  )
46
+ from .probes import (
47
+ CosineCentroidProbe,
48
+ DiagonalMahalanobisCentroidProbe,
49
+ DotProductCentroidProbe,
50
+ LinearRegressionProbe,
51
+ LinearSVMProbe,
52
+ LogisticRegressionProbe,
53
+ MeansDiffProbe,
54
+ ProbeExplainer,
55
+ SqL2CentroidProbe,
56
+ SVDDCentroidProbe,
57
+ )
46
58
 
47
59
  __all__ = [
48
60
  "BatchTopKSAEConcepts",
@@ -50,19 +62,29 @@ __all__ = [
50
62
  "ConceptAutoEncoderExplainer",
51
63
  "ConceptEncoderExplainer",
52
64
  "ConvexNMFConcepts",
65
+ "CosineCentroidProbe",
66
+ "DiagonalMahalanobisCentroidProbe",
53
67
  "DictionaryLearningConcepts",
68
+ "DotProductCentroidProbe",
54
69
  "ICAConcepts",
55
70
  "JumpReLUSAEConcepts",
56
71
  "KMeansConcepts",
57
72
  "LLMLabels",
73
+ "LinearRegressionProbe",
74
+ "LinearSVMProbe",
75
+ "LogisticRegressionProbe",
76
+ "MeansDiffProbe",
58
77
  "MpSAEConcepts",
59
78
  "NeuronsAsConcepts",
60
79
  "NMFConcepts",
61
80
  "PCAConcepts",
81
+ "ProbeExplainer",
62
82
  "SAELossClasses",
63
83
  "SemiNMFConcepts",
64
84
  "SparsePCAConcepts",
85
+ "SqL2CentroidProbe",
65
86
  "SVDConcepts",
87
+ "SVDDCentroidProbe",
66
88
  "TopKInputs",
67
89
  "TopKSAEConcepts",
68
90
  "VanillaSAEConcepts",