alphagrammar 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. alphagrammar-0.1.0/LICENSE +21 -0
  2. alphagrammar-0.1.0/PKG-INFO +131 -0
  3. alphagrammar-0.1.0/README.md +91 -0
  4. alphagrammar-0.1.0/pyproject.toml +38 -0
  5. alphagrammar-0.1.0/setup.cfg +4 -0
  6. alphagrammar-0.1.0/src/alphagrammar/GCN/__init__.py +1 -0
  7. alphagrammar-0.1.0/src/alphagrammar/GCN/batch.py +228 -0
  8. alphagrammar-0.1.0/src/alphagrammar/GCN/feature_extract.py +38 -0
  9. alphagrammar-0.1.0/src/alphagrammar/GCN/loader.py +1329 -0
  10. alphagrammar-0.1.0/src/alphagrammar/GCN/model.py +474 -0
  11. alphagrammar-0.1.0/src/alphagrammar/GCN/util.py +421 -0
  12. alphagrammar-0.1.0/src/alphagrammar/__init__.py +15 -0
  13. alphagrammar-0.1.0/src/alphagrammar/agent.py +304 -0
  14. alphagrammar-0.1.0/src/alphagrammar/cli.py +37 -0
  15. alphagrammar-0.1.0/src/alphagrammar/commands/__init__.py +1 -0
  16. alphagrammar-0.1.0/src/alphagrammar/commands/parse.py +195 -0
  17. alphagrammar-0.1.0/src/alphagrammar/core.py +449 -0
  18. alphagrammar-0.1.0/src/alphagrammar/data/.gitkeep +0 -0
  19. alphagrammar-0.1.0/src/alphagrammar/data/README.txt +28 -0
  20. alphagrammar-0.1.0/src/alphagrammar/data/best_agent.pkl +0 -0
  21. alphagrammar-0.1.0/src/alphagrammar/data/supervised_contextpred.pth +0 -0
  22. alphagrammar-0.1.0/src/alphagrammar/data/vocab.pkl +0 -0
  23. alphagrammar-0.1.0/src/alphagrammar/fuseprop/__init__.py +5 -0
  24. alphagrammar-0.1.0/src/alphagrammar/fuseprop/chemutils.py +572 -0
  25. alphagrammar-0.1.0/src/alphagrammar/fuseprop/dataset.py +88 -0
  26. alphagrammar-0.1.0/src/alphagrammar/fuseprop/decoder.py +250 -0
  27. alphagrammar-0.1.0/src/alphagrammar/fuseprop/encoder.py +72 -0
  28. alphagrammar-0.1.0/src/alphagrammar/fuseprop/gnn.py +70 -0
  29. alphagrammar-0.1.0/src/alphagrammar/fuseprop/inc_graph.py +156 -0
  30. alphagrammar-0.1.0/src/alphagrammar/fuseprop/mol_graph.py +157 -0
  31. alphagrammar-0.1.0/src/alphagrammar/fuseprop/nnutils.py +64 -0
  32. alphagrammar-0.1.0/src/alphagrammar/fuseprop/rnn.py +124 -0
  33. alphagrammar-0.1.0/src/alphagrammar/fuseprop/vocab.py +70 -0
  34. alphagrammar-0.1.0/src/alphagrammar/grammar_generation.py +614 -0
  35. alphagrammar-0.1.0/src/alphagrammar/hrg_td_parser_undirected.py +1585 -0
  36. alphagrammar-0.1.0/src/alphagrammar/private/__init__.py +8 -0
  37. alphagrammar-0.1.0/src/alphagrammar/private/grammar.py +1200 -0
  38. alphagrammar-0.1.0/src/alphagrammar/private/hypergraph.py +1212 -0
  39. alphagrammar-0.1.0/src/alphagrammar/private/metrics.py +171 -0
  40. alphagrammar-0.1.0/src/alphagrammar/private/molecule_graph.py +335 -0
  41. alphagrammar-0.1.0/src/alphagrammar/private/rule_stats.py +74 -0
  42. alphagrammar-0.1.0/src/alphagrammar/private/subgraph_set.py +51 -0
  43. alphagrammar-0.1.0/src/alphagrammar/private/symbol.py +172 -0
  44. alphagrammar-0.1.0/src/alphagrammar/private/utils.py +101 -0
  45. alphagrammar-0.1.0/src/alphagrammar.egg-info/PKG-INFO +131 -0
  46. alphagrammar-0.1.0/src/alphagrammar.egg-info/SOURCES.txt +48 -0
  47. alphagrammar-0.1.0/src/alphagrammar.egg-info/dependency_links.txt +1 -0
  48. alphagrammar-0.1.0/src/alphagrammar.egg-info/entry_points.txt +2 -0
  49. alphagrammar-0.1.0/src/alphagrammar.egg-info/requires.txt +8 -0
  50. alphagrammar-0.1.0/src/alphagrammar.egg-info/top_level.txt +1 -0
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2022 Minghao Guo
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,131 @@
1
+ Metadata-Version: 2.1
2
+ Name: alphagrammar
3
+ Version: 0.1.0
4
+ Summary: AlphaGrammar: Grammar-based molecular representation learning
5
+ Author-email: Michael Sun <msun415@mit.edu>
6
+ License: MIT License
7
+
8
+ Copyright (c) 2022 Minghao Guo
9
+
10
+ Permission is hereby granted, free of charge, to any person obtaining a copy
11
+ of this software and associated documentation files (the "Software"), to deal
12
+ in the Software without restriction, including without limitation the rights
13
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
14
+ copies of the Software, and to permit persons to whom the Software is
15
+ furnished to do so, subject to the following conditions:
16
+
17
+ The above copyright notice and this permission notice shall be included in all
18
+ copies or substantial portions of the Software.
19
+
20
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
21
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
22
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
23
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
24
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
25
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
26
+ SOFTWARE.
27
+
28
+ Keywords: molecule,grammar,SMILES,representation-learning,chemistry
29
+ Requires-Python: >=3.8
30
+ Description-Content-Type: text/markdown
31
+ License-File: LICENSE
32
+ Requires-Dist: numpy
33
+ Requires-Dist: pandas
34
+ Requires-Dist: torch
35
+ Requires-Dist: rdkit
36
+ Requires-Dist: networkx
37
+ Requires-Dist: matplotlib
38
+ Requires-Dist: tqdm
39
+ Requires-Dist: torch_geometric
40
+
41
+ # AlphaGrammar
42
+
43
+ **AlphaGrammar** is a grammar-based molecular representation learning framework.
44
+ It learns a hyperedge-replacement grammar (HRG) over molecular graphs via
45
+ Monte Carlo Tree Search (MCMC) with a learned agent.
46
+
47
+ ## Installation
48
+
49
+ ```bash
50
+ pip install -e /path/to/AlphaGrammar_pkg
51
+ ```
52
+
53
+ Or from PyPI (once published):
54
+
55
+ ```bash
56
+ pip install alphagrammar
57
+ ```
58
+
59
+ ### Pretrained models
60
+
61
+ After installation, copy the pretrained model files to the package data directory:
62
+
63
+ ```bash
64
+ DATA_DIR=$(python -c "import alphagrammar, os; print(os.path.join(os.path.dirname(alphagrammar.__file__), 'data'))")
65
+ cp /path/to/AlphaGrammar/ckpts/vocab_epoch5.pkl $DATA_DIR/
66
+ cp /path/to/AlphaGrammar/ckpts/best_agent_epoch0_R0.0000.pkl $DATA_DIR/
67
+ cp /path/to/AlphaGrammar/GCN/supervised_contextpred.pth $DATA_DIR/
68
+ ```
69
+
70
+ ## Usage
71
+
72
+ ### Command-line interface
73
+
74
+ ```bash
75
+ # Parse a single SMILES string (rollout mode)
76
+ alphagrammar parse "CCO"
77
+
78
+ # Parse from a file (one SMILES per line)
79
+ alphagrammar parse molecules.smi
80
+
81
+ # Parse with Bolinas parser (10-second timeout per molecule)
82
+ alphagrammar parse "CCO" --timeout 10
83
+
84
+ # Use only top-100 rules from vocab
85
+ alphagrammar parse "CCO" --vocab_size 100
86
+ ```
87
+
88
+ ### Python API
89
+
90
+ ```python
91
+ from alphagrammar import MoleculeDataset, _collate_mol_batch, RuleStats
92
+ from alphagrammar.grammar_generation import MCMC_sampling
93
+ from alphagrammar.agent import Agent
94
+ import torch, pickle
95
+
96
+ # Load pretrained models
97
+ with open("data/vocab_epoch5.pkl", "rb") as f:
98
+ rule_stats = pickle.load(f)
99
+
100
+ agent = Agent(feat_dim=300, hidden_size=256)
101
+ agent.load_state_dict(torch.load("data/best_agent_epoch0_R0.0000.pkl"))
102
+ agent.eval()
103
+
104
+ # Build input dataset
105
+ dataset = MoleculeDataset(["CCO", "c1ccccc1"], GNN_model_path="data/supervised_contextpred.pth")
106
+ batch = [dataset[i] for i in range(len(dataset))]
107
+ input_graphs_dict = _collate_mol_batch(batch)
108
+
109
+ # Run MCMC sampling
110
+ results, rules_per_mol, sequential_steps = MCMC_sampling(
111
+ "output_dir", agent, input_graphs_dict, MCMC_size=1, debug=True
112
+ )
113
+ ```
114
+
115
+ ## Package structure
116
+
117
+ ```
118
+ src/alphagrammar/
119
+ ├── __init__.py # Public API
120
+ ├── cli.py # argparse CLI entry point
121
+ ├── core.py # Core functions (bolinas_evaluate, MoleculeDataset, RuleStats, ...)
122
+ ├── grammar_generation.py# MCMC sampling and grammar generation
123
+ ├── agent.py # Neural agent (policy network)
124
+ ├── hrg_td_parser_undirected.py # Bolinas HRG parser
125
+ ├── private/ # Internal hypergraph / grammar data structures
126
+ ├── fuseprop/ # Molecular fragmentation utilities
127
+ ├── GCN/ # Graph neural network feature extraction
128
+ ├── data/ # Pretrained model files (copy here after install)
129
+ └── commands/
130
+ └── parse.py # 'alphagrammar parse' subcommand
131
+ ```
@@ -0,0 +1,91 @@
1
+ # AlphaGrammar
2
+
3
+ **AlphaGrammar** is a grammar-based molecular representation learning framework.
4
+ It learns a hyperedge-replacement grammar (HRG) over molecular graphs via
5
+ Monte Carlo Tree Search (MCMC) with a learned agent.
6
+
7
+ ## Installation
8
+
9
+ ```bash
10
+ pip install -e /path/to/AlphaGrammar_pkg
11
+ ```
12
+
13
+ Or from PyPI (once published):
14
+
15
+ ```bash
16
+ pip install alphagrammar
17
+ ```
18
+
19
+ ### Pretrained models
20
+
21
+ After installation, copy the pretrained model files to the package data directory:
22
+
23
+ ```bash
24
+ DATA_DIR=$(python -c "import alphagrammar, os; print(os.path.join(os.path.dirname(alphagrammar.__file__), 'data'))")
25
+ cp /path/to/AlphaGrammar/ckpts/vocab_epoch5.pkl $DATA_DIR/
26
+ cp /path/to/AlphaGrammar/ckpts/best_agent_epoch0_R0.0000.pkl $DATA_DIR/
27
+ cp /path/to/AlphaGrammar/GCN/supervised_contextpred.pth $DATA_DIR/
28
+ ```
29
+
30
+ ## Usage
31
+
32
+ ### Command-line interface
33
+
34
+ ```bash
35
+ # Parse a single SMILES string (rollout mode)
36
+ alphagrammar parse "CCO"
37
+
38
+ # Parse from a file (one SMILES per line)
39
+ alphagrammar parse molecules.smi
40
+
41
+ # Parse with Bolinas parser (10-second timeout per molecule)
42
+ alphagrammar parse "CCO" --timeout 10
43
+
44
+ # Use only top-100 rules from vocab
45
+ alphagrammar parse "CCO" --vocab_size 100
46
+ ```
47
+
48
+ ### Python API
49
+
50
+ ```python
51
+ from alphagrammar import MoleculeDataset, _collate_mol_batch, RuleStats
52
+ from alphagrammar.grammar_generation import MCMC_sampling
53
+ from alphagrammar.agent import Agent
54
+ import torch, pickle
55
+
56
+ # Load pretrained models
57
+ with open("data/vocab_epoch5.pkl", "rb") as f:
58
+ rule_stats = pickle.load(f)
59
+
60
+ agent = Agent(feat_dim=300, hidden_size=256)
61
+ agent.load_state_dict(torch.load("data/best_agent_epoch0_R0.0000.pkl"))
62
+ agent.eval()
63
+
64
+ # Build input dataset
65
+ dataset = MoleculeDataset(["CCO", "c1ccccc1"], GNN_model_path="data/supervised_contextpred.pth")
66
+ batch = [dataset[i] for i in range(len(dataset))]
67
+ input_graphs_dict = _collate_mol_batch(batch)
68
+
69
+ # Run MCMC sampling
70
+ results, rules_per_mol, sequential_steps = MCMC_sampling(
71
+ "output_dir", agent, input_graphs_dict, MCMC_size=1, debug=True
72
+ )
73
+ ```
74
+
75
+ ## Package structure
76
+
77
+ ```
78
+ src/alphagrammar/
79
+ ├── __init__.py # Public API
80
+ ├── cli.py # argparse CLI entry point
81
+ ├── core.py # Core functions (bolinas_evaluate, MoleculeDataset, RuleStats, ...)
82
+ ├── grammar_generation.py# MCMC sampling and grammar generation
83
+ ├── agent.py # Neural agent (policy network)
84
+ ├── hrg_td_parser_undirected.py # Bolinas HRG parser
85
+ ├── private/ # Internal hypergraph / grammar data structures
86
+ ├── fuseprop/ # Molecular fragmentation utilities
87
+ ├── GCN/ # Graph neural network feature extraction
88
+ ├── data/ # Pretrained model files (copy here after install)
89
+ └── commands/
90
+ └── parse.py # 'alphagrammar parse' subcommand
91
+ ```
@@ -0,0 +1,38 @@
1
+ [build-system]
2
+ requires = ["setuptools>=68", "wheel"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "alphagrammar"
7
+ version = "0.1.0"
8
+ description = "AlphaGrammar: Grammar-based molecular representation learning"
9
+ readme = "README.md"
10
+ requires-python = ">=3.8"
11
+ license = {file = "LICENSE"}
12
+ authors = [
13
+ { name = "Michael Sun", email = "msun415@mit.edu" }
14
+ ]
15
+ keywords = ["molecule", "grammar", "SMILES", "representation-learning", "chemistry"]
16
+ dependencies = [
17
+ "numpy",
18
+ "pandas",
19
+ "torch",
20
+ "rdkit",
21
+ "networkx",
22
+ "matplotlib",
23
+ "tqdm",
24
+ "torch_geometric",
25
+ ]
26
+
27
+ [project.scripts]
28
+ alphagrammar = "alphagrammar.cli:main"
29
+
30
+ [tool.setuptools]
31
+ package-dir = {"" = "src"}
32
+
33
+ [tool.setuptools.packages.find]
34
+ where = ["src"]
35
+ include = ["alphagrammar*"]
36
+
37
+ [tool.setuptools.package-data]
38
+ alphagrammar = ["data/*", "data/.gitkeep"]
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
@@ -0,0 +1 @@
1
+ # GCN subpackage for AlphaGrammar
@@ -0,0 +1,228 @@
1
+ import torch
2
+ from torch_geometric.data import Data, Batch
3
+
4
+ class BatchMasking(Data):
5
+ r"""A plain old python object modeling a batch of graphs as one big
6
+ (dicconnected) graph. With :class:`torch_geometric.data.Data` being the
7
+ base class, all its methods can also be used here.
8
+ In addition, single graphs can be reconstructed via the assignment vector
9
+ :obj:`batch`, which maps each node to its respective graph identifier.
10
+ """
11
+
12
+ def __init__(self, batch=None, **kwargs):
13
+ super(BatchMasking, self).__init__(**kwargs)
14
+ self.batch = batch
15
+
16
+ @staticmethod
17
+ def from_data_list(data_list):
18
+ r"""Constructs a batch object from a python list holding
19
+ :class:`torch_geometric.data.Data` objects.
20
+ The assignment vector :obj:`batch` is created on the fly."""
21
+ keys = [set(data.keys) for data in data_list]
22
+ keys = list(set.union(*keys))
23
+ assert 'batch' not in keys
24
+
25
+ batch = BatchMasking()
26
+
27
+ for key in keys:
28
+ batch[key] = []
29
+ batch.batch = []
30
+
31
+ cumsum_node = 0
32
+ cumsum_edge = 0
33
+
34
+ for i, data in enumerate(data_list):
35
+ num_nodes = data.num_nodes
36
+ batch.batch.append(torch.full((num_nodes, ), i, dtype=torch.long))
37
+ for key in data.keys:
38
+ item = data[key]
39
+ if key in ['edge_index', 'masked_atom_indices']:
40
+ item = item + cumsum_node
41
+ elif key == 'connected_edge_indices':
42
+ item = item + cumsum_edge
43
+ batch[key].append(item)
44
+
45
+ cumsum_node += num_nodes
46
+ cumsum_edge += data.edge_index.shape[1]
47
+
48
+ for key in keys:
49
+ batch[key] = torch.cat(
50
+ batch[key], dim=data_list[0].cat_dim(key, batch[key][0]))
51
+ batch.batch = torch.cat(batch.batch, dim=-1)
52
+ return batch.contiguous()
53
+
54
+ def cumsum(self, key, item):
55
+ r"""If :obj:`True`, the attribute :obj:`key` with content :obj:`item`
56
+ should be added up cumulatively before concatenated together.
57
+ .. note::
58
+ This method is for internal use only, and should only be overridden
59
+ if the batch concatenation process is corrupted for a specific data
60
+ attribute.
61
+ """
62
+ return key in ['edge_index', 'face', 'masked_atom_indices', 'connected_edge_indices']
63
+
64
+ @property
65
+ def num_graphs(self):
66
+ """Returns the number of graphs in the batch."""
67
+ return self.batch[-1].item() + 1
68
+
69
+ class BatchAE(Data):
70
+ r"""A plain old python object modeling a batch of graphs as one big
71
+ (dicconnected) graph. With :class:`torch_geometric.data.Data` being the
72
+ base class, all its methods can also be used here.
73
+ In addition, single graphs can be reconstructed via the assignment vector
74
+ :obj:`batch`, which maps each node to its respective graph identifier.
75
+ """
76
+
77
+ def __init__(self, batch=None, **kwargs):
78
+ super(BatchAE, self).__init__(**kwargs)
79
+ self.batch = batch
80
+
81
+ @staticmethod
82
+ def from_data_list(data_list):
83
+ r"""Constructs a batch object from a python list holding
84
+ :class:`torch_geometric.data.Data` objects.
85
+ The assignment vector :obj:`batch` is created on the fly."""
86
+ keys = [set(data.keys) for data in data_list]
87
+ keys = list(set.union(*keys))
88
+ assert 'batch' not in keys
89
+
90
+ batch = BatchAE()
91
+
92
+ for key in keys:
93
+ batch[key] = []
94
+ batch.batch = []
95
+
96
+ cumsum_node = 0
97
+
98
+ for i, data in enumerate(data_list):
99
+ num_nodes = data.num_nodes
100
+ batch.batch.append(torch.full((num_nodes, ), i, dtype=torch.long))
101
+ for key in data.keys:
102
+ item = data[key]
103
+ if key in ['edge_index', 'negative_edge_index']:
104
+ item = item + cumsum_node
105
+ batch[key].append(item)
106
+
107
+ cumsum_node += num_nodes
108
+
109
+ for key in keys:
110
+ batch[key] = torch.cat(
111
+ batch[key], dim=batch.cat_dim(key))
112
+ batch.batch = torch.cat(batch.batch, dim=-1)
113
+ return batch.contiguous()
114
+
115
+ @property
116
+ def num_graphs(self):
117
+ """Returns the number of graphs in the batch."""
118
+ return self.batch[-1].item() + 1
119
+
120
+ def cat_dim(self, key):
121
+ return -1 if key in ["edge_index", "negative_edge_index"] else 0
122
+
123
+
124
+ class BatchSubstructContext(Data):
125
+ r"""A plain old python object modeling a batch of graphs as one big
126
+ (dicconnected) graph. With :class:`torch_geometric.data.Data` being the
127
+ base class, all its methods can also be used here.
128
+ In addition, single graphs can be reconstructed via the assignment vector
129
+ :obj:`batch`, which maps each node to its respective graph identifier.
130
+ """
131
+
132
+ """
133
+ Specialized batching for substructure context pair!
134
+ """
135
+
136
+ def __init__(self, batch=None, **kwargs):
137
+ super(BatchSubstructContext, self).__init__(**kwargs)
138
+ self.batch = batch
139
+
140
+ @staticmethod
141
+ def from_data_list(data_list):
142
+ r"""Constructs a batch object from a python list holding
143
+ :class:`torch_geometric.data.Data` objects.
144
+ The assignment vector :obj:`batch` is created on the fly."""
145
+ #keys = [set(data.keys) for data in data_list]
146
+ #keys = list(set.union(*keys))
147
+ #assert 'batch' not in keys
148
+
149
+ batch = BatchSubstructContext()
150
+ keys = ["center_substruct_idx", "edge_attr_substruct", "edge_index_substruct", "x_substruct", "overlap_context_substruct_idx", "edge_attr_context", "edge_index_context", "x_context"]
151
+
152
+ for key in keys:
153
+ #print(key)
154
+ batch[key] = []
155
+
156
+ #batch.batch = []
157
+ #used for pooling the context
158
+ batch.batch_overlapped_context = []
159
+ batch.overlapped_context_size = []
160
+
161
+ cumsum_main = 0
162
+ cumsum_substruct = 0
163
+ cumsum_context = 0
164
+
165
+ i = 0
166
+
167
+ for data in data_list:
168
+ #If there is no context, just skip!!
169
+ if hasattr(data, "x_context"):
170
+ num_nodes = data.num_nodes
171
+ num_nodes_substruct = len(data.x_substruct)
172
+ num_nodes_context = len(data.x_context)
173
+
174
+ #batch.batch.append(torch.full((num_nodes, ), i, dtype=torch.long))
175
+ batch.batch_overlapped_context.append(torch.full((len(data.overlap_context_substruct_idx), ), i, dtype=torch.long))
176
+ batch.overlapped_context_size.append(len(data.overlap_context_substruct_idx))
177
+
178
+ ###batching for the main graph
179
+ #for key in data.keys:
180
+ # if not "context" in key and not "substruct" in key:
181
+ # item = data[key]
182
+ # item = item + cumsum_main if batch.cumsum(key, item) else item
183
+ # batch[key].append(item)
184
+
185
+ ###batching for the substructure graph
186
+ for key in ["center_substruct_idx", "edge_attr_substruct", "edge_index_substruct", "x_substruct"]:
187
+ item = data[key]
188
+ item = item + cumsum_substruct if batch.cumsum(key, item) else item
189
+ batch[key].append(item)
190
+
191
+
192
+ ###batching for the context graph
193
+ for key in ["overlap_context_substruct_idx", "edge_attr_context", "edge_index_context", "x_context"]:
194
+ item = data[key]
195
+ item = item + cumsum_context if batch.cumsum(key, item) else item
196
+ batch[key].append(item)
197
+
198
+ cumsum_main += num_nodes
199
+ cumsum_substruct += num_nodes_substruct
200
+ cumsum_context += num_nodes_context
201
+ i += 1
202
+
203
+ for key in keys:
204
+ batch[key] = torch.cat(
205
+ batch[key], dim=batch.cat_dim(key))
206
+ #batch.batch = torch.cat(batch.batch, dim=-1)
207
+ batch.batch_overlapped_context = torch.cat(batch.batch_overlapped_context, dim=-1)
208
+ batch.overlapped_context_size = torch.LongTensor(batch.overlapped_context_size)
209
+
210
+ return batch.contiguous()
211
+
212
+ def cat_dim(self, key):
213
+ return -1 if key in ["edge_index", "edge_index_substruct", "edge_index_context"] else 0
214
+
215
+ def cumsum(self, key, item):
216
+ r"""If :obj:`True`, the attribute :obj:`key` with content :obj:`item`
217
+ should be added up cumulatively before concatenated together.
218
+ .. note::
219
+ This method is for internal use only, and should only be overridden
220
+ if the batch concatenation process is corrupted for a specific data
221
+ attribute.
222
+ """
223
+ return key in ["edge_index", "edge_index_substruct", "edge_index_context", "overlap_context_substruct_idx", "center_substruct_idx"]
224
+
225
+ @property
226
+ def num_graphs(self):
227
+ """Returns the number of graphs in the batch."""
228
+ return self.batch[-1].item() + 1
@@ -0,0 +1,38 @@
1
+ import argparse
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import torch.optim as optim
6
+ from rdkit import Chem
7
+ from tqdm import tqdm
8
+ import numpy as np
9
+ # from torch_geometric.nn import global_add_pool, global_mean_pool, global_max_pool
10
+ from .model import GNN, GNN_feature
11
+ from .loader import mol_to_graph_data_obj_simple
12
+
13
+
14
+ class feature_extractor():
15
+ def __init__(self, pretrained_model_path):
16
+ self.pretrained_model_path = pretrained_model_path
17
+
18
+ def preprocessing(self, graph_mol):
19
+ # rdkit_mol = AllChem.MolFromSmiles(graph_sml)
20
+ # kekulize the molecule to distinguish the aromatic bonds
21
+ data = mol_to_graph_data_obj_simple(Chem.MolFromSmiles(Chem.MolToSmiles(graph_mol)))
22
+ # print("data.x, data.edge_index, data.edge_attr:", data.x, data.edge_index, data.edge_attr)
23
+ # print(Chem.MolToSmiles(graph_mol))
24
+ # import pdb; pdb.set_trace()
25
+ return data
26
+
27
+ def extract(self, graph_mol):
28
+ model = GNN_feature(num_layer=5, emb_dim=300, num_tasks=1, JK='last', drop_ratio=0, graph_pooling='mean', gnn_type='gin')
29
+ model.from_pretrained(self.pretrained_model_path)
30
+ # model.cuda(device=0)
31
+ model.eval()
32
+ graph_data = self.preprocessing(graph_mol)
33
+ # graph_data = graph_data.cuda(device=0)
34
+ with torch.no_grad():
35
+ node_features = model(graph_data.x, graph_data.edge_index, graph_data.edge_attr)
36
+ del model
37
+ return node_features
38
+