chebai-graph 1.0.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. chebai_graph-1.0.0/PKG-INFO +14 -0
  2. chebai_graph-1.0.0/chebai_graph/__init__.py +0 -0
  3. chebai_graph-1.0.0/chebai_graph/loss/__init__.py +0 -0
  4. chebai_graph-1.0.0/chebai_graph/loss/pretraining.py +31 -0
  5. chebai_graph-1.0.0/chebai_graph/models/__init__.py +19 -0
  6. chebai_graph-1.0.0/chebai_graph/models/augmented.py +45 -0
  7. chebai_graph-1.0.0/chebai_graph/models/base.py +707 -0
  8. chebai_graph-1.0.0/chebai_graph/models/dynamic_gni.py +214 -0
  9. chebai_graph-1.0.0/chebai_graph/models/gat.py +94 -0
  10. chebai_graph-1.0.0/chebai_graph/models/gin_net.py +96 -0
  11. chebai_graph-1.0.0/chebai_graph/models/graph.py +373 -0
  12. chebai_graph-1.0.0/chebai_graph/models/resgated.py +109 -0
  13. chebai_graph-1.0.0/chebai_graph/preprocessing/__init__.py +0 -0
  14. chebai_graph-1.0.0/chebai_graph/preprocessing/bin/AtomCharge/indices_one_hot.txt +13 -0
  15. chebai_graph-1.0.0/chebai_graph/preprocessing/bin/AtomFunctionalGroup/indices_one_hot.txt +158 -0
  16. chebai_graph-1.0.0/chebai_graph/preprocessing/bin/AtomHybridization/indices_one_hot.txt +7 -0
  17. chebai_graph-1.0.0/chebai_graph/preprocessing/bin/AtomNodeLevel/indices_one_hot.txt +3 -0
  18. chebai_graph-1.0.0/chebai_graph/preprocessing/bin/AtomNumHs/indices_one_hot.txt +7 -0
  19. chebai_graph-1.0.0/chebai_graph/preprocessing/bin/AtomType/indices_one_hot.txt +119 -0
  20. chebai_graph-1.0.0/chebai_graph/preprocessing/bin/BondLevel/indices_one_hot.txt +4 -0
  21. chebai_graph-1.0.0/chebai_graph/preprocessing/bin/BondType/indices_one_hot.txt +5 -0
  22. chebai_graph-1.0.0/chebai_graph/preprocessing/bin/NumAtomBonds/indices_one_hot.txt +11 -0
  23. chebai_graph-1.0.0/chebai_graph/preprocessing/collate.py +78 -0
  24. chebai_graph-1.0.0/chebai_graph/preprocessing/datasets/__init__.py +36 -0
  25. chebai_graph-1.0.0/chebai_graph/preprocessing/datasets/chebi.py +734 -0
  26. chebai_graph-1.0.0/chebai_graph/preprocessing/datasets/pubchem.py +7 -0
  27. chebai_graph-1.0.0/chebai_graph/preprocessing/datasets/utils.py +43 -0
  28. chebai_graph-1.0.0/chebai_graph/preprocessing/fg_detection/__init__.py +0 -0
  29. chebai_graph-1.0.0/chebai_graph/preprocessing/fg_detection/fg_aware_rule_based.py +1936 -0
  30. chebai_graph-1.0.0/chebai_graph/preprocessing/fg_detection/fg_constants.py +14 -0
  31. chebai_graph-1.0.0/chebai_graph/preprocessing/properties/__init__.py +89 -0
  32. chebai_graph-1.0.0/chebai_graph/preprocessing/properties/augmented_properties.py +411 -0
  33. chebai_graph-1.0.0/chebai_graph/preprocessing/properties/base.py +455 -0
  34. chebai_graph-1.0.0/chebai_graph/preprocessing/properties/constants.py +13 -0
  35. chebai_graph-1.0.0/chebai_graph/preprocessing/properties/properties.py +299 -0
  36. chebai_graph-1.0.0/chebai_graph/preprocessing/property_encoder.py +292 -0
  37. chebai_graph-1.0.0/chebai_graph/preprocessing/reader/__init__.py +28 -0
  38. chebai_graph-1.0.0/chebai_graph/preprocessing/reader/augmented_reader.py +943 -0
  39. chebai_graph-1.0.0/chebai_graph/preprocessing/reader/reader.py +203 -0
  40. chebai_graph-1.0.0/chebai_graph/preprocessing/reader/static_gni.py +191 -0
  41. chebai_graph-1.0.0/chebai_graph/preprocessing/structures.py +42 -0
  42. chebai_graph-1.0.0/chebai_graph/preprocessing/transform_unlabeled.py +109 -0
  43. chebai_graph-1.0.0/pyproject.toml +34 -0
@@ -0,0 +1,14 @@
1
+ Metadata-Version: 2.4
2
+ Name: chebai-graph
3
+ Version: 1.0.0
4
+ Summary: GNNs for ChEB-AI
5
+ Author-email: Martin Glauer <martin.glauer@ovgu.de>
6
+ Requires-Python: >=3.8
7
+ Requires-Dist: chebai
8
+ Requires-Dist: descriptastorus
9
+ Requires-Dist: tox ; extra == "dev"
10
+ Requires-Dist: isort ; extra == "linters"
11
+ Requires-Dist: pre-commit ; extra == "linters"
12
+ Requires-Dist: black ; extra == "linters"
13
+ Provides-Extra: dev
14
+ Provides-Extra: linters
File without changes
File without changes
@@ -0,0 +1,31 @@
1
+ import torch
2
+
3
+
4
+ class MaskPretrainingLoss(torch.nn.Module):
5
+ # Mask atoms and edges, try to predict them (see Hu et al., 2020: Strategies for Pre-training Graph Neural Networks)
6
+ def __init__(self):
7
+ super().__init__()
8
+ self.ce = torch.nn.functional.binary_cross_entropy_with_logits
9
+
10
+ def forward(self, input, target, **loss_kwargs):
11
+ if isinstance(input, tuple):
12
+ atom_preds, bond_preds = input
13
+ atom_targets, bond_targets = target
14
+ try:
15
+ bond_loss = self.ce(bond_preds, bond_targets)
16
+ except RuntimeError as e:
17
+ print(f"Failed to compute bond loss: {e}")
18
+ print(f"Input: preds: {bond_preds.shape}, labels: {bond_targets.shape}")
19
+ bond_loss = 0
20
+ else:
21
+ atom_preds = input
22
+ atom_targets = target
23
+ bond_loss = 0
24
+ try:
25
+ atom_loss = self.ce(atom_preds, atom_targets)
26
+ except RuntimeError as e:
27
+ print(f"Failed to compute atom loss: {e}")
28
+ print(f"Input: preds: {atom_preds.shape}, labels: {atom_targets.shape}")
29
+ atom_loss = 0
30
+
31
+ return atom_loss + bond_loss
@@ -0,0 +1,19 @@
1
+ from .augmented import (
2
+ GATAugNodePoolGraphPred,
3
+ GATGraphNodeFGNodePoolGraphPred,
4
+ ResGatedAugNodePoolGraphPred,
5
+ ResGatedGraphNodeFGNodePoolGraphPred,
6
+ )
7
+ from .dynamic_gni import ResGatedDynamicGNIGraphPred
8
+ from .gat import GATGraphPred
9
+ from .resgated import ResGatedGraphPred
10
+
11
+ __all__ = [
12
+ "ResGatedGraphPred",
13
+ "ResGatedAugNodePoolGraphPred",
14
+ "ResGatedGraphNodeFGNodePoolGraphPred",
15
+ "GATGraphPred",
16
+ "GATAugNodePoolGraphPred",
17
+ "GATGraphNodeFGNodePoolGraphPred",
18
+ "ResGatedDynamicGNIGraphPred",
19
+ ]
@@ -0,0 +1,45 @@
1
+ from .base import AugmentedNodePoolingNet, GraphNodeFGNodePoolingNet
2
+ from .gat import GATGraphPred
3
+ from .resgated import ResGatedGraphPred
4
+
5
+
6
+ class ResGatedAugNodePoolGraphPred(AugmentedNodePoolingNet, ResGatedGraphPred):
7
+ """
8
+ Combines:
9
+ - AugmentedNodePoolingNet: Pools atom and augmented node embeddings (optionally with molecule attributes).
10
+ - ResGatedGraphPred: Residual gated network for final graph prediction.
11
+ """
12
+
13
+ ...
14
+
15
+
16
+ class GATAugNodePoolGraphPred(AugmentedNodePoolingNet, GATGraphPred):
17
+ """
18
+ Combines:
19
+ - AugmentedNodePoolingNet: Pools atom and augmented node embeddings (optionally with molecule attributes).
20
+ - GATGraphPred: Graph attention network for final graph prediction.
21
+ """
22
+
23
+ ...
24
+
25
+
26
+ class ResGatedGraphNodeFGNodePoolGraphPred(
27
+ GraphNodeFGNodePoolingNet, ResGatedGraphPred
28
+ ):
29
+ """
30
+ Combines:
31
+ - GraphNodeFGNodePoolingNet: Pools atom, functional group, and graph nodes (optionally with molecule attributes).
32
+ - ResGatedGraphPred: Residual gated network for final graph prediction.
33
+ """
34
+
35
+ ...
36
+
37
+
38
+ class GATGraphNodeFGNodePoolGraphPred(GraphNodeFGNodePoolingNet, GATGraphPred):
39
+ """
40
+ Combines:
41
+ - GraphNodeFGNodePoolingNet: Pools atom, functional group, and graph nodes (optionally with molecule attributes).
42
+ - GATGraphPred: Graph attention network for final graph prediction.
43
+ """
44
+
45
+ ...