chebai-graph 1.0.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- chebai_graph-1.0.0/PKG-INFO +14 -0
- chebai_graph-1.0.0/chebai_graph/__init__.py +0 -0
- chebai_graph-1.0.0/chebai_graph/loss/__init__.py +0 -0
- chebai_graph-1.0.0/chebai_graph/loss/pretraining.py +31 -0
- chebai_graph-1.0.0/chebai_graph/models/__init__.py +19 -0
- chebai_graph-1.0.0/chebai_graph/models/augmented.py +45 -0
- chebai_graph-1.0.0/chebai_graph/models/base.py +707 -0
- chebai_graph-1.0.0/chebai_graph/models/dynamic_gni.py +214 -0
- chebai_graph-1.0.0/chebai_graph/models/gat.py +94 -0
- chebai_graph-1.0.0/chebai_graph/models/gin_net.py +96 -0
- chebai_graph-1.0.0/chebai_graph/models/graph.py +373 -0
- chebai_graph-1.0.0/chebai_graph/models/resgated.py +109 -0
- chebai_graph-1.0.0/chebai_graph/preprocessing/__init__.py +0 -0
- chebai_graph-1.0.0/chebai_graph/preprocessing/bin/AtomCharge/indices_one_hot.txt +13 -0
- chebai_graph-1.0.0/chebai_graph/preprocessing/bin/AtomFunctionalGroup/indices_one_hot.txt +158 -0
- chebai_graph-1.0.0/chebai_graph/preprocessing/bin/AtomHybridization/indices_one_hot.txt +7 -0
- chebai_graph-1.0.0/chebai_graph/preprocessing/bin/AtomNodeLevel/indices_one_hot.txt +3 -0
- chebai_graph-1.0.0/chebai_graph/preprocessing/bin/AtomNumHs/indices_one_hot.txt +7 -0
- chebai_graph-1.0.0/chebai_graph/preprocessing/bin/AtomType/indices_one_hot.txt +119 -0
- chebai_graph-1.0.0/chebai_graph/preprocessing/bin/BondLevel/indices_one_hot.txt +4 -0
- chebai_graph-1.0.0/chebai_graph/preprocessing/bin/BondType/indices_one_hot.txt +5 -0
- chebai_graph-1.0.0/chebai_graph/preprocessing/bin/NumAtomBonds/indices_one_hot.txt +11 -0
- chebai_graph-1.0.0/chebai_graph/preprocessing/collate.py +78 -0
- chebai_graph-1.0.0/chebai_graph/preprocessing/datasets/__init__.py +36 -0
- chebai_graph-1.0.0/chebai_graph/preprocessing/datasets/chebi.py +734 -0
- chebai_graph-1.0.0/chebai_graph/preprocessing/datasets/pubchem.py +7 -0
- chebai_graph-1.0.0/chebai_graph/preprocessing/datasets/utils.py +43 -0
- chebai_graph-1.0.0/chebai_graph/preprocessing/fg_detection/__init__.py +0 -0
- chebai_graph-1.0.0/chebai_graph/preprocessing/fg_detection/fg_aware_rule_based.py +1936 -0
- chebai_graph-1.0.0/chebai_graph/preprocessing/fg_detection/fg_constants.py +14 -0
- chebai_graph-1.0.0/chebai_graph/preprocessing/properties/__init__.py +89 -0
- chebai_graph-1.0.0/chebai_graph/preprocessing/properties/augmented_properties.py +411 -0
- chebai_graph-1.0.0/chebai_graph/preprocessing/properties/base.py +455 -0
- chebai_graph-1.0.0/chebai_graph/preprocessing/properties/constants.py +13 -0
- chebai_graph-1.0.0/chebai_graph/preprocessing/properties/properties.py +299 -0
- chebai_graph-1.0.0/chebai_graph/preprocessing/property_encoder.py +292 -0
- chebai_graph-1.0.0/chebai_graph/preprocessing/reader/__init__.py +28 -0
- chebai_graph-1.0.0/chebai_graph/preprocessing/reader/augmented_reader.py +943 -0
- chebai_graph-1.0.0/chebai_graph/preprocessing/reader/reader.py +203 -0
- chebai_graph-1.0.0/chebai_graph/preprocessing/reader/static_gni.py +191 -0
- chebai_graph-1.0.0/chebai_graph/preprocessing/structures.py +42 -0
- chebai_graph-1.0.0/chebai_graph/preprocessing/transform_unlabeled.py +109 -0
- chebai_graph-1.0.0/pyproject.toml +34 -0
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: chebai-graph
|
|
3
|
+
Version: 1.0.0
|
|
4
|
+
Summary: GNNs for ChEB-AI
|
|
5
|
+
Author-email: Martin Glauer <martin.glauer@ovgu.de>
|
|
6
|
+
Requires-Python: >=3.8
|
|
7
|
+
Requires-Dist: chebai
|
|
8
|
+
Requires-Dist: descriptastorus
|
|
9
|
+
Requires-Dist: tox ; extra == "dev"
|
|
10
|
+
Requires-Dist: isort ; extra == "linters"
|
|
11
|
+
Requires-Dist: pre-commit ; extra == "linters"
|
|
12
|
+
Requires-Dist: black ; extra == "linters"
|
|
13
|
+
Provides-Extra: dev
|
|
14
|
+
Provides-Extra: linters
|
|
File without changes
|
|
File without changes
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class MaskPretrainingLoss(torch.nn.Module):
|
|
5
|
+
# Mask atoms and edges, try to predict them (see Hu et al., 2020: Strategies for Pre-training Graph Neural Networks)
|
|
6
|
+
def __init__(self):
|
|
7
|
+
super().__init__()
|
|
8
|
+
self.ce = torch.nn.functional.binary_cross_entropy_with_logits
|
|
9
|
+
|
|
10
|
+
def forward(self, input, target, **loss_kwargs):
|
|
11
|
+
if isinstance(input, tuple):
|
|
12
|
+
atom_preds, bond_preds = input
|
|
13
|
+
atom_targets, bond_targets = target
|
|
14
|
+
try:
|
|
15
|
+
bond_loss = self.ce(bond_preds, bond_targets)
|
|
16
|
+
except RuntimeError as e:
|
|
17
|
+
print(f"Failed to compute bond loss: {e}")
|
|
18
|
+
print(f"Input: preds: {bond_preds.shape}, labels: {bond_targets.shape}")
|
|
19
|
+
bond_loss = 0
|
|
20
|
+
else:
|
|
21
|
+
atom_preds = input
|
|
22
|
+
atom_targets = target
|
|
23
|
+
bond_loss = 0
|
|
24
|
+
try:
|
|
25
|
+
atom_loss = self.ce(atom_preds, atom_targets)
|
|
26
|
+
except RuntimeError as e:
|
|
27
|
+
print(f"Failed to compute atom loss: {e}")
|
|
28
|
+
print(f"Input: preds: {atom_preds.shape}, labels: {atom_targets.shape}")
|
|
29
|
+
atom_loss = 0
|
|
30
|
+
|
|
31
|
+
return atom_loss + bond_loss
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
from .augmented import (
|
|
2
|
+
GATAugNodePoolGraphPred,
|
|
3
|
+
GATGraphNodeFGNodePoolGraphPred,
|
|
4
|
+
ResGatedAugNodePoolGraphPred,
|
|
5
|
+
ResGatedGraphNodeFGNodePoolGraphPred,
|
|
6
|
+
)
|
|
7
|
+
from .dynamic_gni import ResGatedDynamicGNIGraphPred
|
|
8
|
+
from .gat import GATGraphPred
|
|
9
|
+
from .resgated import ResGatedGraphPred
|
|
10
|
+
|
|
11
|
+
__all__ = [
|
|
12
|
+
"ResGatedGraphPred",
|
|
13
|
+
"ResGatedAugNodePoolGraphPred",
|
|
14
|
+
"ResGatedGraphNodeFGNodePoolGraphPred",
|
|
15
|
+
"GATGraphPred",
|
|
16
|
+
"GATAugNodePoolGraphPred",
|
|
17
|
+
"GATGraphNodeFGNodePoolGraphPred",
|
|
18
|
+
"ResGatedDynamicGNIGraphPred",
|
|
19
|
+
]
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
from .base import AugmentedNodePoolingNet, GraphNodeFGNodePoolingNet
|
|
2
|
+
from .gat import GATGraphPred
|
|
3
|
+
from .resgated import ResGatedGraphPred
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class ResGatedAugNodePoolGraphPred(AugmentedNodePoolingNet, ResGatedGraphPred):
|
|
7
|
+
"""
|
|
8
|
+
Combines:
|
|
9
|
+
- AugmentedNodePoolingNet: Pools atom and augmented node embeddings (optionally with molecule attributes).
|
|
10
|
+
- ResGatedGraphPred: Residual gated network for final graph prediction.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
...
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class GATAugNodePoolGraphPred(AugmentedNodePoolingNet, GATGraphPred):
|
|
17
|
+
"""
|
|
18
|
+
Combines:
|
|
19
|
+
- AugmentedNodePoolingNet: Pools atom and augmented node embeddings (optionally with molecule attributes).
|
|
20
|
+
- GATGraphPred: Graph attention network for final graph prediction.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
...
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class ResGatedGraphNodeFGNodePoolGraphPred(
|
|
27
|
+
GraphNodeFGNodePoolingNet, ResGatedGraphPred
|
|
28
|
+
):
|
|
29
|
+
"""
|
|
30
|
+
Combines:
|
|
31
|
+
- GraphNodeFGNodePoolingNet: Pools atom, functional group, and graph nodes (optionally with molecule attributes).
|
|
32
|
+
- ResGatedGraphPred: Residual gated network for final graph prediction.
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
...
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class GATGraphNodeFGNodePoolGraphPred(GraphNodeFGNodePoolingNet, GATGraphPred):
|
|
39
|
+
"""
|
|
40
|
+
Combines:
|
|
41
|
+
- GraphNodeFGNodePoolingNet: Pools atom, functional group, and graph nodes (optionally with molecule attributes).
|
|
42
|
+
- GATGraphPred: Graph attention network for final graph prediction.
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
...
|