interpreto 0.4.20__tar.gz → 0.5.0.dev1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (138) hide show
  1. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/.pre-commit-config.yaml +1 -1
  2. interpreto-0.5.0.dev1/AGENTS.md +214 -0
  3. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/PKG-INFO +23 -13
  4. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/README.md +21 -11
  5. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/interpreto/__init__.py +2 -1
  6. interpreto-0.5.0.dev1/interpreto/_vendor/overcomplete/base.pyi +16 -0
  7. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/interpreto/attributions/aggregations/base.py +3 -4
  8. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/interpreto/attributions/base.py +357 -180
  9. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/interpreto/attributions/methods/kernel_shap.py +7 -6
  10. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/interpreto/attributions/methods/lime.py +7 -6
  11. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/interpreto/attributions/methods/occlusion.py +8 -7
  12. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/interpreto/attributions/methods/saliency.py +2 -1
  13. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/interpreto/attributions/methods/sobol_attribution.py +8 -7
  14. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/interpreto/attributions/metrics/insertion_deletion.py +27 -20
  15. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/interpreto/attributions/perturbations/__init__.py +16 -1
  16. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/interpreto/attributions/perturbations/base.py +55 -93
  17. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/interpreto/attributions/perturbations/gaussian_noise_perturbation.py +17 -16
  18. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/interpreto/attributions/perturbations/gradient_shap_perturbation.py +4 -5
  19. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/interpreto/attributions/perturbations/insertion_deletion_perturbation.py +6 -18
  20. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/interpreto/attributions/perturbations/linear_interpolation_perturbation.py +22 -19
  21. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/interpreto/attributions/perturbations/occlusion_perturbation.py +0 -1
  22. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/interpreto/concepts/__init__.py +22 -0
  23. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/interpreto/concepts/base.py +116 -43
  24. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/interpreto/concepts/interpretations/base.py +11 -5
  25. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/interpreto/concepts/interpretations/llm_labels.py +1 -1
  26. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/interpreto/concepts/interpretations/topk_inputs.py +1 -1
  27. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/interpreto/concepts/methods/sklearn_wrappers.py +1 -1
  28. interpreto-0.5.0.dev1/interpreto/concepts/probes/__init__.py +81 -0
  29. interpreto-0.5.0.dev1/interpreto/concepts/probes/base.py +279 -0
  30. interpreto-0.5.0.dev1/interpreto/concepts/probes/bias_calibrators.py +223 -0
  31. interpreto-0.5.0.dev1/interpreto/concepts/probes/centroid.py +432 -0
  32. interpreto-0.5.0.dev1/interpreto/concepts/probes/linear.py +418 -0
  33. interpreto-0.5.0.dev1/interpreto/concepts/probes/normalizations.py +201 -0
  34. interpreto-0.5.0.dev1/interpreto/concepts/probes/sklearn.py +140 -0
  35. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/interpreto/model_wrapping/__init__.py +2 -1
  36. interpreto-0.5.0.dev1/interpreto/model_wrapping/classification_inference_wrapper.py +58 -0
  37. interpreto-0.5.0.dev1/interpreto/model_wrapping/generation_inference_wrapper.py +100 -0
  38. interpreto-0.5.0.dev1/interpreto/model_wrapping/inference_wrapper.py +603 -0
  39. interpreto-0.5.0.dev1/interpreto/model_wrapping/inputs_to_concepts_inference_wrapper.py +91 -0
  40. interpreto-0.5.0.dev1/interpreto/model_wrapping/split_sequence_classification.py +427 -0
  41. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/interpreto/typing.py +8 -14
  42. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/interpreto.egg-info/PKG-INFO +23 -13
  43. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/interpreto.egg-info/SOURCES.txt +11 -1
  44. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/interpreto.egg-info/requires.txt +1 -1
  45. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/mkdocs.yml +17 -9
  46. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/pyproject.toml +2 -2
  47. interpreto-0.4.20/interpreto/concepts/plots/__init__.py +0 -0
  48. interpreto-0.4.20/interpreto/model_wrapping/classification_inference_wrapper.py +0 -227
  49. interpreto-0.4.20/interpreto/model_wrapping/generation_inference_wrapper.py +0 -141
  50. interpreto-0.4.20/interpreto/model_wrapping/inference_wrapper.py +0 -709
  51. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/.editorconfig +0 -0
  52. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/.gitignore +0 -0
  53. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/LICENSE +0 -0
  54. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/MANIFEST.in +0 -0
  55. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/interpreto/_vendor/overcomplete/VENDORED_FROM +0 -0
  56. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/interpreto/_vendor/overcomplete/__init__.py +0 -0
  57. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/interpreto/_vendor/overcomplete/_sync_overcomplete.py +0 -0
  58. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/interpreto/_vendor/overcomplete/base.py +0 -0
  59. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/interpreto/_vendor/overcomplete/data.py +0 -0
  60. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/interpreto/_vendor/overcomplete/metrics.py +0 -0
  61. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/interpreto/_vendor/overcomplete/optimization/__init__.py +0 -0
  62. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/interpreto/_vendor/overcomplete/optimization/archetypal_analysis.py +0 -0
  63. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/interpreto/_vendor/overcomplete/optimization/base.py +0 -0
  64. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/interpreto/_vendor/overcomplete/optimization/convex_nmf.py +0 -0
  65. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/interpreto/_vendor/overcomplete/optimization/nmf.py +0 -0
  66. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/interpreto/_vendor/overcomplete/optimization/semi_nmf.py +0 -0
  67. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/interpreto/_vendor/overcomplete/optimization/sklearn_wrappers.py +0 -0
  68. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/interpreto/_vendor/overcomplete/optimization/utils.py +0 -0
  69. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/interpreto/_vendor/overcomplete/overcomplete.patch +0 -0
  70. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/interpreto/_vendor/overcomplete/sae/__init__.py +0 -0
  71. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/interpreto/_vendor/overcomplete/sae/archetypal_dictionary.py +0 -0
  72. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/interpreto/_vendor/overcomplete/sae/base.py +0 -0
  73. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/interpreto/_vendor/overcomplete/sae/batchtopk_sae.py +0 -0
  74. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/interpreto/_vendor/overcomplete/sae/dictionary.py +0 -0
  75. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/interpreto/_vendor/overcomplete/sae/factory.py +0 -0
  76. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/interpreto/_vendor/overcomplete/sae/jump_sae.py +0 -0
  77. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/interpreto/_vendor/overcomplete/sae/kernels.py +0 -0
  78. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/interpreto/_vendor/overcomplete/sae/losses.py +0 -0
  79. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/interpreto/_vendor/overcomplete/sae/modules.py +0 -0
  80. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/interpreto/_vendor/overcomplete/sae/mp_sae.py +0 -0
  81. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/interpreto/_vendor/overcomplete/sae/omp_sae.py +0 -0
  82. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/interpreto/_vendor/overcomplete/sae/optimizer.py +0 -0
  83. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/interpreto/_vendor/overcomplete/sae/qsae.py +0 -0
  84. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/interpreto/_vendor/overcomplete/sae/rasae.py +0 -0
  85. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/interpreto/_vendor/overcomplete/sae/topk_sae.py +0 -0
  86. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/interpreto/_vendor/overcomplete/sae/trackers.py +0 -0
  87. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/interpreto/_vendor/overcomplete/sae/train.py +0 -0
  88. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/interpreto/attributions/__init__.py +0 -0
  89. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/interpreto/attributions/aggregations/__init__.py +0 -0
  90. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/interpreto/attributions/aggregations/linear_regression_aggregation.py +0 -0
  91. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/interpreto/attributions/aggregations/sobol_aggregation.py +0 -0
  92. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/interpreto/attributions/methods/__init__.py +0 -0
  93. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/interpreto/attributions/methods/gradient_shap.py +0 -0
  94. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/interpreto/attributions/methods/integrated_gradients.py +0 -0
  95. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/interpreto/attributions/methods/smooth_grad.py +0 -0
  96. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/interpreto/attributions/methods/square_grad.py +0 -0
  97. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/interpreto/attributions/methods/var_grad.py +0 -0
  98. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/interpreto/attributions/metrics/__init__.py +0 -0
  99. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/interpreto/attributions/perturbations/random_perturbation.py +0 -0
  100. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/interpreto/attributions/perturbations/shap_perturbation.py +0 -0
  101. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/interpreto/attributions/perturbations/sobol_perturbation.py +0 -0
  102. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/interpreto/attributions/plots/__init__.py +0 -0
  103. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/interpreto/commons/__init__.py +0 -0
  104. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/interpreto/commons/distances.py +0 -0
  105. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/interpreto/commons/generator_tools.py +0 -0
  106. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/interpreto/commons/granularity.py +0 -0
  107. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/interpreto/concepts/interpretations/__init__.py +0 -0
  108. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/interpreto/concepts/methods/__init__.py +0 -0
  109. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/interpreto/concepts/methods/cockatiel.py +0 -0
  110. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/interpreto/concepts/methods/neurons_as_concepts.py +0 -0
  111. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/interpreto/concepts/methods/overcomplete.py +0 -0
  112. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/interpreto/concepts/metrics/__init__.py +0 -0
  113. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/interpreto/concepts/metrics/consim.py +0 -0
  114. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/interpreto/concepts/metrics/dictionary_metrics.py +0 -0
  115. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/interpreto/concepts/metrics/reconstruction_metrics.py +0 -0
  116. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/interpreto/concepts/metrics/sparsity_metrics.py +0 -0
  117. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/interpreto/model_wrapping/llm_interface.py +0 -0
  118. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/interpreto/model_wrapping/model_with_split_points.py +0 -0
  119. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/interpreto/model_wrapping/splitting_utils.py +0 -0
  120. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/interpreto/model_wrapping/transformers_classes.py +0 -0
  121. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/interpreto/visualizations/README.md +0 -0
  122. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/interpreto/visualizations/__init__.py +0 -0
  123. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/interpreto/visualizations/attributions.py +0 -0
  124. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/interpreto/visualizations/commons.py +0 -0
  125. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/interpreto/visualizations/concepts.py +0 -0
  126. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/interpreto/visualizations/css/visualization.css +0 -0
  127. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/interpreto/visualizations/js/core/dom_renderer.js +0 -0
  128. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/interpreto/visualizations/js/core/state_manager.js +0 -0
  129. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/interpreto/visualizations/js/core/style_computer.js +0 -0
  130. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/interpreto/visualizations/js/core/view_updater.js +0 -0
  131. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/interpreto/visualizations/js/visualizations/attribution_classification.js +0 -0
  132. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/interpreto/visualizations/js/visualizations/attribution_generation.js +0 -0
  133. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/interpreto/visualizations/js/visualizations/concepts_classification_global.js +0 -0
  134. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/interpreto/visualizations/js/visualizations/concepts_classification_local.js +0 -0
  135. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/interpreto/visualizations/js/visualizations/concepts_generation_local.js +0 -0
  136. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/interpreto.egg-info/dependency_links.txt +0 -0
  137. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/interpreto.egg-info/top_level.txt +0 -0
  138. {interpreto-0.4.20 → interpreto-0.5.0.dev1}/setup.cfg +0 -0
@@ -24,7 +24,7 @@ repos:
24
24
  exclude: LICENSE
25
25
 
26
26
  - repo: https://github.com/astral-sh/ruff-pre-commit
27
- rev: v0.14.14
27
+ rev: v0.15.10
28
28
  hooks:
29
29
  - id: ruff-format
30
30
  - id: ruff
@@ -0,0 +1,214 @@
1
+ # AGENTS.md
2
+
3
+ ## Goal
4
+
5
+ `interpreto` is a modular interpretability toolkit for transformer models. The repository aims to provide:
6
+
7
+ - an easy-to-use public API for attribution and concept-based explanations,
8
+ - detailed documentation with concrete examples,
9
+ - precise internal representations for tensors, targets, and activations,
10
+ - reusable building blocks that can be combined without rewriting the whole pipeline.
11
+
12
+ The main product surface is:
13
+
14
+ - attribution methods for classification and generation,
15
+ - concept discovery and concept interpretation workflows,
16
+ - evaluation metrics,
17
+ - HTML visualizations,
18
+ - docs and notebooks showing real usage.
19
+
20
+ ## Repository Map
21
+
22
+ - `interpreto/__init__.py`
23
+ - Curated public API. If a feature is meant to be user-facing, it usually belongs here too.
24
+ - `interpreto/model_wrapping/`
25
+ - Bridges raw Hugging Face models to Interpreto internals.
26
+ - `inference_wrapper.py`: shared batching, device handling, logits/gradient access, padding helpers.
27
+ - `classification_inference_wrapper.py`: targeted scoring for classification tasks.
28
+ - `generation_inference_wrapper.py`: targeted scoring for generation tasks.
29
+ - `model_with_split_points.py`: `nnsight`-based model splitting and activation extraction for concept methods.
30
+ - `llm_interface.py`: abstraction layer for LLM-based concept labeling.
31
+ - `interpreto/attributions/`
32
+ - Attribution framework.
33
+ - `base.py`: shared explainers, normalization, output dataclasses, classification/generation glue.
34
+ - `methods/`: LIME, KernelShap, Occlusion, Sobol, Saliency, Integrated Gradients, SmoothGrad, etc.
35
+ - `perturbations/`: perturbation generators used by attribution methods.
36
+ - `aggregations/`: score aggregation logic.
37
+ - `metrics/`: insertion/deletion evaluation.
38
+ - `interpreto/concepts/`
39
+ - Concept-based interpretability framework.
40
+ - `base.py`: base concept explainer interfaces.
41
+ - `methods/`: neurons-as-concepts, overcomplete/SAE methods, sklearn-based methods, Cockatiel.
42
+ - `interpretations/`: `TopKInputs`, `LLMLabels`, and related interpretation utilities.
43
+ - `metrics/`: reconstruction, sparsity, stability, and ConSim.
44
+ - `interpreto/commons/`
45
+ - Shared utilities such as granularity handling, generator helpers, and distances.
46
+ - `interpreto/typing.py`
47
+ - Central typing aliases and protocols. This file expresses the intended normalized internal shapes and interfaces.
48
+ - `interpreto/visualizations/`
49
+ - HTML/CSS/JS renderers for attribution and concept outputs.
50
+ - Visualizations should consume normalized outputs, not recompute model logic.
51
+ - `interpreto/_vendor/overcomplete/`
52
+ - Vendored dependency for concept learning backends. Avoid touching it unless the change really belongs there.
53
+ - `tests/`
54
+ - Pytest suite. Reuse fixtures from `tests/conftest.py` whenever possible.
55
+ - `docs/`
56
+ - MkDocs source, API pages, and notebooks.
57
+ - `site/`
58
+ - Generated documentation output. Prefer editing `docs/`, not `site/`.
59
+
60
+ ## Key Dependencies
61
+
62
+ - `torch`
63
+ - Core tensor and model execution backend.
64
+ - `transformers`
65
+ - Main model/tokenizer interface and public compatibility target.
66
+ - `nnsight`
67
+ - Used by `ModelWithSplitPoints` for split points and activation capture.
68
+ - `jaxtyping` and `beartype`
69
+ - Preferred tools for explicit tensor typing and shape contracts.
70
+ - `scikit-learn`, `scipy`, `einops`, `matplotlib`, `nltk`
71
+ - Supporting libraries for methods, metrics, preprocessing, and visualization.
72
+ - `bitsandbytes`
73
+ - Compatibility with quantized transformer loading.
74
+ - `mkdocs` stack
75
+ - Documentation build system.
76
+
77
+ ## How The Pieces Interact
78
+
79
+ ### Attribution pipeline
80
+
81
+ User inputs can arrive in several formats: strings, tokenized mappings, tensors, or iterables of those. The code should normalize them early, then keep core computations on one internal format.
82
+
83
+ Typical flow:
84
+
85
+ 1. User input and targets enter an attribution explainer from `interpreto.attributions`.
86
+ 2. The explainer normalizes inputs/targets in `attributions/base.py`.
87
+ 3. A perturbator or gradient path generates the computation stream.
88
+ 4. A task-specific inference wrapper computes targeted logits or gradients.
89
+ 5. An aggregator converts raw scores into final attribution values.
90
+ 6. The result is packaged as `AttributionOutput`.
91
+ 7. Metrics and visualizations consume `AttributionOutput`.
92
+
93
+ Important style point: attribution code is intentionally generator-friendly. Many paths are designed to work sample by sample or batch by batch instead of materializing everything eagerly. Preserve that when making changes, especially for generation and prompt construction logic.
94
+
95
+ ### Concept pipeline
96
+
97
+ Typical flow:
98
+
99
+ 1. `ModelWithSplitPoints` wraps a transformer model and exposes split points.
100
+ 2. `get_activations()` extracts latent activations at a chosen granularity.
101
+ 3. A concept explainer from `interpreto.concepts.methods` fits or applies a concept model on those activations.
102
+ 4. Interpretation methods such as `TopKInputs` or `LLMLabels` map concept dimensions to human-readable descriptions.
103
+ 5. Metrics and visualizations operate on the resulting concept-space artifacts.
104
+
105
+ `ModelWithSplitPoints` is the bridge between the transformer world and concept methods. Most concept changes should respect that layering instead of bypassing it.
106
+
107
+ ### Granularity and normalization
108
+
109
+ Granularity is a core abstraction shared across attribution and concept code. The code often accepts flexible user inputs, but should converge quickly toward:
110
+
111
+ - normalized `TensorMapping`-style model inputs,
112
+ - normalized target tensors,
113
+ - normalized activation tensors,
114
+ - normalized output dataclasses.
115
+
116
+ This repository prefers a flexible public API and a stricter internal core.
117
+
118
+ ## Repository Vibe
119
+
120
+ - Keep the public API easy to use.
121
+ - Users may provide several input formats.
122
+ - Internal computations should still be normalized into a single clear format as early as possible.
123
+ - Prefer precise typing.
124
+ - `jaxtyping` is valuable here because tensor shapes matter a lot for readability and debugging.
125
+ - Be pragmatic at boundaries with `transformers` and `nnsight`; do not make the code worse just to force shape annotations through awkward external APIs.
126
+ - Documentation matters.
127
+ - Detailed docstrings, examples, file-level comments, and inline comments are a feature of the repository, not noise.
128
+ - When adding or changing logic, explain the shape conventions and the intent, especially around generators, token alignment, split points, and concept encoding.
129
+ - The repository is modular.
130
+ - Prefer plug-and-play building blocks over special-purpose monoliths.
131
+ - Reuse wrappers, perturbators, aggregators, metrics, and visualization outputs rather than duplicating logic.
132
+ - Prefer one place for validation.
133
+ - Do not add repeated guardrails in every layer if the check already belongs at the public boundary or is already enforced by typing/contracts.
134
+ - Re-check only if a lower-level function can be called independently or if the invariant genuinely changes.
135
+ - Smaller changes are usually better.
136
+ - Do not refactor by default.
137
+ - If a minimal patch would conflict with the method/class/repository design, then do the slightly larger coherent refactor instead of adding a local hack.
138
+ - Keep implementations efficient but simple.
139
+ - Prefer straightforward Torch code.
140
+ - If a much faster version would add a lot of complexity, it is often better to land the clean version first and leave a focused `TODO`.
141
+ - In attribution code, preserve the generator-based pipeline mindset.
142
+ - The repository often processes attribution sample by sample, while trying to construct good prompts and avoid unnecessary materialization.
143
+
144
+ ## Coding Expectations
145
+
146
+ - Write docstrings and the important inline comments at the same time as the code change, or before.
147
+ - Prefer file-level comments when the whole module has a specific role or subtle invariant.
148
+ - Keep internal data formats explicit.
149
+ - If adding a new public class or function, check whether it should be re-exported in a package `__init__.py` and documented in `docs/`.
150
+ - Use the existing module boundaries.
151
+ - New attribution methods usually belong in `interpreto/attributions/methods/`.
152
+ - New perturbation logic belongs in `interpreto/attributions/perturbations/`.
153
+ - New concept methods belong in `interpreto/concepts/methods/`.
154
+ - New interpretation strategies should use the existing concept explainer interfaces.
155
+
156
+ ## Tests
157
+
158
+ Testing style in this repository is usually a mix of:
159
+
160
+ - method-level tests for specific algorithmic behavior,
161
+ - class-level tests for API and integration behavior,
162
+ - sanity checks for end-to-end invariants.
163
+
164
+ Guidelines:
165
+
166
+ - For a new feature, test-driven development is preferred when practical.
167
+ - Keep tests reviewable. Do not add large numbers of nearly identical tests.
168
+ - Be very clear in test comments/docstrings about what the test is proving.
169
+ - Reuse `tests/conftest.py`, `tests/fixtures/`, and existing helpers before inventing new scaffolding.
170
+ - Prefer `hf-internal-testing/*` tiny models over large custom placeholders or long fake model definitions.
171
+ - Do not test the same invariant in many places unless it protects distinct call paths.
172
+
173
+ ## Change Workflow For Agents
174
+
175
+ 1. Think first.
176
+ - Understand which layer should change.
177
+ - Prefer the smallest coherent modification.
178
+ - If the design tradeoff is uncertain, it is better to ask for an opinion than to guess.
179
+ 2. Add or update tests.
180
+ - For new features or bug fixes, start from the behavior you want to lock in.
181
+ - Reuse fixtures and tiny test models whenever possible.
182
+ 3. Implement the change.
183
+ - Keep the code aligned with existing abstractions.
184
+ - Avoid clever one-off tricks that only satisfy the immediate patch.
185
+ 4. Update documentation if needed.
186
+ - Public API changes usually need docstring and docs updates.
187
+ - Example-driven documentation is part of the repository style.
188
+ 5. Verify with targeted commands first.
189
+
190
+ Useful commands:
191
+
192
+ - `make install-dev`
193
+ - `make lint`
194
+ - `make fast-test`
195
+ - `make test-cpu`
196
+ - `python -m pytest -n auto -c pyproject.toml -v path/to/test_file.py`
197
+
198
+ ## Practical Do / Don't
199
+
200
+ Do:
201
+
202
+ - Normalize flexible user inputs into one internal format early.
203
+ - Use `jaxtyping` where it improves shape clarity.
204
+ - Preserve generator-based or streaming-friendly flows.
205
+ - Add comments where tensor shapes, batching, or prompt construction are non-obvious.
206
+ - Favor small coherent patches.
207
+
208
+ Don't:
209
+
210
+ - Add redundant guardrails in every layer.
211
+ - Materialize huge intermediate lists if the existing pipeline is intentionally iterable/generator-based.
212
+ - Refactor broadly without a concrete design reason.
213
+ - Fight external library APIs just to satisfy an idealized typing style.
214
+ - Edit generated docs in `site/` when the real source lives in `docs/`.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: interpreto
3
- Version: 0.4.20
3
+ Version: 0.5.0.dev1
4
4
  Summary: Interpretability toolbox for LLMs
5
5
  Author: FOR Team
6
6
  Author-email: fanny.jourdan@irt-saintexupery.com
@@ -57,7 +57,7 @@ License-File: LICENSE
57
57
  Requires-Dist: transformers>=4.22.0
58
58
  Requires-Dist: nltk
59
59
  Requires-Dist: torch>=2.0
60
- Requires-Dist: nnsight<0.6.0,>=0.5.1
60
+ Requires-Dist: nnsight<0.8.0,>=0.7.0
61
61
  Requires-Dist: jaxtyping<=0.2.36
62
62
  Requires-Dist: beartype
63
63
  Requires-Dist: mknotebooks
@@ -120,6 +120,7 @@ Dynamic: license-file
120
120
  <p align="center">
121
121
  <a href="https://for-sight-ai.github.io/interpreto/"><strong>📚 Explore Interpreto docs &gt;&gt;</strong></a><br />
122
122
  <a href="https://for-sight-ai.github.io/interpreto-demo/"><strong>🖼️ Checkout our explanation gallery &gt;&gt;</strong></a>
123
+ <a href="https://arxiv.org/abs/2512.09730"><strong>📜 Read our paper &gt;&gt;</strong></a>
123
124
  </p>
124
125
 
125
126
  ## 🚀 Quick Start
@@ -163,41 +164,51 @@ They all work seamlessly for both classification (`...ForSequenceClassification`
163
164
 
164
165
  Concept-based explanations aim to provide high-level interpretations of latent model representations.
165
166
 
166
- Interpreto generalizes these methods through four core steps:
167
+ We propose both supervised (probes and CAVs) and unsupervised (dictionary learning) approaches.
168
+
169
+ Interpreto generalizes these methods through four core steps, the two first are common between both approaches:
167
170
 
168
171
  1. Split a model in two and obtain a dataset of activations
169
- 2. Concept Discovery (e.g., from latent embeddings)
170
- 3. Concept Interpretation (mapping discovered concepts to human-understandable elements)
171
- 4. Concept-to-Output Attribution (assessing concept relevance to model outputs)
172
+ 2. Learn concepts (e.g., from latent embeddings)
173
+ 3. Interpret concepts (mapping discovered concepts to human-understandable elements)
174
+ 4. Estimate concepts importance (assessing concept relevance to model outputs)
172
175
 
173
176
  **1. Split a model in two and obtain a dataset of activations:** (mainly via [`nnsight`](https://github.com/ndif-team/nnsight)):
174
177
 
175
178
  Choose any layer in any HuggingFace language model with our `ModelWithSplitPoints` based on `nnsight`. Then pass a dataset through it to obtain a dataset of activations.
176
179
 
177
- **2. Dictionary Learning for Concept Discovery** (mainly via [`overcomplete`](https://github.com/KempnerInstitute/overcomplete)):
180
+ **2. (supervised) Train probe** with the [`ProbeExplainer`](https://for-sight-ai.github.io/interpreto/api/concepts/probes/)
181
+
182
+ We differentiate two families of probes:
183
+
184
+ - Linear probes: [`LinearRegressionProbe`](https://for-sight-ai.github.io/interpreto/api/concepts/probes/#interpreto.concepts.probes.LinearRegressionProbe), [`LogisticRegressionProbe`](https://for-sight-ai.github.io/interpreto/api/concepts/probes/#interpreto.concepts.probes.LogisticRegressionProbe), [`LinearSVMProbe`](https://for-sight-ai.github.io/interpreto/api/concepts/probes/#interpreto.concepts.probes.LinearSVMProbe), [`MeansDiffProbe`](https://for-sight-ai.github.io/interpreto/api/concepts/probes/#interpreto.concepts.probes.MeansDiffProbe)
185
+ - Centroid-based probes: [`CosineCentroidProbe`](https://for-sight-ai.github.io/interpreto/api/concepts/probes/#interpreto.concepts.probes.CosineCentroidProbe), [`DotProductCentroidProbe`](https://for-sight-ai.github.io/interpreto/api/concepts/probes/#interpreto.concepts.probes.DotProductCentroidProbe), [`SqL2CentroidProbe`](https://for-sight-ai.github.io/interpreto/api/concepts/probes/#interpreto.concepts.probes.SqL2CentroidProbe), [`SVDDCentroidProbe`](https://for-sight-ai.github.io/interpreto/api/concepts/probes/#interpreto.concepts.probes.SVDDCentroidProbe), [`DiagonalMahalanobisCentroidProbe`](https://for-sight-ai.github.io/interpreto/api/concepts/probes/#interpreto.concepts.probes.DiagonalMahalanobisCentroidProbe)
186
+
187
+ Both can be tuned with `bias_calibrator` and `normalization` parameters.
188
+
189
+ **2. (unsupervised) Dictionary Learning for Concept Discovery** (mainly via [`overcomplete`](https://github.com/KempnerInstitute/overcomplete)):
178
190
 
179
191
  - Interpret neurons directly via [`NeuronsAsConcepts`](https://for-sight-ai.github.io/interpreto/api/concepts/methods/neurons_as_concepts/)
180
192
  - [`NMF`](https://for-sight-ai.github.io/interpreto/api/concepts/methods/optim/#interpreto.concepts.NMFConcepts), [`Semi-NMF`](https://for-sight-ai.github.io/interpreto/api/concepts/methods/optim/#interpreto.concepts.SemiNMFConcepts), [`ConvexNMF`](https://for-sight-ai.github.io/interpreto/api/concepts/methods/optim/#interpreto.concepts.ConvexNMFConcepts)
181
193
  - [`ICA`](https://for-sight-ai.github.io/interpreto/api/concepts/methods/optim/#interpreto.concepts.ICAConcepts), [`SVD`](https://for-sight-ai.github.io/interpreto/api/concepts/methods/optim/#interpreto.concepts.SVDConcepts), [`PCA`](https://for-sight-ai.github.io/interpreto/api/concepts/methods/optim/#interpreto.concepts.PCAConcepts), [`KMeans`](https://for-sight-ai.github.io/interpreto/api/concepts/methods/optim/#interpreto.concepts.KMeansConcepts)
182
194
  - SAE variants: [`Vanilla SAE`](https://for-sight-ai.github.io/interpreto/api/concepts/methods/sae/#interpreto.concepts.VanillaSAEConcepts), [`TopK SAE`](https://for-sight-ai.github.io/interpreto/api/concepts/methods/sae/#interpreto.concepts.TopKSAEConcepts), [`JumpReLU SAE`](https://for-sight-ai.github.io/interpreto/api/concepts/methods/sae/#interpreto.concepts.JumpReLUSAEConcepts), [`BatchTopK SAE`](https://for-sight-ai.github.io/interpreto/api/concepts/methods/sae/#interpreto.concepts.BatchTopKSAEConcepts)
183
195
 
184
- **3. Available Concept Interpretation Techniques:**
196
+ **3. (unsupervised) Available Concept Interpretation Techniques:**
185
197
 
186
198
  - Top-k tokens from tokenizer vocabulary via [`TopKInputs`](https://for-sight-ai.github.io/interpreto/api/concepts/concepts_interpretations/#interpreto.concepts.interpretations.TopKInputs) and `use_vocab=True`
187
199
  - Top-k tokens/words/sentences/samples from specific datasets via [`TopKInputs`](https://for-sight-ai.github.io/interpreto/api/concepts/concepts_interpretations/#interpreto.concepts.interpretations.TopKInputs)
188
200
  - Label concepts via LLMs with [`LLMLabels`](https://for-sight-ai.github.io/interpreto/api/concepts/concepts_interpretations/#interpreto.concepts.interpretations.LLMLabels) ([Bills et al. 2023](https://openai.com/index/language-models-can-explain-neurons-in-language-models/))
201
+ - Input-to-concept attribution from dataset examples ([Concept Attributions](https://for-sight-ai.github.io/interpreto/api/concepts/interpretations/concept_attributions/)) ([Jourdan et al. 2023](https://aclanthology.org/2023.findings-acl.317/))
189
202
 
190
203
  <details><summary>Concept Interpretation Techniques Added in the future:</summary>
191
204
 
192
- - Input-to-concept attribution from dataset examples ([Jourdan et al. 2023](https://aclanthology.org/2023.findings-acl.317/))
193
- - Theme prediction via LLMs from top-k tokens/sentences
194
205
  - Aligning concepts with human labels ([Sajjad et al. 2022](https://aclanthology.org/2022.naacl-main.225/))
195
206
  - Word cloud visualizations of concepts ([Dalvi et al. 2022](https://arxiv.org/abs/2205.07237))
196
207
  - VocabProj & TokenChange ([Gur-Arieh et al. 2025](https://arxiv.org/abs/2501.08319))
197
208
 
198
209
  </details>
199
210
 
200
- **4. Concept-to-Output Attribution:**
211
+ **4. (unsupervised) Concept-to-Output Attribution:**
201
212
 
202
213
  Estimate the contribution of each concept to the model output.
203
214
 
@@ -207,7 +218,6 @@ Can be obtained with any concept-based explainer via [`MethodConcepts.concept_ou
207
218
 
208
219
  Thanks to this generalization encompassing all concept-based methods and our highly flexible architecture, we can easily obtain a large number of concept-based methods:
209
220
 
210
- - CAV and TCAV: [Kim et al. 2018, Interpretability Beyond Feature Attribution: Quantitative Testing with Concept Activation Vectors (TCAV)](http://proceedings.mlr.press/v80/kim18d.html)
211
221
  - ConceptSHAP: [Yeh et al. 2020, On Completeness-aware Concept-Based Explanations in Deep Neural Networks](https://proceedings.neurips.cc/paper/2020/hash/ecb287ff763c169694f682af52c1f309-Abstract.html)
212
222
  - COCKATIEL: [Jourdan et al. 2023, COCKATIEL: COntinuous Concept ranKed ATtribution with Interpretable ELements for explaining neural net classifiers on NLP](https://aclanthology.org/2023.findings-acl.317/)
213
223
  - Yun et al. 2021, [Transformer visualization via dictionary learning: contextualized embedding as a linear superposition of transformer factors](https://arxiv.org/abs/2103.15949)
@@ -264,7 +274,7 @@ Interpreto 🪄 is a project of the [FOR](https://www.irt-saintexupery.com/fr/fo
264
274
 
265
275
  ## 🗞️ Citation
266
276
 
267
- If you use Interpreto 🪄 as part of your workflow in a scientific publication, please consider citing 🗞️ our paper:
277
+ If you use Interpreto 🪄 as part of your workflow in a scientific publication, please consider citing 🗞️ [our paper](https://arxiv.org/abs/2512.09730):
268
278
 
269
279
  ```bibtex
270
280
  @article{poche2025interpreto,
@@ -15,6 +15,7 @@
15
15
  <p align="center">
16
16
  <a href="https://for-sight-ai.github.io/interpreto/"><strong>📚 Explore Interpreto docs &gt;&gt;</strong></a><br />
17
17
  <a href="https://for-sight-ai.github.io/interpreto-demo/"><strong>🖼️ Checkout our explanation gallery &gt;&gt;</strong></a>
18
+ <a href="https://arxiv.org/abs/2512.09730"><strong>📜 Read our paper &gt;&gt;</strong></a>
18
19
  </p>
19
20
 
20
21
  ## 🚀 Quick Start
@@ -58,41 +59,51 @@ They all work seamlessly for both classification (`...ForSequenceClassification`
58
59
 
59
60
  Concept-based explanations aim to provide high-level interpretations of latent model representations.
60
61
 
61
- Interpreto generalizes these methods through four core steps:
62
+ We propose both supervised (probes and CAVs) and unsupervised (dictionary learning) approaches.
63
+
64
+ Interpreto generalizes these methods through four core steps, the two first are common between both approaches:
62
65
 
63
66
  1. Split a model in two and obtain a dataset of activations
64
- 2. Concept Discovery (e.g., from latent embeddings)
65
- 3. Concept Interpretation (mapping discovered concepts to human-understandable elements)
66
- 4. Concept-to-Output Attribution (assessing concept relevance to model outputs)
67
+ 2. Learn concepts (e.g., from latent embeddings)
68
+ 3. Interpret concepts (mapping discovered concepts to human-understandable elements)
69
+ 4. Estimate concepts importance (assessing concept relevance to model outputs)
67
70
 
68
71
  **1. Split a model in two and obtain a dataset of activations:** (mainly via [`nnsight`](https://github.com/ndif-team/nnsight)):
69
72
 
70
73
  Choose any layer in any HuggingFace language model with our `ModelWithSplitPoints` based on `nnsight`. Then pass a dataset through it to obtain a dataset of activations.
71
74
 
72
- **2. Dictionary Learning for Concept Discovery** (mainly via [`overcomplete`](https://github.com/KempnerInstitute/overcomplete)):
75
+ **2. (supervised) Train probe** with the [`ProbeExplainer`](https://for-sight-ai.github.io/interpreto/api/concepts/probes/)
76
+
77
+ We differentiate two families of probes:
78
+
79
+ - Linear probes: [`LinearRegressionProbe`](https://for-sight-ai.github.io/interpreto/api/concepts/probes/#interpreto.concepts.probes.LinearRegressionProbe), [`LogisticRegressionProbe`](https://for-sight-ai.github.io/interpreto/api/concepts/probes/#interpreto.concepts.probes.LogisticRegressionProbe), [`LinearSVMProbe`](https://for-sight-ai.github.io/interpreto/api/concepts/probes/#interpreto.concepts.probes.LinearSVMProbe), [`MeansDiffProbe`](https://for-sight-ai.github.io/interpreto/api/concepts/probes/#interpreto.concepts.probes.MeansDiffProbe)
80
+ - Centroid-based probes: [`CosineCentroidProbe`](https://for-sight-ai.github.io/interpreto/api/concepts/probes/#interpreto.concepts.probes.CosineCentroidProbe), [`DotProductCentroidProbe`](https://for-sight-ai.github.io/interpreto/api/concepts/probes/#interpreto.concepts.probes.DotProductCentroidProbe), [`SqL2CentroidProbe`](https://for-sight-ai.github.io/interpreto/api/concepts/probes/#interpreto.concepts.probes.SqL2CentroidProbe), [`SVDDCentroidProbe`](https://for-sight-ai.github.io/interpreto/api/concepts/probes/#interpreto.concepts.probes.SVDDCentroidProbe), [`DiagonalMahalanobisCentroidProbe`](https://for-sight-ai.github.io/interpreto/api/concepts/probes/#interpreto.concepts.probes.DiagonalMahalanobisCentroidProbe)
81
+
82
+ Both can be tuned with `bias_calibrator` and `normalization` parameters.
83
+
84
+ **2. (unsupervised) Dictionary Learning for Concept Discovery** (mainly via [`overcomplete`](https://github.com/KempnerInstitute/overcomplete)):
73
85
 
74
86
  - Interpret neurons directly via [`NeuronsAsConcepts`](https://for-sight-ai.github.io/interpreto/api/concepts/methods/neurons_as_concepts/)
75
87
  - [`NMF`](https://for-sight-ai.github.io/interpreto/api/concepts/methods/optim/#interpreto.concepts.NMFConcepts), [`Semi-NMF`](https://for-sight-ai.github.io/interpreto/api/concepts/methods/optim/#interpreto.concepts.SemiNMFConcepts), [`ConvexNMF`](https://for-sight-ai.github.io/interpreto/api/concepts/methods/optim/#interpreto.concepts.ConvexNMFConcepts)
76
88
  - [`ICA`](https://for-sight-ai.github.io/interpreto/api/concepts/methods/optim/#interpreto.concepts.ICAConcepts), [`SVD`](https://for-sight-ai.github.io/interpreto/api/concepts/methods/optim/#interpreto.concepts.SVDConcepts), [`PCA`](https://for-sight-ai.github.io/interpreto/api/concepts/methods/optim/#interpreto.concepts.PCAConcepts), [`KMeans`](https://for-sight-ai.github.io/interpreto/api/concepts/methods/optim/#interpreto.concepts.KMeansConcepts)
77
89
  - SAE variants: [`Vanilla SAE`](https://for-sight-ai.github.io/interpreto/api/concepts/methods/sae/#interpreto.concepts.VanillaSAEConcepts), [`TopK SAE`](https://for-sight-ai.github.io/interpreto/api/concepts/methods/sae/#interpreto.concepts.TopKSAEConcepts), [`JumpReLU SAE`](https://for-sight-ai.github.io/interpreto/api/concepts/methods/sae/#interpreto.concepts.JumpReLUSAEConcepts), [`BatchTopK SAE`](https://for-sight-ai.github.io/interpreto/api/concepts/methods/sae/#interpreto.concepts.BatchTopKSAEConcepts)
78
90
 
79
- **3. Available Concept Interpretation Techniques:**
91
+ **3. (unsupervised) Available Concept Interpretation Techniques:**
80
92
 
81
93
  - Top-k tokens from tokenizer vocabulary via [`TopKInputs`](https://for-sight-ai.github.io/interpreto/api/concepts/concepts_interpretations/#interpreto.concepts.interpretations.TopKInputs) and `use_vocab=True`
82
94
  - Top-k tokens/words/sentences/samples from specific datasets via [`TopKInputs`](https://for-sight-ai.github.io/interpreto/api/concepts/concepts_interpretations/#interpreto.concepts.interpretations.TopKInputs)
83
95
  - Label concepts via LLMs with [`LLMLabels`](https://for-sight-ai.github.io/interpreto/api/concepts/concepts_interpretations/#interpreto.concepts.interpretations.LLMLabels) ([Bills et al. 2023](https://openai.com/index/language-models-can-explain-neurons-in-language-models/))
96
+ - Input-to-concept attribution from dataset examples ([Concept Attributions](https://for-sight-ai.github.io/interpreto/api/concepts/interpretations/concept_attributions/)) ([Jourdan et al. 2023](https://aclanthology.org/2023.findings-acl.317/))
84
97
 
85
98
  <details><summary>Concept Interpretation Techniques Added in the future:</summary>
86
99
 
87
- - Input-to-concept attribution from dataset examples ([Jourdan et al. 2023](https://aclanthology.org/2023.findings-acl.317/))
88
- - Theme prediction via LLMs from top-k tokens/sentences
89
100
  - Aligning concepts with human labels ([Sajjad et al. 2022](https://aclanthology.org/2022.naacl-main.225/))
90
101
  - Word cloud visualizations of concepts ([Dalvi et al. 2022](https://arxiv.org/abs/2205.07237))
91
102
  - VocabProj & TokenChange ([Gur-Arieh et al. 2025](https://arxiv.org/abs/2501.08319))
92
103
 
93
104
  </details>
94
105
 
95
- **4. Concept-to-Output Attribution:**
106
+ **4. (unsupervised) Concept-to-Output Attribution:**
96
107
 
97
108
  Estimate the contribution of each concept to the model output.
98
109
 
@@ -102,7 +113,6 @@ Can be obtained with any concept-based explainer via [`MethodConcepts.concept_ou
102
113
 
103
114
  Thanks to this generalization encompassing all concept-based methods and our highly flexible architecture, we can easily obtain a large number of concept-based methods:
104
115
 
105
- - CAV and TCAV: [Kim et al. 2018, Interpretability Beyond Feature Attribution: Quantitative Testing with Concept Activation Vectors (TCAV)](http://proceedings.mlr.press/v80/kim18d.html)
106
116
  - ConceptSHAP: [Yeh et al. 2020, On Completeness-aware Concept-Based Explanations in Deep Neural Networks](https://proceedings.neurips.cc/paper/2020/hash/ecb287ff763c169694f682af52c1f309-Abstract.html)
107
117
  - COCKATIEL: [Jourdan et al. 2023, COCKATIEL: COntinuous Concept ranKed ATtribution with Interpretable ELements for explaining neural net classifiers on NLP](https://aclanthology.org/2023.findings-acl.317/)
108
118
  - Yun et al. 2021, [Transformer visualization via dictionary learning: contextualized embedding as a linear superposition of transformer factors](https://arxiv.org/abs/2103.15949)
@@ -159,7 +169,7 @@ Interpreto 🪄 is a project of the [FOR](https://www.irt-saintexupery.com/fr/fo
159
169
 
160
170
  ## 🗞️ Citation
161
171
 
162
- If you use Interpreto 🪄 as part of your workflow in a scientific publication, please consider citing 🗞️ our paper:
172
+ If you use Interpreto 🪄 as part of your workflow in a scientific publication, please consider citing 🗞️ [our paper](https://arxiv.org/abs/2512.09730):
163
173
 
164
174
  ```bibtex
165
175
  @article{poche2025interpreto,
@@ -42,7 +42,7 @@ from .attributions import (
42
42
  from .commons import (
43
43
  Granularity,
44
44
  )
45
- from .model_wrapping import ModelWithSplitPoints
45
+ from .model_wrapping import ModelWithSplitPoints, SplitSequenceClassification
46
46
  from .visualizations import (
47
47
  AttributionVisualization,
48
48
  plot_attributions,
@@ -71,6 +71,7 @@ __all__ = [
71
71
  "SquareGrad",
72
72
  "Saliency",
73
73
  "SmoothGrad",
74
+ "SplitSequenceClassification",
74
75
  "Sobol",
75
76
  "VarGrad",
76
77
  "get_version",
@@ -0,0 +1,16 @@
1
+ from abc import ABC
2
+
3
+ import torch
4
+ from torch import nn
5
+
6
+ class BaseDictionaryLearning(ABC, nn.Module):
7
+ nb_concepts: int
8
+ device: str | torch.device
9
+ fitted: bool
10
+
11
+ def __init__(self, nb_concepts: int, device: str | torch.device = "cpu") -> None: ...
12
+ def encode(self, x: torch.Tensor) -> torch.Tensor: ...
13
+ def decode(self, z: torch.Tensor) -> torch.Tensor: ...
14
+ def fit(self, x: torch.Tensor) -> None: ...
15
+ def get_dictionary(self) -> torch.Tensor: ...
16
+ def to(self, device: str | torch.device) -> "BaseDictionaryLearning": ... # type: ignore[override]
@@ -41,10 +41,9 @@ def cast_input_to_dtype(func):
41
41
  Ensure mask and results are on the device specified in the aggregator
42
42
  """
43
43
 
44
- def wrapper(self, results: torch.Tensor, mask, *args, **kwargs) -> torch.Tensor:
45
- # TODO : eventually add device alignment as well
46
- if mask is not None and mask.dtype != self.dtype:
47
- mask = mask.to(self.dtype)
44
+ def wrapper(self, results: torch.Tensor, mask: torch.Tensor | None, *args, **kwargs) -> torch.Tensor:
45
+ if mask is not None:
46
+ mask = mask.to(device=results.device, dtype=self.dtype)
48
47
  return func(self, results.to(self.dtype), mask, *args, **kwargs)
49
48
 
50
49
  return wrapper