mofclassifier 0.1.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mofclassifier-0.1.2/MOFClassifier/CLscore.py +459 -0
- mofclassifier-0.1.2/MOFClassifier/__init__.py +1 -0
- mofclassifier-0.1.2/PKG-INFO +67 -0
- mofclassifier-0.1.2/README.md +45 -0
- mofclassifier-0.1.2/mofclassifier.egg-info/PKG-INFO +67 -0
- mofclassifier-0.1.2/mofclassifier.egg-info/SOURCES.txt +9 -0
- mofclassifier-0.1.2/mofclassifier.egg-info/dependency_links.txt +1 -0
- mofclassifier-0.1.2/mofclassifier.egg-info/requires.txt +7 -0
- mofclassifier-0.1.2/mofclassifier.egg-info/top_level.txt +1 -0
- mofclassifier-0.1.2/setup.cfg +4 -0
- mofclassifier-0.1.2/setup.py +32 -0
|
@@ -0,0 +1,459 @@
|
|
|
1
|
+
from __future__ import print_function, division
|
|
2
|
+
|
|
3
|
+
import io, os, json, warnings, argparse, requests, zipfile
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
|
|
7
|
+
from tqdm import tqdm
|
|
8
|
+
|
|
9
|
+
from pymatgen.core.structure import Structure
|
|
10
|
+
from ase.io import read, write
|
|
11
|
+
from sklearn import metrics
|
|
12
|
+
|
|
13
|
+
import torch
|
|
14
|
+
import torch.nn as nn
|
|
15
|
+
from torch.autograd import Variable
|
|
16
|
+
from torch.utils.data import DataLoader
|
|
17
|
+
|
|
18
|
+
warnings.filterwarnings(
|
|
19
|
+
"ignore",
|
|
20
|
+
category=UserWarning,
|
|
21
|
+
module="pymatgen.io.cif"
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
package_directory = os.path.abspath(os.path.dirname(__file__))
|
|
25
|
+
models_dir = os.path.join(package_directory, "models")
|
|
26
|
+
|
|
27
|
+
if os.path.isdir(models_dir) and os.listdir(models_dir):
|
|
28
|
+
pass
|
|
29
|
+
else:
|
|
30
|
+
os.makedirs(models_dir, exist_ok=True)
|
|
31
|
+
zip_url = "https://github.com/Chung-Research-Group/MOFClassifier/archive/refs/heads/main.zip"
|
|
32
|
+
resp = requests.get(zip_url)
|
|
33
|
+
resp.raise_for_status()
|
|
34
|
+
with zipfile.ZipFile(io.BytesIO(resp.content)) as z:
|
|
35
|
+
prefix = "MOFClassifier-main/MOFClassifier/models/"
|
|
36
|
+
for member in z.namelist():
|
|
37
|
+
if not member.startswith(prefix) or member.endswith("/"):
|
|
38
|
+
continue
|
|
39
|
+
rel_path = member[len(prefix):]
|
|
40
|
+
dest_path = os.path.join(models_dir, rel_path)
|
|
41
|
+
os.makedirs(os.path.dirname(dest_path), exist_ok=True)
|
|
42
|
+
with z.open(member) as src, open(dest_path, "wb") as dst:
|
|
43
|
+
dst.write(src.read())
|
|
44
|
+
print(f"Extracted {rel_path} to models/")
|
|
45
|
+
|
|
46
|
+
models_dir_qsp = os.path.join(package_directory, "models_qsp")
|
|
47
|
+
|
|
48
|
+
if os.path.isdir(models_dir_qsp) and os.listdir(models_dir_qsp):
|
|
49
|
+
pass
|
|
50
|
+
else:
|
|
51
|
+
os.makedirs(models_dir_qsp, exist_ok=True)
|
|
52
|
+
zip_url = "https://github.com/Chung-Research-Group/MOFClassifier/archive/refs/heads/main.zip"
|
|
53
|
+
resp = requests.get(zip_url)
|
|
54
|
+
resp.raise_for_status()
|
|
55
|
+
with zipfile.ZipFile(io.BytesIO(resp.content)) as z:
|
|
56
|
+
prefix = "MOFClassifier-main/MOFClassifier/models_qsp/"
|
|
57
|
+
for member in z.namelist():
|
|
58
|
+
if not member.startswith(prefix) or member.endswith("/"):
|
|
59
|
+
continue
|
|
60
|
+
rel_path = member[len(prefix):]
|
|
61
|
+
dest_path = os.path.join(models_dir_qsp, rel_path)
|
|
62
|
+
os.makedirs(os.path.dirname(dest_path), exist_ok=True)
|
|
63
|
+
with z.open(member) as src, open(dest_path, "wb") as dst:
|
|
64
|
+
dst.write(src.read())
|
|
65
|
+
print(f"Extracted {rel_path} to models/")
|
|
66
|
+
|
|
67
|
+
models_dir_h = os.path.join(package_directory, "models_h")
|
|
68
|
+
|
|
69
|
+
if os.path.isdir(models_dir_h) and os.listdir(models_dir_h):
|
|
70
|
+
pass
|
|
71
|
+
else:
|
|
72
|
+
os.makedirs(models_dir_h, exist_ok=True)
|
|
73
|
+
zip_url = "https://github.com/Chung-Research-Group/MOFClassifier/archive/refs/heads/main.zip"
|
|
74
|
+
resp = requests.get(zip_url)
|
|
75
|
+
resp.raise_for_status()
|
|
76
|
+
with zipfile.ZipFile(io.BytesIO(resp.content)) as z:
|
|
77
|
+
prefix = "MOFClassifier-main/MOFClassifier/models_h/"
|
|
78
|
+
for member in z.namelist():
|
|
79
|
+
if not member.startswith(prefix) or member.endswith("/"):
|
|
80
|
+
continue
|
|
81
|
+
rel_path = member[len(prefix):]
|
|
82
|
+
dest_path = os.path.join(models_dir_h, rel_path)
|
|
83
|
+
os.makedirs(os.path.dirname(dest_path), exist_ok=True)
|
|
84
|
+
with z.open(member) as src, open(dest_path, "wb") as dst:
|
|
85
|
+
dst.write(src.read())
|
|
86
|
+
print(f"Extracted {rel_path} to models/")
|
|
87
|
+
|
|
88
|
+
atom_url = "https://raw.githubusercontent.com/Chung-Research-Group/MOFClassifier/main/MOFClassifier/atom_init.json"
|
|
89
|
+
atom_path = os.path.join(package_directory, "atom_init.json")
|
|
90
|
+
|
|
91
|
+
if not os.path.exists(atom_path):
|
|
92
|
+
resp = requests.get(atom_url)
|
|
93
|
+
resp.raise_for_status()
|
|
94
|
+
with open(atom_path, "wb") as f:
|
|
95
|
+
f.write(resp.content)
|
|
96
|
+
print("Downloaded atom_init.json")
|
|
97
|
+
else:
|
|
98
|
+
pass
|
|
99
|
+
|
|
100
|
+
def collate_pool(dataset_list):
|
|
101
|
+
batch_atom_fea, batch_nbr_fea, batch_nbr_fea_idx = [], [], []
|
|
102
|
+
crystal_atom_idx = []
|
|
103
|
+
batch_cif_ids = []
|
|
104
|
+
base_idx = 0
|
|
105
|
+
for i, ((atom_fea, nbr_fea, nbr_fea_idx), cif_id)\
|
|
106
|
+
in enumerate(dataset_list):
|
|
107
|
+
n_i = atom_fea.shape[0]
|
|
108
|
+
batch_atom_fea.append(atom_fea)
|
|
109
|
+
batch_nbr_fea.append(nbr_fea)
|
|
110
|
+
batch_nbr_fea_idx.append(nbr_fea_idx+base_idx)
|
|
111
|
+
new_idx = torch.LongTensor(np.arange(n_i)+base_idx)
|
|
112
|
+
crystal_atom_idx.append(new_idx)
|
|
113
|
+
batch_cif_ids.append(cif_id)
|
|
114
|
+
base_idx += n_i
|
|
115
|
+
return (torch.cat(batch_atom_fea, dim=0),
|
|
116
|
+
torch.cat(batch_nbr_fea, dim=0),
|
|
117
|
+
torch.cat(batch_nbr_fea_idx, dim=0),
|
|
118
|
+
crystal_atom_idx),\
|
|
119
|
+
batch_cif_ids
|
|
120
|
+
|
|
121
|
+
class GaussianDistance(object):
|
|
122
|
+
def __init__(self, dmin, dmax, step, var=None):
|
|
123
|
+
assert dmin < dmax
|
|
124
|
+
assert dmax - dmin > step
|
|
125
|
+
self.filter = np.arange(dmin, dmax+step, step)
|
|
126
|
+
if var is None:
|
|
127
|
+
var = step
|
|
128
|
+
self.var = var
|
|
129
|
+
def expand(self, distances):
|
|
130
|
+
return np.exp(-(distances[..., np.newaxis] - self.filter)**2 /
|
|
131
|
+
self.var**2)
|
|
132
|
+
|
|
133
|
+
class AtomInitializer(object):
|
|
134
|
+
def __init__(self, atom_types):
|
|
135
|
+
self.atom_types = set(atom_types)
|
|
136
|
+
self._embedding = {}
|
|
137
|
+
def get_atom_fea(self, atom_type):
|
|
138
|
+
assert atom_type in self.atom_types
|
|
139
|
+
return self._embedding[atom_type]
|
|
140
|
+
def load_state_dict(self, state_dict):
|
|
141
|
+
self._embedding = state_dict
|
|
142
|
+
self.atom_types = set(self._embedding.keys())
|
|
143
|
+
self._decodedict = {idx: atom_type for atom_type, idx in
|
|
144
|
+
self._embedding.items()}
|
|
145
|
+
def state_dict(self):
|
|
146
|
+
return self._embedding
|
|
147
|
+
def decode(self, idx):
|
|
148
|
+
if not hasattr(self, '_decodedict'):
|
|
149
|
+
self._decodedict = {idx: atom_type for atom_type, idx in
|
|
150
|
+
self._embedding.items()}
|
|
151
|
+
return self._decodedict[idx]
|
|
152
|
+
|
|
153
|
+
class AtomCustomJSONInitializer(AtomInitializer):
|
|
154
|
+
def __init__(self, elem_embedding_file):
|
|
155
|
+
with open(elem_embedding_file) as f:
|
|
156
|
+
elem_embedding = json.load(f)
|
|
157
|
+
elem_embedding = {int(key): value for key, value
|
|
158
|
+
in elem_embedding.items()}
|
|
159
|
+
atom_types = set(elem_embedding.keys())
|
|
160
|
+
super(AtomCustomJSONInitializer, self).__init__(atom_types)
|
|
161
|
+
for key, value in elem_embedding.items():
|
|
162
|
+
self._embedding[key] = np.array(value, dtype=float)
|
|
163
|
+
|
|
164
|
+
def preprocess(root_cif, atom_init_file):
|
|
165
|
+
cif_id = os.path.basename(root_cif).replace(".cif", "")
|
|
166
|
+
ari = AtomCustomJSONInitializer(atom_init_file)
|
|
167
|
+
gdf = GaussianDistance(dmin=0, dmax=8, step=0.2)
|
|
168
|
+
try:
|
|
169
|
+
crystal = Structure.from_file(root_cif)
|
|
170
|
+
except:
|
|
171
|
+
with warnings.catch_warnings():
|
|
172
|
+
warnings.simplefilter("ignore")
|
|
173
|
+
atoms = read(root_cif)
|
|
174
|
+
write(root_cif, atoms)
|
|
175
|
+
crystal = Structure.from_file(root_cif)
|
|
176
|
+
atom_fea = np.vstack([ari.get_atom_fea(crystal[j].specie.number)
|
|
177
|
+
for j in range(len(crystal))])
|
|
178
|
+
atom_fea = torch.Tensor(atom_fea)
|
|
179
|
+
all_nbrs = crystal.get_all_neighbors(8, include_index=True)
|
|
180
|
+
all_nbrs = [sorted(nbrs, key=lambda x: x[1]) for nbrs in all_nbrs]
|
|
181
|
+
nbr_fea_idx, nbr_fea = [], []
|
|
182
|
+
for nbr in all_nbrs:
|
|
183
|
+
if len(nbr) < 12:
|
|
184
|
+
warnings.warn('{} not find enough neighbors to build graph. '
|
|
185
|
+
'If it happens frequently, consider increase '
|
|
186
|
+
'radius.'.format(cif_id))
|
|
187
|
+
nbr_fea_idx.append(list(map(lambda x: x[2], nbr)) + [0] * (12 - len(nbr)))
|
|
188
|
+
nbr_fea.append(list(map(lambda x: x[1], nbr)) + [8 + 1.] * (12 - len(nbr)))
|
|
189
|
+
else:
|
|
190
|
+
nbr_fea_idx.append(list(map(lambda x: x[2], nbr[:12])))
|
|
191
|
+
nbr_fea.append(list(map(lambda x: x[1], nbr[:12])))
|
|
192
|
+
nbr_fea_idx, nbr_fea = np.array(nbr_fea_idx), np.array(nbr_fea)
|
|
193
|
+
nbr_fea = gdf.expand(nbr_fea)
|
|
194
|
+
atom_fea = torch.Tensor(atom_fea)
|
|
195
|
+
nbr_fea = torch.Tensor(nbr_fea)
|
|
196
|
+
nbr_fea_idx = torch.LongTensor(nbr_fea_idx)
|
|
197
|
+
preload_data = ((atom_fea, nbr_fea, nbr_fea_idx), cif_id)
|
|
198
|
+
return preload_data
|
|
199
|
+
|
|
200
|
+
class ConvLayer(nn.Module):
|
|
201
|
+
def __init__(self, atom_fea_len, nbr_fea_len):
|
|
202
|
+
super(ConvLayer, self).__init__()
|
|
203
|
+
self.atom_fea_len = atom_fea_len
|
|
204
|
+
self.nbr_fea_len = nbr_fea_len
|
|
205
|
+
self.fc_full = nn.Linear(2*self.atom_fea_len+self.nbr_fea_len,
|
|
206
|
+
2*self.atom_fea_len)
|
|
207
|
+
self.sigmoid = nn.Sigmoid()
|
|
208
|
+
self.softplus1 = nn.Softplus()
|
|
209
|
+
self.bn1 = nn.BatchNorm1d(2*self.atom_fea_len)
|
|
210
|
+
self.bn2 = nn.BatchNorm1d(self.atom_fea_len)
|
|
211
|
+
self.softplus2 = nn.Softplus()
|
|
212
|
+
def forward(self, atom_in_fea, nbr_fea, nbr_fea_idx):
|
|
213
|
+
N, M = nbr_fea_idx.shape
|
|
214
|
+
atom_nbr_fea = atom_in_fea[nbr_fea_idx, :]
|
|
215
|
+
total_nbr_fea = torch.cat(
|
|
216
|
+
[atom_in_fea.unsqueeze(1).expand(N, M, self.atom_fea_len),
|
|
217
|
+
atom_nbr_fea, nbr_fea], dim=2)
|
|
218
|
+
total_gated_fea = self.fc_full(total_nbr_fea)
|
|
219
|
+
total_gated_fea = self.bn1(total_gated_fea.view(
|
|
220
|
+
-1, self.atom_fea_len*2)).view(N, M, self.atom_fea_len*2)
|
|
221
|
+
nbr_filter, nbr_core = total_gated_fea.chunk(2, dim=2)
|
|
222
|
+
nbr_filter = self.sigmoid(nbr_filter)
|
|
223
|
+
nbr_core = self.softplus1(nbr_core)
|
|
224
|
+
nbr_sumed = torch.sum(nbr_filter * nbr_core, dim=1)
|
|
225
|
+
nbr_sumed = self.bn2(nbr_sumed)
|
|
226
|
+
out = self.softplus2(atom_in_fea + nbr_sumed)
|
|
227
|
+
return out
|
|
228
|
+
|
|
229
|
+
class CrystalGraphConvNet(nn.Module):
|
|
230
|
+
def __init__(self, orig_atom_fea_len, nbr_fea_len,
|
|
231
|
+
atom_fea_len=64, n_conv=3, h_fea_len=128, n_h=1,
|
|
232
|
+
classification=False):
|
|
233
|
+
super(CrystalGraphConvNet, self).__init__()
|
|
234
|
+
self.classification = classification
|
|
235
|
+
self.embedding = nn.Linear(orig_atom_fea_len, atom_fea_len)
|
|
236
|
+
self.convs = nn.ModuleList([ConvLayer(atom_fea_len=atom_fea_len,
|
|
237
|
+
nbr_fea_len=nbr_fea_len)
|
|
238
|
+
for _ in range(n_conv)])
|
|
239
|
+
self.conv_to_fc = nn.Linear(atom_fea_len, h_fea_len)
|
|
240
|
+
self.conv_to_fc_softplus = nn.Softplus()
|
|
241
|
+
self.final_fea = 0
|
|
242
|
+
if n_h > 1:
|
|
243
|
+
self.fcs = nn.ModuleList([nn.Linear(h_fea_len, h_fea_len)
|
|
244
|
+
for _ in range(n_h-1)])
|
|
245
|
+
self.softpluses = nn.ModuleList([nn.Softplus()
|
|
246
|
+
for _ in range(n_h-1)])
|
|
247
|
+
if self.classification:
|
|
248
|
+
self.fc_out = nn.Linear(h_fea_len, 2)
|
|
249
|
+
else:
|
|
250
|
+
self.fc_out = nn.Linear(h_fea_len, 1)
|
|
251
|
+
if self.classification:
|
|
252
|
+
self.logsoftmax = nn.LogSoftmax(dim=1)
|
|
253
|
+
self.dropout = nn.Dropout()
|
|
254
|
+
def forward(self, atom_fea, nbr_fea, nbr_fea_idx, crystal_atom_idx):
|
|
255
|
+
atom_fea = self.embedding(atom_fea)
|
|
256
|
+
for conv_func in self.convs:
|
|
257
|
+
atom_fea = conv_func(atom_fea, nbr_fea, nbr_fea_idx)
|
|
258
|
+
crys_fea = self.pooling(atom_fea, crystal_atom_idx)
|
|
259
|
+
crys_fea = self.conv_to_fc(self.conv_to_fc_softplus(crys_fea))
|
|
260
|
+
crys_fea = self.conv_to_fc_softplus(crys_fea)
|
|
261
|
+
if self.classification:
|
|
262
|
+
crys_fea = self.dropout(crys_fea)
|
|
263
|
+
if hasattr(self, 'fcs') and hasattr(self, 'softpluses'):
|
|
264
|
+
for fc, softplus in zip(self.fcs, self.softpluses):
|
|
265
|
+
crys_fea = softplus(fc(crys_fea))
|
|
266
|
+
|
|
267
|
+
self.final_fea = crys_fea
|
|
268
|
+
|
|
269
|
+
out = self.fc_out(crys_fea)
|
|
270
|
+
if self.classification:
|
|
271
|
+
out = self.logsoftmax(out)
|
|
272
|
+
return out
|
|
273
|
+
def pooling(self, atom_fea, crystal_atom_idx):
|
|
274
|
+
assert sum([len(idx_map) for idx_map in crystal_atom_idx]) ==\
|
|
275
|
+
atom_fea.data.shape[0]
|
|
276
|
+
summed_fea = [torch.mean(atom_fea[idx_map], dim=0, keepdim=True)
|
|
277
|
+
for idx_map in crystal_atom_idx]
|
|
278
|
+
return torch.cat(summed_fea, dim=0)
|
|
279
|
+
|
|
280
|
+
class Normalizer(object):
|
|
281
|
+
def __init__(self, tensor):
|
|
282
|
+
self.mean = torch.mean(tensor)
|
|
283
|
+
self.std = torch.std(tensor)
|
|
284
|
+
def norm(self, tensor):
|
|
285
|
+
return (tensor - self.mean) / self.std
|
|
286
|
+
def denorm(self, normed_tensor):
|
|
287
|
+
return normed_tensor * self.std + self.mean
|
|
288
|
+
def state_dict(self):
|
|
289
|
+
return {'mean': self.mean,
|
|
290
|
+
'std': self.std}
|
|
291
|
+
def load_state_dict(self, state_dict):
|
|
292
|
+
self.mean = state_dict['mean']
|
|
293
|
+
self.std = state_dict['std']
|
|
294
|
+
|
|
295
|
+
def predict(root_cif,
|
|
296
|
+
atom_init_file=os.path.join(package_directory, "atom_init.json"),
|
|
297
|
+
model = "core"):
|
|
298
|
+
use_cuda = torch.cuda.is_available()
|
|
299
|
+
models_100 = []
|
|
300
|
+
if model == "core":
|
|
301
|
+
model_dir = os.path.join(package_directory, "models")
|
|
302
|
+
elif model == "qsp":
|
|
303
|
+
model_dir = os.path.join(package_directory, "models_qsp")
|
|
304
|
+
elif model == "h":
|
|
305
|
+
model_dir = os.path.join(package_directory, "models_h")
|
|
306
|
+
else:
|
|
307
|
+
print("Currently only core or qsp are supported.")
|
|
308
|
+
for i in tqdm(range(1, 101)):
|
|
309
|
+
collate_fn = collate_pool
|
|
310
|
+
dataset_test = []
|
|
311
|
+
dataset_test.append(preprocess(root_cif=root_cif, atom_init_file=atom_init_file))
|
|
312
|
+
test_loader = DataLoader(dataset_test, batch_size=1, shuffle=True,
|
|
313
|
+
num_workers=0, collate_fn=collate_fn,
|
|
314
|
+
pin_memory=use_cuda)
|
|
315
|
+
modelpath = os.path.join(model_dir, 'checkpoint_bag_'+str(i)+'.pth.tar')
|
|
316
|
+
if os.path.isfile(modelpath):
|
|
317
|
+
model_checkpoint = torch.load(modelpath, weights_only=False,
|
|
318
|
+
map_location=lambda storage, loc: storage)
|
|
319
|
+
model_args = argparse.Namespace(**model_checkpoint['args'])
|
|
320
|
+
else:
|
|
321
|
+
print("=> no model params found at '{}'".format(modelpath))
|
|
322
|
+
structures, _ = dataset_test[0]
|
|
323
|
+
orig_atom_fea_len = structures[0].shape[-1]
|
|
324
|
+
nbr_fea_len = structures[1].shape[-1]
|
|
325
|
+
model = CrystalGraphConvNet(orig_atom_fea_len, nbr_fea_len,
|
|
326
|
+
atom_fea_len=model_args.atom_fea_len,
|
|
327
|
+
n_conv=model_args.n_conv,
|
|
328
|
+
h_fea_len=model_args.h_fea_len,
|
|
329
|
+
n_h=model_args.n_h,
|
|
330
|
+
classification=True)
|
|
331
|
+
if use_cuda:
|
|
332
|
+
model.cuda()
|
|
333
|
+
normalizer = Normalizer(torch.zeros(3))
|
|
334
|
+
if os.path.isfile(modelpath):
|
|
335
|
+
checkpoint = torch.load(modelpath, weights_only=False,
|
|
336
|
+
map_location=lambda storage, loc: storage)
|
|
337
|
+
model.load_state_dict(checkpoint['state_dict'])
|
|
338
|
+
normalizer.load_state_dict(checkpoint['normalizer'])
|
|
339
|
+
else:
|
|
340
|
+
print("=> no model found at '{}'".format(modelpath))
|
|
341
|
+
test_preds = []
|
|
342
|
+
test_cif_ids = []
|
|
343
|
+
model.eval()
|
|
344
|
+
for _, (input, batch_cif_ids) in enumerate(test_loader):
|
|
345
|
+
with torch.no_grad():
|
|
346
|
+
if use_cuda:
|
|
347
|
+
input_var = (Variable(input[0].cuda(non_blocking=True)),
|
|
348
|
+
Variable(input[1].cuda(non_blocking=True)),
|
|
349
|
+
input[2].cuda(non_blocking=True),
|
|
350
|
+
[crys_idx.cuda(non_blocking=True) for crys_idx in input[3]])
|
|
351
|
+
else:
|
|
352
|
+
input_var = (Variable(input[0]),
|
|
353
|
+
Variable(input[1]),
|
|
354
|
+
input[2],
|
|
355
|
+
input[3])
|
|
356
|
+
output = model(*input_var)
|
|
357
|
+
test_pred = torch.exp(output.data.cpu())
|
|
358
|
+
assert test_pred.shape[1] == 2
|
|
359
|
+
test_preds += test_pred[:, 1].tolist()
|
|
360
|
+
test_cif_ids += batch_cif_ids
|
|
361
|
+
models_100.extend(test_preds)
|
|
362
|
+
CLscore = np.mean(models_100)
|
|
363
|
+
return [test_cif_ids[0], models_100, CLscore]
|
|
364
|
+
|
|
365
|
+
|
|
366
|
+
def predict_batch(
|
|
367
|
+
root_cifs,
|
|
368
|
+
atom_init_file=os.path.join(package_directory, "atom_init.json"),
|
|
369
|
+
model = "core",
|
|
370
|
+
batch_size=512,
|
|
371
|
+
):
|
|
372
|
+
use_cuda = torch.cuda.is_available()
|
|
373
|
+
models_100 = []
|
|
374
|
+
|
|
375
|
+
if model == "core":
|
|
376
|
+
model_dir = os.path.join(package_directory, "models")
|
|
377
|
+
elif model == "qsp":
|
|
378
|
+
model_dir = os.path.join(package_directory, "models_qsp")
|
|
379
|
+
elif model == "h":
|
|
380
|
+
model_dir = os.path.join(package_directory, "models_h")
|
|
381
|
+
else:
|
|
382
|
+
print("Currently only core or qsp are supported.")
|
|
383
|
+
|
|
384
|
+
collate_fn = collate_pool
|
|
385
|
+
dataset_test = []
|
|
386
|
+
dataset_test.extend(
|
|
387
|
+
[
|
|
388
|
+
preprocess(root_cif=root_cif, atom_init_file=atom_init_file)
|
|
389
|
+
for root_cif in root_cifs
|
|
390
|
+
]
|
|
391
|
+
)
|
|
392
|
+
test_loader = DataLoader(
|
|
393
|
+
dataset_test,
|
|
394
|
+
batch_size=batch_size,
|
|
395
|
+
shuffle=False,
|
|
396
|
+
num_workers=0,
|
|
397
|
+
collate_fn=collate_fn,
|
|
398
|
+
pin_memory=use_cuda,
|
|
399
|
+
)
|
|
400
|
+
test_cif_ids = []
|
|
401
|
+
|
|
402
|
+
for i in tqdm(range(1, 101)):
|
|
403
|
+
modelpath = os.path.join(model_dir, "checkpoint_bag_" + str(i) + ".pth.tar")
|
|
404
|
+
if os.path.isfile(modelpath):
|
|
405
|
+
model_checkpoint = torch.load(modelpath, weights_only=False,
|
|
406
|
+
map_location=lambda storage, loc: storage)
|
|
407
|
+
model_args = argparse.Namespace(**model_checkpoint['args'])
|
|
408
|
+
else:
|
|
409
|
+
print("=> no model params found at '{}'".format(modelpath))
|
|
410
|
+
structures, _ = dataset_test[0]
|
|
411
|
+
orig_atom_fea_len = structures[0].shape[-1]
|
|
412
|
+
nbr_fea_len = structures[1].shape[-1]
|
|
413
|
+
model = CrystalGraphConvNet(orig_atom_fea_len, nbr_fea_len,
|
|
414
|
+
atom_fea_len=model_args.atom_fea_len,
|
|
415
|
+
n_conv=model_args.n_conv,
|
|
416
|
+
h_fea_len=model_args.h_fea_len,
|
|
417
|
+
n_h=model_args.n_h,
|
|
418
|
+
classification=True)
|
|
419
|
+
if use_cuda:
|
|
420
|
+
model.cuda()
|
|
421
|
+
normalizer = Normalizer(torch.zeros(3))
|
|
422
|
+
if os.path.isfile(modelpath):
|
|
423
|
+
checkpoint = torch.load(modelpath, weights_only=False,
|
|
424
|
+
map_location=lambda storage, loc: storage)
|
|
425
|
+
model.load_state_dict(checkpoint['state_dict'])
|
|
426
|
+
normalizer.load_state_dict(checkpoint['normalizer'])
|
|
427
|
+
else:
|
|
428
|
+
print("=> no model found at '{}'".format(modelpath))
|
|
429
|
+
test_preds = []
|
|
430
|
+
model.eval()
|
|
431
|
+
for _, (input, batch_cif_ids) in enumerate(test_loader):
|
|
432
|
+
with torch.no_grad():
|
|
433
|
+
if use_cuda:
|
|
434
|
+
input_var = (
|
|
435
|
+
Variable(input[0].cuda()),
|
|
436
|
+
Variable(input[1].cuda()),
|
|
437
|
+
input[2].cuda(),
|
|
438
|
+
[crys_idx.cuda() for crys_idx in input[3]],
|
|
439
|
+
)
|
|
440
|
+
else:
|
|
441
|
+
input_var = (
|
|
442
|
+
Variable(input[0]),
|
|
443
|
+
Variable(input[1]),
|
|
444
|
+
input[2],
|
|
445
|
+
input[3],
|
|
446
|
+
)
|
|
447
|
+
output = model(*input_var)
|
|
448
|
+
test_pred = torch.exp(output.data.cpu())
|
|
449
|
+
assert test_pred.shape[1] == 2
|
|
450
|
+
test_preds += test_pred[:, 1]
|
|
451
|
+
if i == 1:
|
|
452
|
+
test_cif_ids += batch_cif_ids
|
|
453
|
+
models_100.append(test_preds)
|
|
454
|
+
models_100 = np.asarray(models_100).T
|
|
455
|
+
CLscore = np.mean(models_100, axis=1).tolist()
|
|
456
|
+
return [
|
|
457
|
+
(test_cif_ids[i], models_100[i].tolist(), CLscore[i])
|
|
458
|
+
for i in range(len(root_cifs))
|
|
459
|
+
]
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
Metadata-Version: 2.1
|
|
2
|
+
Name: mofclassifier
|
|
3
|
+
Version: 0.1.2
|
|
4
|
+
Summary: A Machine Learning Approach for Validating Computation-Ready Metal-Organic Frameworks
|
|
5
|
+
Home-page: https://github.com/Chung-Research-Group/MOFClassifier
|
|
6
|
+
Author: Guobin Zhao
|
|
7
|
+
Author-email: sxmzhaogb@gmai.com
|
|
8
|
+
License: CC-BY-4.0
|
|
9
|
+
Classifier: Development Status :: 5 - Production/Stable
|
|
10
|
+
Classifier: Intended Audience :: Developers
|
|
11
|
+
Classifier: Topic :: Scientific/Engineering :: Chemistry
|
|
12
|
+
Classifier: Programming Language :: Python :: 3.9
|
|
13
|
+
Requires-Python: >=3.9, <4
|
|
14
|
+
Description-Content-Type: text/markdown
|
|
15
|
+
Requires-Dist: ase
|
|
16
|
+
Requires-Dist: numpy==1.26.4
|
|
17
|
+
Requires-Dist: torch==2.7.0
|
|
18
|
+
Requires-Dist: Pymatgen==2024.8.9
|
|
19
|
+
Requires-Dist: scikit-learn==1.3.2
|
|
20
|
+
Requires-Dist: tqdm==4.67.1
|
|
21
|
+
Requires-Dist: pandas==2.2.3
|
|
22
|
+
|
|
23
|
+
## MOFClassifier: A Machine Learning Approach for Validating Computation-Ready Metal-Organic Frameworks
|
|
24
|
+
|
|
25
|
+
<div align="center">
|
|
26
|
+
<img src="https://raw.githubusercontent.com/sxm13/pypi-dev/main/logos/mofclassifier.png" alt="mofclassifier logo" width="500"/>
|
|
27
|
+
</div>
|
|
28
|
+
|
|
29
|
+
[](https://arxiv.org/abs/2506.14845)
|
|
30
|
+
[](https://pubs.acs.org/doi/10.1021/jacs.5c10126)
|
|
31
|
+

|
|
32
|
+
[](https://pypi.org/project/MOFClassifier?logo=pypi&logoColor=white)
|
|
33
|
+
[](https://python.org/downloads)
|
|
34
|
+
[](https://github.com/sxm13/pypi-dev/blob/main/LICENSE)
|
|
35
|
+
[](https://GitHub.com/Chung-Research-Group/MOFClassifier/issues/)
|
|
36
|
+
[](https://doi.org/10.5281/zenodo.15654431)
|
|
37
|
+
|
|
38
|
+
### Installation
|
|
39
|
+
|
|
40
|
+
```sh
|
|
41
|
+
pip install MOFClassifier
|
|
42
|
+
```
|
|
43
|
+
|
|
44
|
+
### Examples
|
|
45
|
+
```python
|
|
46
|
+
from MOFClassifier import CLscore
|
|
47
|
+
result = CLscore.predict(root_cif="./example.cif", model="core")
|
|
48
|
+
```
|
|
49
|
+
- **root_cif**: the path of your structure
|
|
50
|
+
- **model**: the model name: a. "core": training with CoRE MOF DB; b. "qsp": training with CoRE MOF DB and QMOF DB; c. "h": training with ToBaCCo (Hypothetical MOFs)
|
|
51
|
+
- **result**: a. cifid: the name of structure; b. all_score: the CLscore predicted by 100 models (bags); c. mean_score: the mean CLscore of CLscores
|
|
52
|
+
|
|
53
|
+
```python
|
|
54
|
+
from MOFClassifier import CLscore
|
|
55
|
+
results = CLscore.predict_batch(root_cifs=["./example1.cif""./example2.cif","./example3.cif"], model="core", batch_size=512)
|
|
56
|
+
```
|
|
57
|
+
- **root_cifs**: the path of your structures
|
|
58
|
+
- **model**: the model name: a. "core": training with CoRE MOF DB; b. "qsp": training with CoRE MOF DB and QMOF DB; c. "h": training with ToBaCCo (Hypothetical MOFs)
|
|
59
|
+
- **batch_size**: the number of samples
|
|
60
|
+
- **results**: a. cifid: the name of structure; b. all_score: the CLscore predicted by 100 models (bags); c. mean_score: the mean CLscore of CLscores
|
|
61
|
+
|
|
62
|
+
### Citation
|
|
63
|
+
[Guobin Zhao, Pengyu Zhao and Yongchul G. Chung. ***Journal of the American Chemical Society***, 2025, 147, 37, 33343–33349. DOI: 10.1021/jacs.5c10126](https://pubs.acs.org/doi/10.1021/jacs.5c10126)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
### Acknowledgments
|
|
67
|
+
We thank [henk789](https://github.com/henk789) for contribution to batch prediction.
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
## MOFClassifier: A Machine Learning Approach for Validating Computation-Ready Metal-Organic Frameworks
|
|
2
|
+
|
|
3
|
+
<div align="center">
|
|
4
|
+
<img src="https://raw.githubusercontent.com/sxm13/pypi-dev/main/logos/mofclassifier.png" alt="mofclassifier logo" width="500"/>
|
|
5
|
+
</div>
|
|
6
|
+
|
|
7
|
+
[](https://arxiv.org/abs/2506.14845)
|
|
8
|
+
[](https://pubs.acs.org/doi/10.1021/jacs.5c10126)
|
|
9
|
+

|
|
10
|
+
[](https://pypi.org/project/MOFClassifier?logo=pypi&logoColor=white)
|
|
11
|
+
[](https://python.org/downloads)
|
|
12
|
+
[](https://github.com/sxm13/pypi-dev/blob/main/LICENSE)
|
|
13
|
+
[](https://GitHub.com/Chung-Research-Group/MOFClassifier/issues/)
|
|
14
|
+
[](https://doi.org/10.5281/zenodo.15654431)
|
|
15
|
+
|
|
16
|
+
### Installation
|
|
17
|
+
|
|
18
|
+
```sh
|
|
19
|
+
pip install MOFClassifier
|
|
20
|
+
```
|
|
21
|
+
|
|
22
|
+
### Examples
|
|
23
|
+
```python
|
|
24
|
+
from MOFClassifier import CLscore
|
|
25
|
+
result = CLscore.predict(root_cif="./example.cif", model="core")
|
|
26
|
+
```
|
|
27
|
+
- **root_cif**: the path of your structure
|
|
28
|
+
- **model**: the model name: a. "core": training with CoRE MOF DB; b. "qsp": training with CoRE MOF DB and QMOF DB; c. "h": training with ToBaCCo (Hypothetical MOFs)
|
|
29
|
+
- **result**: a. cifid: the name of structure; b. all_score: the CLscore predicted by 100 models (bags); c. mean_score: the mean CLscore of CLscores
|
|
30
|
+
|
|
31
|
+
```python
|
|
32
|
+
from MOFClassifier import CLscore
|
|
33
|
+
results = CLscore.predict_batch(root_cifs=["./example1.cif""./example2.cif","./example3.cif"], model="core", batch_size=512)
|
|
34
|
+
```
|
|
35
|
+
- **root_cifs**: the path of your structures
|
|
36
|
+
- **model**: the model name: a. "core": training with CoRE MOF DB; b. "qsp": training with CoRE MOF DB and QMOF DB; c. "h": training with ToBaCCo (Hypothetical MOFs)
|
|
37
|
+
- **batch_size**: the number of samples
|
|
38
|
+
- **results**: a. cifid: the name of structure; b. all_score: the CLscore predicted by 100 models (bags); c. mean_score: the mean CLscore of CLscores
|
|
39
|
+
|
|
40
|
+
### Citation
|
|
41
|
+
[Guobin Zhao, Pengyu Zhao and Yongchul G. Chung. ***Journal of the American Chemical Society***, 2025, 147, 37, 33343–33349. DOI: 10.1021/jacs.5c10126](https://pubs.acs.org/doi/10.1021/jacs.5c10126)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
### Acknowledgments
|
|
45
|
+
We thank [henk789](https://github.com/henk789) for contribution to batch prediction.
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
Metadata-Version: 2.1
|
|
2
|
+
Name: mofclassifier
|
|
3
|
+
Version: 0.1.2
|
|
4
|
+
Summary: A Machine Learning Approach for Validating Computation-Ready Metal-Organic Frameworks
|
|
5
|
+
Home-page: https://github.com/Chung-Research-Group/MOFClassifier
|
|
6
|
+
Author: Guobin Zhao
|
|
7
|
+
Author-email: sxmzhaogb@gmai.com
|
|
8
|
+
License: CC-BY-4.0
|
|
9
|
+
Classifier: Development Status :: 5 - Production/Stable
|
|
10
|
+
Classifier: Intended Audience :: Developers
|
|
11
|
+
Classifier: Topic :: Scientific/Engineering :: Chemistry
|
|
12
|
+
Classifier: Programming Language :: Python :: 3.9
|
|
13
|
+
Requires-Python: >=3.9, <4
|
|
14
|
+
Description-Content-Type: text/markdown
|
|
15
|
+
Requires-Dist: ase
|
|
16
|
+
Requires-Dist: numpy==1.26.4
|
|
17
|
+
Requires-Dist: torch==2.7.0
|
|
18
|
+
Requires-Dist: Pymatgen==2024.8.9
|
|
19
|
+
Requires-Dist: scikit-learn==1.3.2
|
|
20
|
+
Requires-Dist: tqdm==4.67.1
|
|
21
|
+
Requires-Dist: pandas==2.2.3
|
|
22
|
+
|
|
23
|
+
## MOFClassifier: A Machine Learning Approach for Validating Computation-Ready Metal-Organic Frameworks
|
|
24
|
+
|
|
25
|
+
<div align="center">
|
|
26
|
+
<img src="https://raw.githubusercontent.com/sxm13/pypi-dev/main/logos/mofclassifier.png" alt="mofclassifier logo" width="500"/>
|
|
27
|
+
</div>
|
|
28
|
+
|
|
29
|
+
[](https://arxiv.org/abs/2506.14845)
|
|
30
|
+
[](https://pubs.acs.org/doi/10.1021/jacs.5c10126)
|
|
31
|
+

|
|
32
|
+
[](https://pypi.org/project/MOFClassifier?logo=pypi&logoColor=white)
|
|
33
|
+
[](https://python.org/downloads)
|
|
34
|
+
[](https://github.com/sxm13/pypi-dev/blob/main/LICENSE)
|
|
35
|
+
[](https://GitHub.com/Chung-Research-Group/MOFClassifier/issues/)
|
|
36
|
+
[](https://doi.org/10.5281/zenodo.15654431)
|
|
37
|
+
|
|
38
|
+
### Installation
|
|
39
|
+
|
|
40
|
+
```sh
|
|
41
|
+
pip install MOFClassifier
|
|
42
|
+
```
|
|
43
|
+
|
|
44
|
+
### Examples
|
|
45
|
+
```python
|
|
46
|
+
from MOFClassifier import CLscore
|
|
47
|
+
result = CLscore.predict(root_cif="./example.cif", model="core")
|
|
48
|
+
```
|
|
49
|
+
- **root_cif**: the path of your structure
|
|
50
|
+
- **model**: the model name: a. "core": training with CoRE MOF DB; b. "qsp": training with CoRE MOF DB and QMOF DB; c. "h": training with ToBaCCo (Hypothetical MOFs)
|
|
51
|
+
- **result**: a. cifid: the name of structure; b. all_score: the CLscore predicted by 100 models (bags); c. mean_score: the mean CLscore of CLscores
|
|
52
|
+
|
|
53
|
+
```python
|
|
54
|
+
from MOFClassifier import CLscore
|
|
55
|
+
results = CLscore.predict_batch(root_cifs=["./example1.cif""./example2.cif","./example3.cif"], model="core", batch_size=512)
|
|
56
|
+
```
|
|
57
|
+
- **root_cifs**: the path of your structures
|
|
58
|
+
- **model**: the model name: a. "core": training with CoRE MOF DB; b. "qsp": training with CoRE MOF DB and QMOF DB; c. "h": training with ToBaCCo (Hypothetical MOFs)
|
|
59
|
+
- **batch_size**: the number of samples
|
|
60
|
+
- **results**: a. cifid: the name of structure; b. all_score: the CLscore predicted by 100 models (bags); c. mean_score: the mean CLscore of CLscores
|
|
61
|
+
|
|
62
|
+
### Citation
|
|
63
|
+
[Guobin Zhao, Pengyu Zhao and Yongchul G. Chung. ***Journal of the American Chemical Society***, 2025, 147, 37, 33343–33349. DOI: 10.1021/jacs.5c10126](https://pubs.acs.org/doi/10.1021/jacs.5c10126)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
### Acknowledgments
|
|
67
|
+
We thank [henk789](https://github.com/henk789) for contribution to batch prediction.
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
README.md
|
|
2
|
+
setup.py
|
|
3
|
+
MOFClassifier/CLscore.py
|
|
4
|
+
MOFClassifier/__init__.py
|
|
5
|
+
mofclassifier.egg-info/PKG-INFO
|
|
6
|
+
mofclassifier.egg-info/SOURCES.txt
|
|
7
|
+
mofclassifier.egg-info/dependency_links.txt
|
|
8
|
+
mofclassifier.egg-info/requires.txt
|
|
9
|
+
mofclassifier.egg-info/top_level.txt
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
MOFClassifier
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
import requests
|
|
2
|
+
from setuptools import setup, find_packages
|
|
3
|
+
|
|
4
|
+
setup(
|
|
5
|
+
name="mofclassifier",
|
|
6
|
+
version="0.1.2",
|
|
7
|
+
packages=find_packages(),
|
|
8
|
+
description="A Machine Learning Approach for Validating Computation-Ready Metal-Organic Frameworks",
|
|
9
|
+
author="Guobin Zhao",
|
|
10
|
+
author_email="sxmzhaogb@gmai.com",
|
|
11
|
+
url="https://github.com/Chung-Research-Group/MOFClassifier",
|
|
12
|
+
long_description=open('README.md').read(),
|
|
13
|
+
long_description_content_type='text/markdown',
|
|
14
|
+
license="CC-BY-4.0",
|
|
15
|
+
classifiers=[
|
|
16
|
+
'Development Status :: 5 - Production/Stable',
|
|
17
|
+
'Intended Audience :: Developers',
|
|
18
|
+
'Topic :: Scientific/Engineering :: Chemistry',
|
|
19
|
+
'Programming Language :: Python :: 3.9',
|
|
20
|
+
],
|
|
21
|
+
install_requires=[
|
|
22
|
+
"ase",
|
|
23
|
+
"numpy==1.26.4",
|
|
24
|
+
"torch==2.7.0",
|
|
25
|
+
"Pymatgen==2024.8.9",
|
|
26
|
+
"scikit-learn==1.3.2",
|
|
27
|
+
"tqdm==4.67.1",
|
|
28
|
+
"pandas==2.2.3"
|
|
29
|
+
],
|
|
30
|
+
license_files = ("LICENSE",),
|
|
31
|
+
python_requires='>=3.9, <4',
|
|
32
|
+
)
|