inter-gnn 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. inter_gnn-0.1.0/PKG-INFO +288 -0
  2. inter_gnn-0.1.0/README.md +243 -0
  3. inter_gnn-0.1.0/inter_gnn/__init__.py +23 -0
  4. inter_gnn-0.1.0/inter_gnn/cli.py +248 -0
  5. inter_gnn-0.1.0/inter_gnn/data/__init__.py +27 -0
  6. inter_gnn-0.1.0/inter_gnn/data/cliffs.py +283 -0
  7. inter_gnn-0.1.0/inter_gnn/data/concepts.py +369 -0
  8. inter_gnn-0.1.0/inter_gnn/data/datamodule.py +143 -0
  9. inter_gnn-0.1.0/inter_gnn/data/datasets.py +202 -0
  10. inter_gnn-0.1.0/inter_gnn/data/featurize.py +384 -0
  11. inter_gnn-0.1.0/inter_gnn/data/protein.py +410 -0
  12. inter_gnn-0.1.0/inter_gnn/data/splits.py +373 -0
  13. inter_gnn-0.1.0/inter_gnn/data/standardize.py +288 -0
  14. inter_gnn-0.1.0/inter_gnn/evaluation/__init__.py +24 -0
  15. inter_gnn-0.1.0/inter_gnn/evaluation/causal.py +104 -0
  16. inter_gnn-0.1.0/inter_gnn/evaluation/chemical_validity.py +134 -0
  17. inter_gnn-0.1.0/inter_gnn/evaluation/faithfulness.py +176 -0
  18. inter_gnn-0.1.0/inter_gnn/evaluation/predictive.py +145 -0
  19. inter_gnn-0.1.0/inter_gnn/evaluation/stability_metrics.py +115 -0
  20. inter_gnn-0.1.0/inter_gnn/evaluation/statistical.py +119 -0
  21. inter_gnn-0.1.0/inter_gnn/explainers/__init__.py +11 -0
  22. inter_gnn-0.1.0/inter_gnn/explainers/cf_explainer.py +201 -0
  23. inter_gnn-0.1.0/inter_gnn/explainers/cider.py +207 -0
  24. inter_gnn-0.1.0/inter_gnn/explainers/t_explainer.py +176 -0
  25. inter_gnn-0.1.0/inter_gnn/interpretability/__init__.py +14 -0
  26. inter_gnn-0.1.0/inter_gnn/interpretability/concept_whitening.py +171 -0
  27. inter_gnn-0.1.0/inter_gnn/interpretability/motifs.py +198 -0
  28. inter_gnn-0.1.0/inter_gnn/interpretability/prototypes.py +195 -0
  29. inter_gnn-0.1.0/inter_gnn/interpretability/stability.py +183 -0
  30. inter_gnn-0.1.0/inter_gnn/models/__init__.py +17 -0
  31. inter_gnn-0.1.0/inter_gnn/models/attention.py +225 -0
  32. inter_gnn-0.1.0/inter_gnn/models/core_model.py +221 -0
  33. inter_gnn-0.1.0/inter_gnn/models/encoders.py +244 -0
  34. inter_gnn-0.1.0/inter_gnn/models/task_heads.py +150 -0
  35. inter_gnn-0.1.0/inter_gnn/training/__init__.py +11 -0
  36. inter_gnn-0.1.0/inter_gnn/training/callbacks.py +161 -0
  37. inter_gnn-0.1.0/inter_gnn/training/config.py +178 -0
  38. inter_gnn-0.1.0/inter_gnn/training/losses.py +179 -0
  39. inter_gnn-0.1.0/inter_gnn/training/trainer.py +308 -0
  40. inter_gnn-0.1.0/inter_gnn/visualization/__init__.py +15 -0
  41. inter_gnn-0.1.0/inter_gnn/visualization/concept_viz.py +120 -0
  42. inter_gnn-0.1.0/inter_gnn/visualization/counterfactual_viz.py +148 -0
  43. inter_gnn-0.1.0/inter_gnn/visualization/dashboard.py +204 -0
  44. inter_gnn-0.1.0/inter_gnn/visualization/molecule_viz.py +185 -0
  45. inter_gnn-0.1.0/inter_gnn/visualization/motif_viz.py +133 -0
  46. inter_gnn-0.1.0/inter_gnn/visualization/prototype_viz.py +132 -0
  47. inter_gnn-0.1.0/inter_gnn.egg-info/PKG-INFO +288 -0
  48. inter_gnn-0.1.0/inter_gnn.egg-info/SOURCES.txt +52 -0
  49. inter_gnn-0.1.0/inter_gnn.egg-info/dependency_links.txt +1 -0
  50. inter_gnn-0.1.0/inter_gnn.egg-info/entry_points.txt +2 -0
  51. inter_gnn-0.1.0/inter_gnn.egg-info/requires.txt +23 -0
  52. inter_gnn-0.1.0/inter_gnn.egg-info/top_level.txt +1 -0
  53. inter_gnn-0.1.0/pyproject.toml +89 -0
  54. inter_gnn-0.1.0/setup.cfg +4 -0
@@ -0,0 +1,288 @@
1
+ Metadata-Version: 2.4
2
+ Name: inter-gnn
3
+ Version: 0.1.0
4
+ Summary: Interpretable GNN-Based Framework for Drug Discovery and Candidate Screening
5
+ Author: Harshal Loya, Jash Chauhan, Het Gala
6
+ License: MIT
7
+ Project-URL: Homepage, https://github.com/inter-gnn/inter-gnn
8
+ Project-URL: Documentation, https://inter-gnn.readthedocs.io
9
+ Project-URL: Repository, https://github.com/inter-gnn/inter-gnn
10
+ Keywords: graph-neural-networks,drug-discovery,explainable-ai,molecular-property-prediction,interpretability,activity-cliffs,concept-whitening
11
+ Classifier: Development Status :: 3 - Alpha
12
+ Classifier: Intended Audience :: Science/Research
13
+ Classifier: License :: OSI Approved :: MIT License
14
+ Classifier: Programming Language :: Python :: 3
15
+ Classifier: Programming Language :: Python :: 3.9
16
+ Classifier: Programming Language :: Python :: 3.10
17
+ Classifier: Programming Language :: Python :: 3.11
18
+ Classifier: Programming Language :: Python :: 3.12
19
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
20
+ Classifier: Topic :: Scientific/Engineering :: Bio-Informatics
21
+ Classifier: Topic :: Scientific/Engineering :: Chemistry
22
+ Requires-Python: >=3.9
23
+ Description-Content-Type: text/markdown
24
+ Requires-Dist: torch>=2.0.0
25
+ Requires-Dist: torch-geometric>=2.4.0
26
+ Requires-Dist: rdkit>=2023.3.1
27
+ Requires-Dist: numpy>=1.24.0
28
+ Requires-Dist: scipy>=1.10.0
29
+ Requires-Dist: pandas>=2.0.0
30
+ Requires-Dist: scikit-learn>=1.2.0
31
+ Requires-Dist: matplotlib>=3.7.0
32
+ Requires-Dist: pyyaml>=6.0
33
+ Requires-Dist: tqdm>=4.65.0
34
+ Provides-Extra: viz
35
+ Requires-Dist: plotly>=5.14.0; extra == "viz"
36
+ Requires-Dist: py3Dmol>=2.0.0; extra == "viz"
37
+ Requires-Dist: seaborn>=0.12.0; extra == "viz"
38
+ Requires-Dist: ipywidgets>=8.0.0; extra == "viz"
39
+ Provides-Extra: dev
40
+ Requires-Dist: pytest>=7.3.0; extra == "dev"
41
+ Requires-Dist: pytest-cov>=4.1.0; extra == "dev"
42
+ Requires-Dist: black>=23.0.0; extra == "dev"
43
+ Requires-Dist: ruff>=0.1.0; extra == "dev"
44
+ Requires-Dist: mypy>=1.4.0; extra == "dev"
45
+
46
+ # InterGNN — Interpretable GNN-Based Framework for Drug Discovery
47
+
48
+ ![Python 3.9+](https://img.shields.io/badge/python-3.9%2B-blue)
49
+ ![PyTorch](https://img.shields.io/badge/PyTorch-2.0%2B-red)
50
+ ![License: MIT](https://img.shields.io/badge/License-MIT-green)
51
+
52
+ An interpretable Graph Neural Network framework combining state-of-the-art molecular property prediction with inherent and post-hoc explainability methods. Designed for drug discovery workflows requiring trust, transparency, and scientific insight.
53
+
54
+ ---
55
+
56
+ ## Architecture
57
+
58
+ ```
59
+ SMILES → Standardize → Featurize → MolecularGNNEncoder ──┐
60
+ ├─ CrossAttention → TaskHead → Prediction
61
+ Protein → ProteinGraphBuilder → TargetGNNEncoder ─────────┘
62
+
63
+ ┌───────────────┼───────────────┐
64
+ ▼ ▼ ▼
65
+ PrototypeLayer MotifHead ConceptWhitening
66
+ (case-based) (substructure) (axis-aligned)
67
+ ```
68
+
69
+ ### Key Features
70
+
71
+ | Feature | Description |
72
+ |---------|-------------|
73
+ | **Molecular Encoder** | GINEConv with edge-aware message passing and chirality features |
74
+ | **Target Encoder** | Multi-head GATConv for residue-level protein graphs |
75
+ | **Cross-Attention Fusion** | Atom-residue interaction for drug-target affinity |
76
+ | **PAGE Prototypes** | Case-based classification via learned prototypes |
77
+ | **MAGE Motifs** | Differentiable motif mask generation with Gumbel-sigmoid |
78
+ | **Concept Whitening** | ZCA whitening + axis-aligned concept interpretability |
79
+ | **CF-GNNExplainer** | Counterfactual minimal perturbation explanations |
80
+ | **T-GNNExplainer** | Sufficient subgraph identification |
81
+ | **CIDER Diagnostics** | Causal invariance testing across environments |
82
+
83
+ ---
84
+
85
+ ## Installation
86
+
87
+ ```bash
88
+ # Clone the repository
89
+ git clone https://github.com/your-org/Inter_gnn.git
90
+ cd Inter_gnn
91
+
92
+ # Install with all dependencies
93
+ pip install -e ".[vis,dev]"
94
+ ```
95
+
96
+ ### Requirements
97
+
98
+ - Python ≥ 3.9
99
+ - PyTorch ≥ 2.0
100
+ - PyTorch Geometric ≥ 2.4
101
+ - RDKit ≥ 2023.03
102
+ - NumPy, SciPy, Pandas, scikit-learn, matplotlib
103
+
104
+ ---
105
+
106
+ ## Quick Start
107
+
108
+ ### 1. Create a Configuration
109
+
110
+ ```yaml
111
+ # config.yaml
112
+ data:
113
+ dataset_name: tox21
114
+ split_method: scaffold
115
+ batch_size: 32
116
+ detect_cliffs: true
117
+ compute_concepts: true
118
+
119
+ model:
120
+ hidden_dim: 256
121
+ num_mol_layers: 4
122
+ task_type: classification
123
+ num_tasks: 12
124
+
125
+ interpretability:
126
+ use_prototypes: true
127
+ num_prototypes_per_class: 5
128
+ use_motifs: true
129
+ num_motifs: 8
130
+ use_concept_whitening: true
131
+
132
+ training:
133
+ pretrain_epochs: 50
134
+ finetune_epochs: 100
135
+ learning_rate: 0.001
136
+ ```
137
+
138
+ ### 2. Train
139
+
140
+ ```bash
141
+ inter-gnn train --config config.yaml
142
+ ```
143
+
144
+ ### 3. Evaluate
145
+
146
+ ```bash
147
+ inter-gnn evaluate --config config.yaml --checkpoint checkpoints/finetune_best.pt
148
+ ```
149
+
150
+ ### 4. Generate Explanations
151
+
152
+ ```bash
153
+ inter-gnn explain --config config.yaml --checkpoint model.pt --smiles "CC(=O)Oc1ccccc1C(=O)O"
154
+ ```
155
+
156
+ ### 5. Dashboard
157
+
158
+ ```bash
159
+ inter-gnn dashboard --config config.yaml --checkpoint model.pt --output report/
160
+ ```
161
+
162
+ ---
163
+
164
+ ## Python API
165
+
166
+ ```python
167
+ from inter_gnn.training.config import InterGNNConfig
168
+ from inter_gnn.training.trainer import InterGNNTrainer
169
+ from inter_gnn.data.datamodule import InterGNNDataModule
170
+
171
+ # Load config
172
+ config = InterGNNConfig.from_yaml("config.yaml")
173
+
174
+ # Build data
175
+ dm = InterGNNDataModule(config)
176
+ dm.prepare_data()
177
+ dm.setup()
178
+
179
+ # Train (two-phase: pretrain → finetune)
180
+ trainer = InterGNNTrainer(config)
181
+ history = trainer.fit(dm.train_dataloader(), dm.val_dataloader())
182
+
183
+ # Explain a molecule
184
+ from inter_gnn.data.featurize import smiles_to_graph
185
+ import torch
186
+
187
+ graph = smiles_to_graph("CC(=O)Oc1ccccc1C(=O)O")
188
+ batch = torch.zeros(graph.x.shape[0], dtype=torch.long)
189
+ output = trainer.model(graph.x, graph.edge_index, graph.edge_attr, batch)
190
+
191
+ importance = trainer.model.get_node_importance(
192
+ graph.x, graph.edge_index, graph.edge_attr, batch
193
+ )
194
+ ```
195
+
196
+ ---
197
+
198
+ ## Module Overview
199
+
200
+ ```
201
+ inter_gnn/
202
+ ├── data/ # Data & Preprocessing
203
+ │ ├── standardize.py # Molecule standardization (tautomer, charge, stereo)
204
+ │ ├── featurize.py # SMILES → molecular graph (~78-dim atom, ~14-dim bond)
205
+ │ ├── protein.py # Protein sequence → k-NN / contact graph
206
+ │ ├── concepts.py # SMARTS concept library (~30 patterns)
207
+ │ ├── cliffs.py # Activity cliff detection
208
+ │ ├── splits.py # Scaffold, cold-target, temporal splits
209
+ │ ├── datasets.py # 9 benchmark dataset loaders
210
+ │ └── datamodule.py # DataModule wrapper
211
+ ├── models/ # Core Model
212
+ │ ├── encoders.py # GINEConv (molecule) + GATConv (protein) encoders
213
+ │ ├── attention.py # Cross-attention fusion + bilinear alternative
214
+ │ ├── task_heads.py # Classification + regression heads
215
+ │ └── core_model.py # Unified InterGNN model
216
+ ├── interpretability/ # Intrinsic Interpretability
217
+ │ ├── prototypes.py # PAGE-inspired prototype layer
218
+ │ ├── motifs.py # MAGE-inspired motif generator
219
+ │ ├── concept_whitening.py # ZCA whitening + concept alignment
220
+ │ └── stability.py # Explanation stability regularizer
221
+ ├── explainers/ # Post-hoc Explanations
222
+ │ ├── cf_explainer.py # CF-GNNExplainer (counterfactual)
223
+ │ ├── t_explainer.py # T-GNNExplainer (sufficient subgraph)
224
+ │ └── cider.py # CIDER causal invariance diagnostics
225
+ ├── training/ # Training Pipeline
226
+ │ ├── losses.py # Combined multi-objective loss
227
+ │ ├── trainer.py # Two-phase trainer (pretrain + finetune)
228
+ │ ├── callbacks.py # EarlyStopping, checkpointing, monitoring
229
+ │ └── config.py # YAML config with dataclass hierarchy
230
+ ├── evaluation/ # Evaluation Metrics
231
+ │ ├── predictive.py # ROC-AUC, PR-AUC, RMSE, CI, etc.
232
+ │ ├── faithfulness.py # Deletion/Insertion AUC, sufficiency/necessity
233
+ │ ├── stability_metrics.py # Jaccard stability, cliff consistency
234
+ │ ├── chemical_validity.py # Valence checks, SMARTS match rates
235
+ │ ├── causal.py # Invariance violation, environment alignment
236
+ │ └── statistical.py # Paired bootstrap, randomization tests
237
+ ├── visualization/ # Visualization Tools
238
+ │ ├── molecule_viz.py # Atom/bond saliency rendering
239
+ │ ├── prototype_viz.py # Prototype gallery
240
+ │ ├── motif_viz.py # Motif activation heatmaps
241
+ │ ├── concept_viz.py # Concept activation bars
242
+ │ ├── counterfactual_viz.py# Counterfactual edit display
243
+ │ └── dashboard.py # HTML batch-export dashboard
244
+ └── cli.py # Command-line interface
245
+ ```
246
+
247
+ ---
248
+
249
+ ## Supported Datasets
250
+
251
+ | Dataset | Type | Tasks | Source |
252
+ |---------|------|-------|--------|
253
+ | MUTAG | Classification | 1 | TUDataset |
254
+ | Tox21 | Classification | 12 | MoleculeNet |
255
+ | ClinTox | Classification | 2 | MoleculeNet |
256
+ | QM9 | Regression | 19 | MoleculeNet |
257
+ | Davis | DTA Regression | 1 | TDC |
258
+ | KIBA | DTA Regression | 1 | TDC |
259
+ | BindingDB | DTA Regression | 1 | TDC |
260
+ | SIDER | Classification | 27 | MoleculeNet |
261
+ | SynLethDB | Classification | 1 | Custom |
262
+
263
+ ---
264
+
265
+ ## Two-Phase Training
266
+
267
+ 1. **Pre-training** — Trains encoders + task head with prediction loss only
268
+ 2. **Joint Fine-tuning** — Attaches interpretability modules, trains all losses:
269
+ - `L_pred`: Task prediction (BCE/MSE)
270
+ - `L_pull/push/div`: Prototype losses
271
+ - `L_sparsity/conn`: Motif losses
272
+ - `L_align/decorr`: Concept whitening losses
273
+ - `L_stability`: Explanation stability
274
+
275
+ ---
276
+
277
+ ## Citation
278
+
279
+ ```bibtex
280
+ @software{inter_gnn2025,
281
+ title={InterGNN: Interpretable Graph Neural Network for Drug Discovery},
282
+ year={2025},
283
+ }
284
+ ```
285
+
286
+ ## License
287
+
288
+ MIT License
@@ -0,0 +1,243 @@
1
+ # InterGNN — Interpretable GNN-Based Framework for Drug Discovery
2
+
3
+ ![Python 3.9+](https://img.shields.io/badge/python-3.9%2B-blue)
4
+ ![PyTorch](https://img.shields.io/badge/PyTorch-2.0%2B-red)
5
+ ![License: MIT](https://img.shields.io/badge/License-MIT-green)
6
+
7
+ An interpretable Graph Neural Network framework combining state-of-the-art molecular property prediction with inherent and post-hoc explainability methods. Designed for drug discovery workflows requiring trust, transparency, and scientific insight.
8
+
9
+ ---
10
+
11
+ ## Architecture
12
+
13
+ ```
14
+ SMILES → Standardize → Featurize → MolecularGNNEncoder ──┐
15
+ ├─ CrossAttention → TaskHead → Prediction
16
+ Protein → ProteinGraphBuilder → TargetGNNEncoder ─────────┘
17
+
18
+ ┌───────────────┼───────────────┐
19
+ ▼ ▼ ▼
20
+ PrototypeLayer MotifHead ConceptWhitening
21
+ (case-based) (substructure) (axis-aligned)
22
+ ```
23
+
24
+ ### Key Features
25
+
26
+ | Feature | Description |
27
+ |---------|-------------|
28
+ | **Molecular Encoder** | GINEConv with edge-aware message passing and chirality features |
29
+ | **Target Encoder** | Multi-head GATConv for residue-level protein graphs |
30
+ | **Cross-Attention Fusion** | Atom-residue interaction for drug-target affinity |
31
+ | **PAGE Prototypes** | Case-based classification via learned prototypes |
32
+ | **MAGE Motifs** | Differentiable motif mask generation with Gumbel-sigmoid |
33
+ | **Concept Whitening** | ZCA whitening + axis-aligned concept interpretability |
34
+ | **CF-GNNExplainer** | Counterfactual minimal perturbation explanations |
35
+ | **T-GNNExplainer** | Sufficient subgraph identification |
36
+ | **CIDER Diagnostics** | Causal invariance testing across environments |
37
+
38
+ ---
39
+
40
+ ## Installation
41
+
42
+ ```bash
43
+ # Clone the repository
44
+ git clone https://github.com/your-org/Inter_gnn.git
45
+ cd Inter_gnn
46
+
47
+ # Install with all dependencies
48
+ pip install -e ".[vis,dev]"
49
+ ```
50
+
51
+ ### Requirements
52
+
53
+ - Python ≥ 3.9
54
+ - PyTorch ≥ 2.0
55
+ - PyTorch Geometric ≥ 2.4
56
+ - RDKit ≥ 2023.03
57
+ - NumPy, SciPy, Pandas, scikit-learn, matplotlib
58
+
59
+ ---
60
+
61
+ ## Quick Start
62
+
63
+ ### 1. Create a Configuration
64
+
65
+ ```yaml
66
+ # config.yaml
67
+ data:
68
+ dataset_name: tox21
69
+ split_method: scaffold
70
+ batch_size: 32
71
+ detect_cliffs: true
72
+ compute_concepts: true
73
+
74
+ model:
75
+ hidden_dim: 256
76
+ num_mol_layers: 4
77
+ task_type: classification
78
+ num_tasks: 12
79
+
80
+ interpretability:
81
+ use_prototypes: true
82
+ num_prototypes_per_class: 5
83
+ use_motifs: true
84
+ num_motifs: 8
85
+ use_concept_whitening: true
86
+
87
+ training:
88
+ pretrain_epochs: 50
89
+ finetune_epochs: 100
90
+ learning_rate: 0.001
91
+ ```
92
+
93
+ ### 2. Train
94
+
95
+ ```bash
96
+ inter-gnn train --config config.yaml
97
+ ```
98
+
99
+ ### 3. Evaluate
100
+
101
+ ```bash
102
+ inter-gnn evaluate --config config.yaml --checkpoint checkpoints/finetune_best.pt
103
+ ```
104
+
105
+ ### 4. Generate Explanations
106
+
107
+ ```bash
108
+ inter-gnn explain --config config.yaml --checkpoint model.pt --smiles "CC(=O)Oc1ccccc1C(=O)O"
109
+ ```
110
+
111
+ ### 5. Dashboard
112
+
113
+ ```bash
114
+ inter-gnn dashboard --config config.yaml --checkpoint model.pt --output report/
115
+ ```
116
+
117
+ ---
118
+
119
+ ## Python API
120
+
121
+ ```python
122
+ from inter_gnn.training.config import InterGNNConfig
123
+ from inter_gnn.training.trainer import InterGNNTrainer
124
+ from inter_gnn.data.datamodule import InterGNNDataModule
125
+
126
+ # Load config
127
+ config = InterGNNConfig.from_yaml("config.yaml")
128
+
129
+ # Build data
130
+ dm = InterGNNDataModule(config)
131
+ dm.prepare_data()
132
+ dm.setup()
133
+
134
+ # Train (two-phase: pretrain → finetune)
135
+ trainer = InterGNNTrainer(config)
136
+ history = trainer.fit(dm.train_dataloader(), dm.val_dataloader())
137
+
138
+ # Explain a molecule
139
+ from inter_gnn.data.featurize import smiles_to_graph
140
+ import torch
141
+
142
+ graph = smiles_to_graph("CC(=O)Oc1ccccc1C(=O)O")
143
+ batch = torch.zeros(graph.x.shape[0], dtype=torch.long)
144
+ output = trainer.model(graph.x, graph.edge_index, graph.edge_attr, batch)
145
+
146
+ importance = trainer.model.get_node_importance(
147
+ graph.x, graph.edge_index, graph.edge_attr, batch
148
+ )
149
+ ```
150
+
151
+ ---
152
+
153
+ ## Module Overview
154
+
155
+ ```
156
+ inter_gnn/
157
+ ├── data/ # Data & Preprocessing
158
+ │ ├── standardize.py # Molecule standardization (tautomer, charge, stereo)
159
+ │ ├── featurize.py # SMILES → molecular graph (~78-dim atom, ~14-dim bond)
160
+ │ ├── protein.py # Protein sequence → k-NN / contact graph
161
+ │ ├── concepts.py # SMARTS concept library (~30 patterns)
162
+ │ ├── cliffs.py # Activity cliff detection
163
+ │ ├── splits.py # Scaffold, cold-target, temporal splits
164
+ │ ├── datasets.py # 9 benchmark dataset loaders
165
+ │ └── datamodule.py # DataModule wrapper
166
+ ├── models/ # Core Model
167
+ │ ├── encoders.py # GINEConv (molecule) + GATConv (protein) encoders
168
+ │ ├── attention.py # Cross-attention fusion + bilinear alternative
169
+ │ ├── task_heads.py # Classification + regression heads
170
+ │ └── core_model.py # Unified InterGNN model
171
+ ├── interpretability/ # Intrinsic Interpretability
172
+ │ ├── prototypes.py # PAGE-inspired prototype layer
173
+ │ ├── motifs.py # MAGE-inspired motif generator
174
+ │ ├── concept_whitening.py # ZCA whitening + concept alignment
175
+ │ └── stability.py # Explanation stability regularizer
176
+ ├── explainers/ # Post-hoc Explanations
177
+ │ ├── cf_explainer.py # CF-GNNExplainer (counterfactual)
178
+ │ ├── t_explainer.py # T-GNNExplainer (sufficient subgraph)
179
+ │ └── cider.py # CIDER causal invariance diagnostics
180
+ ├── training/ # Training Pipeline
181
+ │ ├── losses.py # Combined multi-objective loss
182
+ │ ├── trainer.py # Two-phase trainer (pretrain + finetune)
183
+ │ ├── callbacks.py # EarlyStopping, checkpointing, monitoring
184
+ │ └── config.py # YAML config with dataclass hierarchy
185
+ ├── evaluation/ # Evaluation Metrics
186
+ │ ├── predictive.py # ROC-AUC, PR-AUC, RMSE, CI, etc.
187
+ │ ├── faithfulness.py # Deletion/Insertion AUC, sufficiency/necessity
188
+ │ ├── stability_metrics.py # Jaccard stability, cliff consistency
189
+ │ ├── chemical_validity.py # Valence checks, SMARTS match rates
190
+ │ ├── causal.py # Invariance violation, environment alignment
191
+ │ └── statistical.py # Paired bootstrap, randomization tests
192
+ ├── visualization/ # Visualization Tools
193
+ │ ├── molecule_viz.py # Atom/bond saliency rendering
194
+ │ ├── prototype_viz.py # Prototype gallery
195
+ │ ├── motif_viz.py # Motif activation heatmaps
196
+ │ ├── concept_viz.py # Concept activation bars
197
+ │ ├── counterfactual_viz.py# Counterfactual edit display
198
+ │ └── dashboard.py # HTML batch-export dashboard
199
+ └── cli.py # Command-line interface
200
+ ```
201
+
202
+ ---
203
+
204
+ ## Supported Datasets
205
+
206
+ | Dataset | Type | Tasks | Source |
207
+ |---------|------|-------|--------|
208
+ | MUTAG | Classification | 1 | TUDataset |
209
+ | Tox21 | Classification | 12 | MoleculeNet |
210
+ | ClinTox | Classification | 2 | MoleculeNet |
211
+ | QM9 | Regression | 19 | MoleculeNet |
212
+ | Davis | DTA Regression | 1 | TDC |
213
+ | KIBA | DTA Regression | 1 | TDC |
214
+ | BindingDB | DTA Regression | 1 | TDC |
215
+ | SIDER | Classification | 27 | MoleculeNet |
216
+ | SynLethDB | Classification | 1 | Custom |
217
+
218
+ ---
219
+
220
+ ## Two-Phase Training
221
+
222
+ 1. **Pre-training** — Trains encoders + task head with prediction loss only
223
+ 2. **Joint Fine-tuning** — Attaches interpretability modules, trains all losses:
224
+ - `L_pred`: Task prediction (BCE/MSE)
225
+ - `L_pull/push/div`: Prototype losses
226
+ - `L_sparsity/conn`: Motif losses
227
+ - `L_align/decorr`: Concept whitening losses
228
+ - `L_stability`: Explanation stability
229
+
230
+ ---
231
+
232
+ ## Citation
233
+
234
+ ```bibtex
235
+ @software{inter_gnn2025,
236
+ title={InterGNN: Interpretable Graph Neural Network for Drug Discovery},
237
+ year={2025},
238
+ }
239
+ ```
240
+
241
+ ## License
242
+
243
+ MIT License
@@ -0,0 +1,23 @@
1
+ """
2
+ Inter-GNN: Interpretable GNN-Based Framework for Drug Discovery and Candidate Screening.
3
+
4
+ A modular Python package implementing:
5
+ - Data preprocessing (standardization, graph featurization, protein graphs, concepts, cliffs, splits)
6
+ - Core GNN model (edge/chirality-aware MPNN, cross-attention fusion, task heads)
7
+ - Intrinsic interpretability (PAGE prototypes, MAGE motifs, concept whitening)
8
+ - Post-hoc explainability (CF-GNNExplainer, T-GNNExplainer, CIDER diagnostics)
9
+ - Evaluation metrics (predictive, faithfulness, stability, chemical validity, causal)
10
+ - Visualization tools (saliency, prototypes, motifs, concepts, counterfactuals)
11
+ """
12
+
13
+ __version__ = "0.1.0"
14
+ __author__ = "Harshal Loya, Jash Chauhan, Het Gala"
15
+
16
+ from inter_gnn.models.core_model import InterGNN
17
+ from inter_gnn.training.config import InterGNNConfig
18
+
19
+ __all__ = [
20
+ "InterGNN",
21
+ "InterGNNConfig",
22
+ "__version__",
23
+ ]