inter-gnn 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- inter_gnn-0.1.0/PKG-INFO +288 -0
- inter_gnn-0.1.0/README.md +243 -0
- inter_gnn-0.1.0/inter_gnn/__init__.py +23 -0
- inter_gnn-0.1.0/inter_gnn/cli.py +248 -0
- inter_gnn-0.1.0/inter_gnn/data/__init__.py +27 -0
- inter_gnn-0.1.0/inter_gnn/data/cliffs.py +283 -0
- inter_gnn-0.1.0/inter_gnn/data/concepts.py +369 -0
- inter_gnn-0.1.0/inter_gnn/data/datamodule.py +143 -0
- inter_gnn-0.1.0/inter_gnn/data/datasets.py +202 -0
- inter_gnn-0.1.0/inter_gnn/data/featurize.py +384 -0
- inter_gnn-0.1.0/inter_gnn/data/protein.py +410 -0
- inter_gnn-0.1.0/inter_gnn/data/splits.py +373 -0
- inter_gnn-0.1.0/inter_gnn/data/standardize.py +288 -0
- inter_gnn-0.1.0/inter_gnn/evaluation/__init__.py +24 -0
- inter_gnn-0.1.0/inter_gnn/evaluation/causal.py +104 -0
- inter_gnn-0.1.0/inter_gnn/evaluation/chemical_validity.py +134 -0
- inter_gnn-0.1.0/inter_gnn/evaluation/faithfulness.py +176 -0
- inter_gnn-0.1.0/inter_gnn/evaluation/predictive.py +145 -0
- inter_gnn-0.1.0/inter_gnn/evaluation/stability_metrics.py +115 -0
- inter_gnn-0.1.0/inter_gnn/evaluation/statistical.py +119 -0
- inter_gnn-0.1.0/inter_gnn/explainers/__init__.py +11 -0
- inter_gnn-0.1.0/inter_gnn/explainers/cf_explainer.py +201 -0
- inter_gnn-0.1.0/inter_gnn/explainers/cider.py +207 -0
- inter_gnn-0.1.0/inter_gnn/explainers/t_explainer.py +176 -0
- inter_gnn-0.1.0/inter_gnn/interpretability/__init__.py +14 -0
- inter_gnn-0.1.0/inter_gnn/interpretability/concept_whitening.py +171 -0
- inter_gnn-0.1.0/inter_gnn/interpretability/motifs.py +198 -0
- inter_gnn-0.1.0/inter_gnn/interpretability/prototypes.py +195 -0
- inter_gnn-0.1.0/inter_gnn/interpretability/stability.py +183 -0
- inter_gnn-0.1.0/inter_gnn/models/__init__.py +17 -0
- inter_gnn-0.1.0/inter_gnn/models/attention.py +225 -0
- inter_gnn-0.1.0/inter_gnn/models/core_model.py +221 -0
- inter_gnn-0.1.0/inter_gnn/models/encoders.py +244 -0
- inter_gnn-0.1.0/inter_gnn/models/task_heads.py +150 -0
- inter_gnn-0.1.0/inter_gnn/training/__init__.py +11 -0
- inter_gnn-0.1.0/inter_gnn/training/callbacks.py +161 -0
- inter_gnn-0.1.0/inter_gnn/training/config.py +178 -0
- inter_gnn-0.1.0/inter_gnn/training/losses.py +179 -0
- inter_gnn-0.1.0/inter_gnn/training/trainer.py +308 -0
- inter_gnn-0.1.0/inter_gnn/visualization/__init__.py +15 -0
- inter_gnn-0.1.0/inter_gnn/visualization/concept_viz.py +120 -0
- inter_gnn-0.1.0/inter_gnn/visualization/counterfactual_viz.py +148 -0
- inter_gnn-0.1.0/inter_gnn/visualization/dashboard.py +204 -0
- inter_gnn-0.1.0/inter_gnn/visualization/molecule_viz.py +185 -0
- inter_gnn-0.1.0/inter_gnn/visualization/motif_viz.py +133 -0
- inter_gnn-0.1.0/inter_gnn/visualization/prototype_viz.py +132 -0
- inter_gnn-0.1.0/inter_gnn.egg-info/PKG-INFO +288 -0
- inter_gnn-0.1.0/inter_gnn.egg-info/SOURCES.txt +52 -0
- inter_gnn-0.1.0/inter_gnn.egg-info/dependency_links.txt +1 -0
- inter_gnn-0.1.0/inter_gnn.egg-info/entry_points.txt +2 -0
- inter_gnn-0.1.0/inter_gnn.egg-info/requires.txt +23 -0
- inter_gnn-0.1.0/inter_gnn.egg-info/top_level.txt +1 -0
- inter_gnn-0.1.0/pyproject.toml +89 -0
- inter_gnn-0.1.0/setup.cfg +4 -0
inter_gnn-0.1.0/PKG-INFO
ADDED
|
@@ -0,0 +1,288 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: inter-gnn
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Interpretable GNN-Based Framework for Drug Discovery and Candidate Screening
|
|
5
|
+
Author: Harshal Loya, Jash Chauhan, Het Gala
|
|
6
|
+
License: MIT
|
|
7
|
+
Project-URL: Homepage, https://github.com/inter-gnn/inter-gnn
|
|
8
|
+
Project-URL: Documentation, https://inter-gnn.readthedocs.io
|
|
9
|
+
Project-URL: Repository, https://github.com/inter-gnn/inter-gnn
|
|
10
|
+
Keywords: graph-neural-networks,drug-discovery,explainable-ai,molecular-property-prediction,interpretability,activity-cliffs,concept-whitening
|
|
11
|
+
Classifier: Development Status :: 3 - Alpha
|
|
12
|
+
Classifier: Intended Audience :: Science/Research
|
|
13
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
14
|
+
Classifier: Programming Language :: Python :: 3
|
|
15
|
+
Classifier: Programming Language :: Python :: 3.9
|
|
16
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
17
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
18
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
19
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
20
|
+
Classifier: Topic :: Scientific/Engineering :: Bio-Informatics
|
|
21
|
+
Classifier: Topic :: Scientific/Engineering :: Chemistry
|
|
22
|
+
Requires-Python: >=3.9
|
|
23
|
+
Description-Content-Type: text/markdown
|
|
24
|
+
Requires-Dist: torch>=2.0.0
|
|
25
|
+
Requires-Dist: torch-geometric>=2.4.0
|
|
26
|
+
Requires-Dist: rdkit>=2023.3.1
|
|
27
|
+
Requires-Dist: numpy>=1.24.0
|
|
28
|
+
Requires-Dist: scipy>=1.10.0
|
|
29
|
+
Requires-Dist: pandas>=2.0.0
|
|
30
|
+
Requires-Dist: scikit-learn>=1.2.0
|
|
31
|
+
Requires-Dist: matplotlib>=3.7.0
|
|
32
|
+
Requires-Dist: pyyaml>=6.0
|
|
33
|
+
Requires-Dist: tqdm>=4.65.0
|
|
34
|
+
Provides-Extra: viz
|
|
35
|
+
Requires-Dist: plotly>=5.14.0; extra == "viz"
|
|
36
|
+
Requires-Dist: py3Dmol>=2.0.0; extra == "viz"
|
|
37
|
+
Requires-Dist: seaborn>=0.12.0; extra == "viz"
|
|
38
|
+
Requires-Dist: ipywidgets>=8.0.0; extra == "viz"
|
|
39
|
+
Provides-Extra: dev
|
|
40
|
+
Requires-Dist: pytest>=7.3.0; extra == "dev"
|
|
41
|
+
Requires-Dist: pytest-cov>=4.1.0; extra == "dev"
|
|
42
|
+
Requires-Dist: black>=23.0.0; extra == "dev"
|
|
43
|
+
Requires-Dist: ruff>=0.1.0; extra == "dev"
|
|
44
|
+
Requires-Dist: mypy>=1.4.0; extra == "dev"
|
|
45
|
+
|
|
46
|
+
# InterGNN — Interpretable GNN-Based Framework for Drug Discovery
|
|
47
|
+
|
|
48
|
+

|
|
49
|
+

|
|
50
|
+

|
|
51
|
+
|
|
52
|
+
An interpretable Graph Neural Network framework combining state-of-the-art molecular property prediction with inherent and post-hoc explainability methods. Designed for drug discovery workflows requiring trust, transparency, and scientific insight.
|
|
53
|
+
|
|
54
|
+
---
|
|
55
|
+
|
|
56
|
+
## Architecture
|
|
57
|
+
|
|
58
|
+
```
|
|
59
|
+
SMILES → Standardize → Featurize → MolecularGNNEncoder ──┐
|
|
60
|
+
├─ CrossAttention → TaskHead → Prediction
|
|
61
|
+
Protein → ProteinGraphBuilder → TargetGNNEncoder ─────────┘
|
|
62
|
+
│
|
|
63
|
+
┌───────────────┼───────────────┐
|
|
64
|
+
▼ ▼ ▼
|
|
65
|
+
PrototypeLayer MotifHead ConceptWhitening
|
|
66
|
+
(case-based) (substructure) (axis-aligned)
|
|
67
|
+
```
|
|
68
|
+
|
|
69
|
+
### Key Features
|
|
70
|
+
|
|
71
|
+
| Feature | Description |
|
|
72
|
+
|---------|-------------|
|
|
73
|
+
| **Molecular Encoder** | GINEConv with edge-aware message passing and chirality features |
|
|
74
|
+
| **Target Encoder** | Multi-head GATConv for residue-level protein graphs |
|
|
75
|
+
| **Cross-Attention Fusion** | Atom-residue interaction for drug-target affinity |
|
|
76
|
+
| **PAGE Prototypes** | Case-based classification via learned prototypes |
|
|
77
|
+
| **MAGE Motifs** | Differentiable motif mask generation with Gumbel-sigmoid |
|
|
78
|
+
| **Concept Whitening** | ZCA whitening + axis-aligned concept interpretability |
|
|
79
|
+
| **CF-GNNExplainer** | Counterfactual minimal perturbation explanations |
|
|
80
|
+
| **T-GNNExplainer** | Sufficient subgraph identification |
|
|
81
|
+
| **CIDER Diagnostics** | Causal invariance testing across environments |
|
|
82
|
+
|
|
83
|
+
---
|
|
84
|
+
|
|
85
|
+
## Installation
|
|
86
|
+
|
|
87
|
+
```bash
|
|
88
|
+
# Clone the repository
|
|
89
|
+
git clone https://github.com/your-org/Inter_gnn.git
|
|
90
|
+
cd Inter_gnn
|
|
91
|
+
|
|
92
|
+
# Install with all dependencies
|
|
93
|
+
pip install -e ".[vis,dev]"
|
|
94
|
+
```
|
|
95
|
+
|
|
96
|
+
### Requirements
|
|
97
|
+
|
|
98
|
+
- Python ≥ 3.9
|
|
99
|
+
- PyTorch ≥ 2.0
|
|
100
|
+
- PyTorch Geometric ≥ 2.4
|
|
101
|
+
- RDKit ≥ 2023.03
|
|
102
|
+
- NumPy, SciPy, Pandas, scikit-learn, matplotlib
|
|
103
|
+
|
|
104
|
+
---
|
|
105
|
+
|
|
106
|
+
## Quick Start
|
|
107
|
+
|
|
108
|
+
### 1. Create a Configuration
|
|
109
|
+
|
|
110
|
+
```yaml
|
|
111
|
+
# config.yaml
|
|
112
|
+
data:
|
|
113
|
+
dataset_name: tox21
|
|
114
|
+
split_method: scaffold
|
|
115
|
+
batch_size: 32
|
|
116
|
+
detect_cliffs: true
|
|
117
|
+
compute_concepts: true
|
|
118
|
+
|
|
119
|
+
model:
|
|
120
|
+
hidden_dim: 256
|
|
121
|
+
num_mol_layers: 4
|
|
122
|
+
task_type: classification
|
|
123
|
+
num_tasks: 12
|
|
124
|
+
|
|
125
|
+
interpretability:
|
|
126
|
+
use_prototypes: true
|
|
127
|
+
num_prototypes_per_class: 5
|
|
128
|
+
use_motifs: true
|
|
129
|
+
num_motifs: 8
|
|
130
|
+
use_concept_whitening: true
|
|
131
|
+
|
|
132
|
+
training:
|
|
133
|
+
pretrain_epochs: 50
|
|
134
|
+
finetune_epochs: 100
|
|
135
|
+
learning_rate: 0.001
|
|
136
|
+
```
|
|
137
|
+
|
|
138
|
+
### 2. Train
|
|
139
|
+
|
|
140
|
+
```bash
|
|
141
|
+
inter-gnn train --config config.yaml
|
|
142
|
+
```
|
|
143
|
+
|
|
144
|
+
### 3. Evaluate
|
|
145
|
+
|
|
146
|
+
```bash
|
|
147
|
+
inter-gnn evaluate --config config.yaml --checkpoint checkpoints/finetune_best.pt
|
|
148
|
+
```
|
|
149
|
+
|
|
150
|
+
### 4. Generate Explanations
|
|
151
|
+
|
|
152
|
+
```bash
|
|
153
|
+
inter-gnn explain --config config.yaml --checkpoint model.pt --smiles "CC(=O)Oc1ccccc1C(=O)O"
|
|
154
|
+
```
|
|
155
|
+
|
|
156
|
+
### 5. Dashboard
|
|
157
|
+
|
|
158
|
+
```bash
|
|
159
|
+
inter-gnn dashboard --config config.yaml --checkpoint model.pt --output report/
|
|
160
|
+
```
|
|
161
|
+
|
|
162
|
+
---
|
|
163
|
+
|
|
164
|
+
## Python API
|
|
165
|
+
|
|
166
|
+
```python
|
|
167
|
+
from inter_gnn.training.config import InterGNNConfig
|
|
168
|
+
from inter_gnn.training.trainer import InterGNNTrainer
|
|
169
|
+
from inter_gnn.data.datamodule import InterGNNDataModule
|
|
170
|
+
|
|
171
|
+
# Load config
|
|
172
|
+
config = InterGNNConfig.from_yaml("config.yaml")
|
|
173
|
+
|
|
174
|
+
# Build data
|
|
175
|
+
dm = InterGNNDataModule(config)
|
|
176
|
+
dm.prepare_data()
|
|
177
|
+
dm.setup()
|
|
178
|
+
|
|
179
|
+
# Train (two-phase: pretrain → finetune)
|
|
180
|
+
trainer = InterGNNTrainer(config)
|
|
181
|
+
history = trainer.fit(dm.train_dataloader(), dm.val_dataloader())
|
|
182
|
+
|
|
183
|
+
# Explain a molecule
|
|
184
|
+
from inter_gnn.data.featurize import smiles_to_graph
|
|
185
|
+
import torch
|
|
186
|
+
|
|
187
|
+
graph = smiles_to_graph("CC(=O)Oc1ccccc1C(=O)O")
|
|
188
|
+
batch = torch.zeros(graph.x.shape[0], dtype=torch.long)
|
|
189
|
+
output = trainer.model(graph.x, graph.edge_index, graph.edge_attr, batch)
|
|
190
|
+
|
|
191
|
+
importance = trainer.model.get_node_importance(
|
|
192
|
+
graph.x, graph.edge_index, graph.edge_attr, batch
|
|
193
|
+
)
|
|
194
|
+
```
|
|
195
|
+
|
|
196
|
+
---
|
|
197
|
+
|
|
198
|
+
## Module Overview
|
|
199
|
+
|
|
200
|
+
```
|
|
201
|
+
inter_gnn/
|
|
202
|
+
├── data/ # Data & Preprocessing
|
|
203
|
+
│ ├── standardize.py # Molecule standardization (tautomer, charge, stereo)
|
|
204
|
+
│ ├── featurize.py # SMILES → molecular graph (~78-dim atom, ~14-dim bond)
|
|
205
|
+
│ ├── protein.py # Protein sequence → k-NN / contact graph
|
|
206
|
+
│ ├── concepts.py # SMARTS concept library (~30 patterns)
|
|
207
|
+
│ ├── cliffs.py # Activity cliff detection
|
|
208
|
+
│ ├── splits.py # Scaffold, cold-target, temporal splits
|
|
209
|
+
│ ├── datasets.py # 9 benchmark dataset loaders
|
|
210
|
+
│ └── datamodule.py # DataModule wrapper
|
|
211
|
+
├── models/ # Core Model
|
|
212
|
+
│ ├── encoders.py # GINEConv (molecule) + GATConv (protein) encoders
|
|
213
|
+
│ ├── attention.py # Cross-attention fusion + bilinear alternative
|
|
214
|
+
│ ├── task_heads.py # Classification + regression heads
|
|
215
|
+
│ └── core_model.py # Unified InterGNN model
|
|
216
|
+
├── interpretability/ # Intrinsic Interpretability
|
|
217
|
+
│ ├── prototypes.py # PAGE-inspired prototype layer
|
|
218
|
+
│ ├── motifs.py # MAGE-inspired motif generator
|
|
219
|
+
│ ├── concept_whitening.py # ZCA whitening + concept alignment
|
|
220
|
+
│ └── stability.py # Explanation stability regularizer
|
|
221
|
+
├── explainers/ # Post-hoc Explanations
|
|
222
|
+
│ ├── cf_explainer.py # CF-GNNExplainer (counterfactual)
|
|
223
|
+
│ ├── t_explainer.py # T-GNNExplainer (sufficient subgraph)
|
|
224
|
+
│ └── cider.py # CIDER causal invariance diagnostics
|
|
225
|
+
├── training/ # Training Pipeline
|
|
226
|
+
│ ├── losses.py # Combined multi-objective loss
|
|
227
|
+
│ ├── trainer.py # Two-phase trainer (pretrain + finetune)
|
|
228
|
+
│ ├── callbacks.py # EarlyStopping, checkpointing, monitoring
|
|
229
|
+
│ └── config.py # YAML config with dataclass hierarchy
|
|
230
|
+
├── evaluation/ # Evaluation Metrics
|
|
231
|
+
│ ├── predictive.py # ROC-AUC, PR-AUC, RMSE, CI, etc.
|
|
232
|
+
│ ├── faithfulness.py # Deletion/Insertion AUC, sufficiency/necessity
|
|
233
|
+
│ ├── stability_metrics.py # Jaccard stability, cliff consistency
|
|
234
|
+
│ ├── chemical_validity.py # Valence checks, SMARTS match rates
|
|
235
|
+
│ ├── causal.py # Invariance violation, environment alignment
|
|
236
|
+
│ └── statistical.py # Paired bootstrap, randomization tests
|
|
237
|
+
├── visualization/ # Visualization Tools
|
|
238
|
+
│ ├── molecule_viz.py # Atom/bond saliency rendering
|
|
239
|
+
│ ├── prototype_viz.py # Prototype gallery
|
|
240
|
+
│ ├── motif_viz.py # Motif activation heatmaps
|
|
241
|
+
│ ├── concept_viz.py # Concept activation bars
|
|
242
|
+
│ ├── counterfactual_viz.py# Counterfactual edit display
|
|
243
|
+
│ └── dashboard.py # HTML batch-export dashboard
|
|
244
|
+
└── cli.py # Command-line interface
|
|
245
|
+
```
|
|
246
|
+
|
|
247
|
+
---
|
|
248
|
+
|
|
249
|
+
## Supported Datasets
|
|
250
|
+
|
|
251
|
+
| Dataset | Type | Tasks | Source |
|
|
252
|
+
|---------|------|-------|--------|
|
|
253
|
+
| MUTAG | Classification | 1 | TUDataset |
|
|
254
|
+
| Tox21 | Classification | 12 | MoleculeNet |
|
|
255
|
+
| ClinTox | Classification | 2 | MoleculeNet |
|
|
256
|
+
| QM9 | Regression | 19 | MoleculeNet |
|
|
257
|
+
| Davis | DTA Regression | 1 | TDC |
|
|
258
|
+
| KIBA | DTA Regression | 1 | TDC |
|
|
259
|
+
| BindingDB | DTA Regression | 1 | TDC |
|
|
260
|
+
| SIDER | Classification | 27 | MoleculeNet |
|
|
261
|
+
| SynLethDB | Classification | 1 | Custom |
|
|
262
|
+
|
|
263
|
+
---
|
|
264
|
+
|
|
265
|
+
## Two-Phase Training
|
|
266
|
+
|
|
267
|
+
1. **Pre-training** — Trains encoders + task head with prediction loss only
|
|
268
|
+
2. **Joint Fine-tuning** — Attaches interpretability modules, trains all losses:
|
|
269
|
+
- `L_pred`: Task prediction (BCE/MSE)
|
|
270
|
+
- `L_pull/push/div`: Prototype losses
|
|
271
|
+
- `L_sparsity/conn`: Motif losses
|
|
272
|
+
- `L_align/decorr`: Concept whitening losses
|
|
273
|
+
- `L_stability`: Explanation stability
|
|
274
|
+
|
|
275
|
+
---
|
|
276
|
+
|
|
277
|
+
## Citation
|
|
278
|
+
|
|
279
|
+
```bibtex
|
|
280
|
+
@software{inter_gnn2025,
|
|
281
|
+
title={InterGNN: Interpretable Graph Neural Network for Drug Discovery},
|
|
282
|
+
year={2025},
|
|
283
|
+
}
|
|
284
|
+
```
|
|
285
|
+
|
|
286
|
+
## License
|
|
287
|
+
|
|
288
|
+
MIT License
|
|
@@ -0,0 +1,243 @@
|
|
|
1
|
+
# InterGNN — Interpretable GNN-Based Framework for Drug Discovery
|
|
2
|
+
|
|
3
|
+

|
|
4
|
+

|
|
5
|
+

|
|
6
|
+
|
|
7
|
+
An interpretable Graph Neural Network framework combining state-of-the-art molecular property prediction with inherent and post-hoc explainability methods. Designed for drug discovery workflows requiring trust, transparency, and scientific insight.
|
|
8
|
+
|
|
9
|
+
---
|
|
10
|
+
|
|
11
|
+
## Architecture
|
|
12
|
+
|
|
13
|
+
```
|
|
14
|
+
SMILES → Standardize → Featurize → MolecularGNNEncoder ──┐
|
|
15
|
+
├─ CrossAttention → TaskHead → Prediction
|
|
16
|
+
Protein → ProteinGraphBuilder → TargetGNNEncoder ─────────┘
|
|
17
|
+
│
|
|
18
|
+
┌───────────────┼───────────────┐
|
|
19
|
+
▼ ▼ ▼
|
|
20
|
+
PrototypeLayer MotifHead ConceptWhitening
|
|
21
|
+
(case-based) (substructure) (axis-aligned)
|
|
22
|
+
```
|
|
23
|
+
|
|
24
|
+
### Key Features
|
|
25
|
+
|
|
26
|
+
| Feature | Description |
|
|
27
|
+
|---------|-------------|
|
|
28
|
+
| **Molecular Encoder** | GINEConv with edge-aware message passing and chirality features |
|
|
29
|
+
| **Target Encoder** | Multi-head GATConv for residue-level protein graphs |
|
|
30
|
+
| **Cross-Attention Fusion** | Atom-residue interaction for drug-target affinity |
|
|
31
|
+
| **PAGE Prototypes** | Case-based classification via learned prototypes |
|
|
32
|
+
| **MAGE Motifs** | Differentiable motif mask generation with Gumbel-sigmoid |
|
|
33
|
+
| **Concept Whitening** | ZCA whitening + axis-aligned concept interpretability |
|
|
34
|
+
| **CF-GNNExplainer** | Counterfactual minimal perturbation explanations |
|
|
35
|
+
| **T-GNNExplainer** | Sufficient subgraph identification |
|
|
36
|
+
| **CIDER Diagnostics** | Causal invariance testing across environments |
|
|
37
|
+
|
|
38
|
+
---
|
|
39
|
+
|
|
40
|
+
## Installation
|
|
41
|
+
|
|
42
|
+
```bash
|
|
43
|
+
# Clone the repository
|
|
44
|
+
git clone https://github.com/your-org/Inter_gnn.git
|
|
45
|
+
cd Inter_gnn
|
|
46
|
+
|
|
47
|
+
# Install with all dependencies
|
|
48
|
+
pip install -e ".[vis,dev]"
|
|
49
|
+
```
|
|
50
|
+
|
|
51
|
+
### Requirements
|
|
52
|
+
|
|
53
|
+
- Python ≥ 3.9
|
|
54
|
+
- PyTorch ≥ 2.0
|
|
55
|
+
- PyTorch Geometric ≥ 2.4
|
|
56
|
+
- RDKit ≥ 2023.03
|
|
57
|
+
- NumPy, SciPy, Pandas, scikit-learn, matplotlib
|
|
58
|
+
|
|
59
|
+
---
|
|
60
|
+
|
|
61
|
+
## Quick Start
|
|
62
|
+
|
|
63
|
+
### 1. Create a Configuration
|
|
64
|
+
|
|
65
|
+
```yaml
|
|
66
|
+
# config.yaml
|
|
67
|
+
data:
|
|
68
|
+
dataset_name: tox21
|
|
69
|
+
split_method: scaffold
|
|
70
|
+
batch_size: 32
|
|
71
|
+
detect_cliffs: true
|
|
72
|
+
compute_concepts: true
|
|
73
|
+
|
|
74
|
+
model:
|
|
75
|
+
hidden_dim: 256
|
|
76
|
+
num_mol_layers: 4
|
|
77
|
+
task_type: classification
|
|
78
|
+
num_tasks: 12
|
|
79
|
+
|
|
80
|
+
interpretability:
|
|
81
|
+
use_prototypes: true
|
|
82
|
+
num_prototypes_per_class: 5
|
|
83
|
+
use_motifs: true
|
|
84
|
+
num_motifs: 8
|
|
85
|
+
use_concept_whitening: true
|
|
86
|
+
|
|
87
|
+
training:
|
|
88
|
+
pretrain_epochs: 50
|
|
89
|
+
finetune_epochs: 100
|
|
90
|
+
learning_rate: 0.001
|
|
91
|
+
```
|
|
92
|
+
|
|
93
|
+
### 2. Train
|
|
94
|
+
|
|
95
|
+
```bash
|
|
96
|
+
inter-gnn train --config config.yaml
|
|
97
|
+
```
|
|
98
|
+
|
|
99
|
+
### 3. Evaluate
|
|
100
|
+
|
|
101
|
+
```bash
|
|
102
|
+
inter-gnn evaluate --config config.yaml --checkpoint checkpoints/finetune_best.pt
|
|
103
|
+
```
|
|
104
|
+
|
|
105
|
+
### 4. Generate Explanations
|
|
106
|
+
|
|
107
|
+
```bash
|
|
108
|
+
inter-gnn explain --config config.yaml --checkpoint model.pt --smiles "CC(=O)Oc1ccccc1C(=O)O"
|
|
109
|
+
```
|
|
110
|
+
|
|
111
|
+
### 5. Dashboard
|
|
112
|
+
|
|
113
|
+
```bash
|
|
114
|
+
inter-gnn dashboard --config config.yaml --checkpoint model.pt --output report/
|
|
115
|
+
```
|
|
116
|
+
|
|
117
|
+
---
|
|
118
|
+
|
|
119
|
+
## Python API
|
|
120
|
+
|
|
121
|
+
```python
|
|
122
|
+
from inter_gnn.training.config import InterGNNConfig
|
|
123
|
+
from inter_gnn.training.trainer import InterGNNTrainer
|
|
124
|
+
from inter_gnn.data.datamodule import InterGNNDataModule
|
|
125
|
+
|
|
126
|
+
# Load config
|
|
127
|
+
config = InterGNNConfig.from_yaml("config.yaml")
|
|
128
|
+
|
|
129
|
+
# Build data
|
|
130
|
+
dm = InterGNNDataModule(config)
|
|
131
|
+
dm.prepare_data()
|
|
132
|
+
dm.setup()
|
|
133
|
+
|
|
134
|
+
# Train (two-phase: pretrain → finetune)
|
|
135
|
+
trainer = InterGNNTrainer(config)
|
|
136
|
+
history = trainer.fit(dm.train_dataloader(), dm.val_dataloader())
|
|
137
|
+
|
|
138
|
+
# Explain a molecule
|
|
139
|
+
from inter_gnn.data.featurize import smiles_to_graph
|
|
140
|
+
import torch
|
|
141
|
+
|
|
142
|
+
graph = smiles_to_graph("CC(=O)Oc1ccccc1C(=O)O")
|
|
143
|
+
batch = torch.zeros(graph.x.shape[0], dtype=torch.long)
|
|
144
|
+
output = trainer.model(graph.x, graph.edge_index, graph.edge_attr, batch)
|
|
145
|
+
|
|
146
|
+
importance = trainer.model.get_node_importance(
|
|
147
|
+
graph.x, graph.edge_index, graph.edge_attr, batch
|
|
148
|
+
)
|
|
149
|
+
```
|
|
150
|
+
|
|
151
|
+
---
|
|
152
|
+
|
|
153
|
+
## Module Overview
|
|
154
|
+
|
|
155
|
+
```
|
|
156
|
+
inter_gnn/
|
|
157
|
+
├── data/ # Data & Preprocessing
|
|
158
|
+
│ ├── standardize.py # Molecule standardization (tautomer, charge, stereo)
|
|
159
|
+
│ ├── featurize.py # SMILES → molecular graph (~78-dim atom, ~14-dim bond)
|
|
160
|
+
│ ├── protein.py # Protein sequence → k-NN / contact graph
|
|
161
|
+
│ ├── concepts.py # SMARTS concept library (~30 patterns)
|
|
162
|
+
│ ├── cliffs.py # Activity cliff detection
|
|
163
|
+
│ ├── splits.py # Scaffold, cold-target, temporal splits
|
|
164
|
+
│ ├── datasets.py # 9 benchmark dataset loaders
|
|
165
|
+
│ └── datamodule.py # DataModule wrapper
|
|
166
|
+
├── models/ # Core Model
|
|
167
|
+
│ ├── encoders.py # GINEConv (molecule) + GATConv (protein) encoders
|
|
168
|
+
│ ├── attention.py # Cross-attention fusion + bilinear alternative
|
|
169
|
+
│ ├── task_heads.py # Classification + regression heads
|
|
170
|
+
│ └── core_model.py # Unified InterGNN model
|
|
171
|
+
├── interpretability/ # Intrinsic Interpretability
|
|
172
|
+
│ ├── prototypes.py # PAGE-inspired prototype layer
|
|
173
|
+
│ ├── motifs.py # MAGE-inspired motif generator
|
|
174
|
+
│ ├── concept_whitening.py # ZCA whitening + concept alignment
|
|
175
|
+
│ └── stability.py # Explanation stability regularizer
|
|
176
|
+
├── explainers/ # Post-hoc Explanations
|
|
177
|
+
│ ├── cf_explainer.py # CF-GNNExplainer (counterfactual)
|
|
178
|
+
│ ├── t_explainer.py # T-GNNExplainer (sufficient subgraph)
|
|
179
|
+
│ └── cider.py # CIDER causal invariance diagnostics
|
|
180
|
+
├── training/ # Training Pipeline
|
|
181
|
+
│ ├── losses.py # Combined multi-objective loss
|
|
182
|
+
│ ├── trainer.py # Two-phase trainer (pretrain + finetune)
|
|
183
|
+
│ ├── callbacks.py # EarlyStopping, checkpointing, monitoring
|
|
184
|
+
│ └── config.py # YAML config with dataclass hierarchy
|
|
185
|
+
├── evaluation/ # Evaluation Metrics
|
|
186
|
+
│ ├── predictive.py # ROC-AUC, PR-AUC, RMSE, CI, etc.
|
|
187
|
+
│ ├── faithfulness.py # Deletion/Insertion AUC, sufficiency/necessity
|
|
188
|
+
│ ├── stability_metrics.py # Jaccard stability, cliff consistency
|
|
189
|
+
│ ├── chemical_validity.py # Valence checks, SMARTS match rates
|
|
190
|
+
│ ├── causal.py # Invariance violation, environment alignment
|
|
191
|
+
│ └── statistical.py # Paired bootstrap, randomization tests
|
|
192
|
+
├── visualization/ # Visualization Tools
|
|
193
|
+
│ ├── molecule_viz.py # Atom/bond saliency rendering
|
|
194
|
+
│ ├── prototype_viz.py # Prototype gallery
|
|
195
|
+
│ ├── motif_viz.py # Motif activation heatmaps
|
|
196
|
+
│ ├── concept_viz.py # Concept activation bars
|
|
197
|
+
│ ├── counterfactual_viz.py# Counterfactual edit display
|
|
198
|
+
│ └── dashboard.py # HTML batch-export dashboard
|
|
199
|
+
└── cli.py # Command-line interface
|
|
200
|
+
```
|
|
201
|
+
|
|
202
|
+
---
|
|
203
|
+
|
|
204
|
+
## Supported Datasets
|
|
205
|
+
|
|
206
|
+
| Dataset | Type | Tasks | Source |
|
|
207
|
+
|---------|------|-------|--------|
|
|
208
|
+
| MUTAG | Classification | 1 | TUDataset |
|
|
209
|
+
| Tox21 | Classification | 12 | MoleculeNet |
|
|
210
|
+
| ClinTox | Classification | 2 | MoleculeNet |
|
|
211
|
+
| QM9 | Regression | 19 | MoleculeNet |
|
|
212
|
+
| Davis | DTA Regression | 1 | TDC |
|
|
213
|
+
| KIBA | DTA Regression | 1 | TDC |
|
|
214
|
+
| BindingDB | DTA Regression | 1 | TDC |
|
|
215
|
+
| SIDER | Classification | 27 | MoleculeNet |
|
|
216
|
+
| SynLethDB | Classification | 1 | Custom |
|
|
217
|
+
|
|
218
|
+
---
|
|
219
|
+
|
|
220
|
+
## Two-Phase Training
|
|
221
|
+
|
|
222
|
+
1. **Pre-training** — Trains encoders + task head with prediction loss only
|
|
223
|
+
2. **Joint Fine-tuning** — Attaches interpretability modules, trains all losses:
|
|
224
|
+
- `L_pred`: Task prediction (BCE/MSE)
|
|
225
|
+
- `L_pull/push/div`: Prototype losses
|
|
226
|
+
- `L_sparsity/conn`: Motif losses
|
|
227
|
+
- `L_align/decorr`: Concept whitening losses
|
|
228
|
+
- `L_stability`: Explanation stability
|
|
229
|
+
|
|
230
|
+
---
|
|
231
|
+
|
|
232
|
+
## Citation
|
|
233
|
+
|
|
234
|
+
```bibtex
|
|
235
|
+
@software{inter_gnn2025,
|
|
236
|
+
title={InterGNN: Interpretable Graph Neural Network for Drug Discovery},
|
|
237
|
+
year={2025},
|
|
238
|
+
}
|
|
239
|
+
```
|
|
240
|
+
|
|
241
|
+
## License
|
|
242
|
+
|
|
243
|
+
MIT License
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Inter-GNN: Interpretable GNN-Based Framework for Drug Discovery and Candidate Screening.
|
|
3
|
+
|
|
4
|
+
A modular Python package implementing:
|
|
5
|
+
- Data preprocessing (standardization, graph featurization, protein graphs, concepts, cliffs, splits)
|
|
6
|
+
- Core GNN model (edge/chirality-aware MPNN, cross-attention fusion, task heads)
|
|
7
|
+
- Intrinsic interpretability (PAGE prototypes, MAGE motifs, concept whitening)
|
|
8
|
+
- Post-hoc explainability (CF-GNNExplainer, T-GNNExplainer, CIDER diagnostics)
|
|
9
|
+
- Evaluation metrics (predictive, faithfulness, stability, chemical validity, causal)
|
|
10
|
+
- Visualization tools (saliency, prototypes, motifs, concepts, counterfactuals)
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
__version__ = "0.1.0"
|
|
14
|
+
__author__ = "Harshal Loya, Jash Chauhan, Het Gala"
|
|
15
|
+
|
|
16
|
+
from inter_gnn.models.core_model import InterGNN
|
|
17
|
+
from inter_gnn.training.config import InterGNNConfig
|
|
18
|
+
|
|
19
|
+
__all__ = [
|
|
20
|
+
"InterGNN",
|
|
21
|
+
"InterGNNConfig",
|
|
22
|
+
"__version__",
|
|
23
|
+
]
|