molcraft 0.1.0a6__tar.gz → 0.1.0a8__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of molcraft might be problematic. Click here for more details.
- {molcraft-0.1.0a6 → molcraft-0.1.0a8}/PKG-INFO +2 -2
- {molcraft-0.1.0a6 → molcraft-0.1.0a8}/README.md +1 -1
- {molcraft-0.1.0a6 → molcraft-0.1.0a8}/molcraft/__init__.py +1 -1
- molcraft-0.1.0a8/molcraft/callbacks.py +100 -0
- {molcraft-0.1.0a6 → molcraft-0.1.0a8}/molcraft/chem.py +45 -30
- {molcraft-0.1.0a6 → molcraft-0.1.0a8}/molcraft/conformers.py +0 -4
- {molcraft-0.1.0a6 → molcraft-0.1.0a8}/molcraft/features.py +3 -9
- {molcraft-0.1.0a6 → molcraft-0.1.0a8}/molcraft/featurizers.py +18 -26
- {molcraft-0.1.0a6 → molcraft-0.1.0a8}/molcraft/layers.py +466 -801
- {molcraft-0.1.0a6 → molcraft-0.1.0a8}/molcraft/models.py +16 -1
- {molcraft-0.1.0a6 → molcraft-0.1.0a8}/molcraft/ops.py +14 -3
- {molcraft-0.1.0a6 → molcraft-0.1.0a8}/molcraft.egg-info/PKG-INFO +2 -2
- {molcraft-0.1.0a6 → molcraft-0.1.0a8}/tests/test_layers.py +9 -3
- {molcraft-0.1.0a6 → molcraft-0.1.0a8}/tests/test_models.py +60 -3
- molcraft-0.1.0a6/molcraft/callbacks.py +0 -33
- {molcraft-0.1.0a6 → molcraft-0.1.0a8}/LICENSE +0 -0
- {molcraft-0.1.0a6 → molcraft-0.1.0a8}/molcraft/datasets.py +0 -0
- {molcraft-0.1.0a6 → molcraft-0.1.0a8}/molcraft/descriptors.py +0 -0
- {molcraft-0.1.0a6 → molcraft-0.1.0a8}/molcraft/losses.py +0 -0
- {molcraft-0.1.0a6 → molcraft-0.1.0a8}/molcraft/records.py +0 -0
- {molcraft-0.1.0a6 → molcraft-0.1.0a8}/molcraft/tensors.py +0 -0
- {molcraft-0.1.0a6 → molcraft-0.1.0a8}/molcraft.egg-info/SOURCES.txt +0 -0
- {molcraft-0.1.0a6 → molcraft-0.1.0a8}/molcraft.egg-info/dependency_links.txt +0 -0
- {molcraft-0.1.0a6 → molcraft-0.1.0a8}/molcraft.egg-info/requires.txt +0 -0
- {molcraft-0.1.0a6 → molcraft-0.1.0a8}/molcraft.egg-info/top_level.txt +0 -0
- {molcraft-0.1.0a6 → molcraft-0.1.0a8}/pyproject.toml +0 -0
- {molcraft-0.1.0a6 → molcraft-0.1.0a8}/setup.cfg +0 -0
- {molcraft-0.1.0a6 → molcraft-0.1.0a8}/tests/test_chem.py +0 -0
- {molcraft-0.1.0a6 → molcraft-0.1.0a8}/tests/test_featurizers.py +0 -0
- {molcraft-0.1.0a6 → molcraft-0.1.0a8}/tests/test_losses.py +0 -0
- {molcraft-0.1.0a6 → molcraft-0.1.0a8}/tests/test_tensors.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: molcraft
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.0a8
|
|
4
4
|
Summary: Graph Neural Networks for Molecular Machine Learning
|
|
5
5
|
Author-email: Alexander Kensert <alexander.kensert@gmail.com>
|
|
6
6
|
License: MIT License
|
|
@@ -78,7 +78,7 @@ featurizer = featurizers.MolGraphFeaturizer(
|
|
|
78
78
|
features.IsRotatable(),
|
|
79
79
|
],
|
|
80
80
|
super_atom=True,
|
|
81
|
-
self_loops=
|
|
81
|
+
self_loops=True,
|
|
82
82
|
)
|
|
83
83
|
|
|
84
84
|
graph = featurizer([('N[C@@H](C)C(=O)O', 2.0), ('N[C@@H](CS)C(=O)O', 1.0)])
|
|
@@ -0,0 +1,100 @@
|
|
|
1
|
+
import keras
|
|
2
|
+
import warnings
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class TensorBoard(keras.callbacks.TensorBoard):
|
|
7
|
+
|
|
8
|
+
def _log_weights(self, epoch):
|
|
9
|
+
with self._train_writer.as_default():
|
|
10
|
+
for layer in self.model.layers:
|
|
11
|
+
for weight in layer.weights:
|
|
12
|
+
# Use weight.path istead of weight.name to distinguish
|
|
13
|
+
# weights of different layers.
|
|
14
|
+
histogram_weight_name = weight.path + "/histogram"
|
|
15
|
+
self.summary.histogram(
|
|
16
|
+
histogram_weight_name, weight, step=epoch
|
|
17
|
+
)
|
|
18
|
+
if self.write_images:
|
|
19
|
+
image_weight_name = weight.path + "/image"
|
|
20
|
+
self._log_weight_as_image(
|
|
21
|
+
weight, image_weight_name, epoch
|
|
22
|
+
)
|
|
23
|
+
self._train_writer.flush()
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class LearningRateDecay(keras.callbacks.LearningRateScheduler):
|
|
27
|
+
|
|
28
|
+
def __init__(self, rate: float, delay: int = 0, **kwargs):
|
|
29
|
+
|
|
30
|
+
def lr_schedule(epoch: int, lr: float):
|
|
31
|
+
if epoch < delay:
|
|
32
|
+
return float(lr)
|
|
33
|
+
return float(lr * keras.ops.exp(-rate))
|
|
34
|
+
|
|
35
|
+
super().__init__(schedule=lr_schedule, **kwargs)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class Rollback(keras.callbacks.Callback):
|
|
39
|
+
"""Rollback callback.
|
|
40
|
+
|
|
41
|
+
Currently, this callback simply restores the model and (optionally) the optimizer
|
|
42
|
+
variables if current loss deviates too much from the best observed loss.
|
|
43
|
+
|
|
44
|
+
This callback might be useful in situations where the loss tend to spike and put
|
|
45
|
+
the model in an undesired/problematic high-loss parameter space.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
tolerance (float):
|
|
49
|
+
The threshold for when the restoration is triggered. The devaiation is
|
|
50
|
+
calculated as follows: (current_loss - best_loss) / best_loss.
|
|
51
|
+
"""
|
|
52
|
+
|
|
53
|
+
def __init__(
|
|
54
|
+
self,
|
|
55
|
+
tolerance: float = 0.5,
|
|
56
|
+
rollback_optimizer: bool = True,
|
|
57
|
+
):
|
|
58
|
+
super().__init__()
|
|
59
|
+
self.tolerance = tolerance
|
|
60
|
+
self.rollback_optimizer = rollback_optimizer
|
|
61
|
+
|
|
62
|
+
def on_train_begin(self, logs=None):
|
|
63
|
+
self._rollback_weights = self._get_model_vars()
|
|
64
|
+
if self.rollback_optimizer:
|
|
65
|
+
self._rollback_optimizer_vars = self._get_optimizer_vars()
|
|
66
|
+
self._rollback_loss = float('inf')
|
|
67
|
+
|
|
68
|
+
def on_epoch_end(self, epoch: int, logs: dict = None):
|
|
69
|
+
current_loss = logs.get('val_loss', logs.get('loss'))
|
|
70
|
+
deviation = (current_loss - self._rollback_loss) / self._rollback_loss
|
|
71
|
+
|
|
72
|
+
if np.isnan(current_loss) or np.isinf(current_loss):
|
|
73
|
+
self._rollback()
|
|
74
|
+
# Rolling back model because of nan or inf loss
|
|
75
|
+
return
|
|
76
|
+
|
|
77
|
+
if deviation > self.tolerance:
|
|
78
|
+
self._rollback()
|
|
79
|
+
# Rolling back model because of large loss deviation.
|
|
80
|
+
return
|
|
81
|
+
|
|
82
|
+
if current_loss < self._rollback_loss:
|
|
83
|
+
self._save_state(current_loss)
|
|
84
|
+
|
|
85
|
+
def _save_state(self, current_loss: float) -> None:
|
|
86
|
+
self._rollback_loss = current_loss
|
|
87
|
+
self._rollback_weights = self._get_model_vars()
|
|
88
|
+
if self.rollback_optimizer:
|
|
89
|
+
self._rollback_optimizer_vars = self._get_optimizer_vars()
|
|
90
|
+
|
|
91
|
+
def _rollback(self) -> None:
|
|
92
|
+
self.model.set_weights(self._rollback_weights)
|
|
93
|
+
if self.rollback_optimizer:
|
|
94
|
+
self.model.optimizer.set_weights(self._rollback_optimizer_vars)
|
|
95
|
+
|
|
96
|
+
def _get_optimizer_vars(self):
|
|
97
|
+
return [v.numpy() for v in self.model.optimizer.variables]
|
|
98
|
+
|
|
99
|
+
def _get_model_vars(self):
|
|
100
|
+
return self.model.get_weights()
|
|
@@ -102,18 +102,20 @@ class Mol(Chem.Mol):
|
|
|
102
102
|
|
|
103
103
|
def get_conformer(self, index: int = 0) -> 'Conformer':
|
|
104
104
|
if self.num_conformers == 0:
|
|
105
|
-
warn(
|
|
105
|
+
warnings.warn(
|
|
106
106
|
'Molecule has no conformer. To embed conformer(s), invoke the `embed` method, '
|
|
107
|
-
'and optionally followed by `minimize()` to perform force field minimization.'
|
|
107
|
+
'and optionally followed by `minimize()` to perform force field minimization.',
|
|
108
|
+
stacklevel=2
|
|
108
109
|
)
|
|
109
110
|
return None
|
|
110
111
|
return Conformer.cast(self.GetConformer(index))
|
|
111
112
|
|
|
112
113
|
def get_conformers(self) -> list['Conformer']:
|
|
113
114
|
if self.num_conformers == 0:
|
|
114
|
-
warn(
|
|
115
|
+
warnings.warn(
|
|
115
116
|
'Molecule has no conformers. To embed conformers, invoke the `embed` method, '
|
|
116
|
-
'and optionally followed by `minimize()` to perform force field minimization.'
|
|
117
|
+
'and optionally followed by `minimize()` to perform force field minimization.',
|
|
118
|
+
stacklevel=2
|
|
117
119
|
)
|
|
118
120
|
return []
|
|
119
121
|
return [Conformer.cast(x) for x in self.GetConformers()]
|
|
@@ -400,7 +402,6 @@ def embed_conformers(
|
|
|
400
402
|
mol: Mol,
|
|
401
403
|
num_conformers: int,
|
|
402
404
|
method: str = 'ETKDGv3',
|
|
403
|
-
force: bool = True,
|
|
404
405
|
**kwargs
|
|
405
406
|
) -> None:
|
|
406
407
|
available_embedding_methods = {
|
|
@@ -411,27 +412,40 @@ def embed_conformers(
|
|
|
411
412
|
'srETKDGv3': rdDistGeom.srETKDGv3(),
|
|
412
413
|
'KDG': rdDistGeom.KDG()
|
|
413
414
|
}
|
|
414
|
-
default_embedding_method = 'ETKDGv3'
|
|
415
415
|
mol = Mol(mol)
|
|
416
|
-
|
|
417
|
-
if
|
|
418
|
-
|
|
419
|
-
f
|
|
420
|
-
|
|
416
|
+
embedding_method = available_embedding_methods.get(method)
|
|
417
|
+
if embedding_method is None:
|
|
418
|
+
raise ValueError(
|
|
419
|
+
f'Could not find `method` {method!r}. Specify either of: '
|
|
420
|
+
'`ETDG`, `ETKDG`, `ETKDGv2`, `ETKDGv3`, `srETKDGv3` or `KDG`.'
|
|
421
421
|
)
|
|
422
|
-
|
|
422
|
+
|
|
423
423
|
for key, value in kwargs.items():
|
|
424
|
-
setattr(
|
|
424
|
+
setattr(embedding_method, key, value)
|
|
425
425
|
|
|
426
|
-
success = rdDistGeom.EmbedMultipleConfs(
|
|
426
|
+
success = rdDistGeom.EmbedMultipleConfs(
|
|
427
|
+
mol, numConfs=num_conformers, params=embedding_method
|
|
428
|
+
)
|
|
427
429
|
if not len(success):
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
430
|
+
warnings.warn(
|
|
431
|
+
f'Could not embed conformer(s) for {mol.canonical_smiles!r} using the '
|
|
432
|
+
'speified method. Giving it another try with more permissive methods.',
|
|
433
|
+
stacklevel=2
|
|
434
|
+
)
|
|
435
|
+
max_attempts = (20 * mol.num_atoms) # increasing it from 10xN to 20xN
|
|
436
|
+
for fallback_method in [method, 'ETDG', 'KDG']:
|
|
437
|
+
fallback_embedding_method = available_embedding_methods[fallback_method]
|
|
438
|
+
fallback_embedding_method.useRandomCoords = True
|
|
439
|
+
fallback_embedding_method.maxAttempts = max_attempts
|
|
440
|
+
success = rdDistGeom.EmbedMultipleConfs(
|
|
441
|
+
mol, numConfs=num_conformers, params=fallback_embedding_method
|
|
442
|
+
)
|
|
443
|
+
if len(success):
|
|
444
|
+
break
|
|
431
445
|
else:
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
446
|
+
raise RuntimeError(
|
|
447
|
+
f'Could not embed conformer(s) for {mol.canonical_smiles!r}. '
|
|
448
|
+
)
|
|
435
449
|
return mol
|
|
436
450
|
|
|
437
451
|
def optimize_conformers(
|
|
@@ -445,6 +459,11 @@ def optimize_conformers(
|
|
|
445
459
|
available_force_field_methods = [
|
|
446
460
|
'MMFF', 'MMFF94', 'MMFF94s', 'UFF'
|
|
447
461
|
]
|
|
462
|
+
if method not in available_force_field_methods:
|
|
463
|
+
raise ValueError(
|
|
464
|
+
f'Could not find `method` {method!r}. Specify either of: '
|
|
465
|
+
'`UFF`, `MMFF`, `MMFF94` or `MMFF94s`.'
|
|
466
|
+
)
|
|
448
467
|
mol = Mol(mol)
|
|
449
468
|
try:
|
|
450
469
|
if method.startswith('MMFF'):
|
|
@@ -467,9 +486,10 @@ def optimize_conformers(
|
|
|
467
486
|
ignore_interfragment_interactions=ignore_interfragment_interactions,
|
|
468
487
|
)
|
|
469
488
|
except RuntimeError as e:
|
|
470
|
-
warn(
|
|
489
|
+
warnings.warn(
|
|
471
490
|
f'{method} force field minimization raised {e}. '
|
|
472
|
-
'\nProceeding without force field minimization
|
|
491
|
+
'\nProceeding without force field minimization.',
|
|
492
|
+
stacklevel=2
|
|
473
493
|
)
|
|
474
494
|
return mol
|
|
475
495
|
|
|
@@ -480,9 +500,10 @@ def prune_conformers(
|
|
|
480
500
|
energy_force_field: str = 'UFF',
|
|
481
501
|
):
|
|
482
502
|
if mol.num_conformers == 0:
|
|
483
|
-
warn(
|
|
503
|
+
warnings.warn(
|
|
484
504
|
'Molecule has no conformers. To embed conformers, invoke the `embed` method, '
|
|
485
|
-
'and optionally followed by `minimize()` to perform force field minimization.'
|
|
505
|
+
'and optionally followed by `minimize()` to perform force field minimization.',
|
|
506
|
+
stacklevel=2
|
|
486
507
|
)
|
|
487
508
|
return mol
|
|
488
509
|
|
|
@@ -658,9 +679,3 @@ def _atom_pair_fingerprint(
|
|
|
658
679
|
fp_param = {'fpSize': size}
|
|
659
680
|
return _get_fingerprint(mol, 'atom_pair', binary, dtype, **fp_param)
|
|
660
681
|
|
|
661
|
-
def warn(message: str) -> None:
|
|
662
|
-
warnings.warn(
|
|
663
|
-
message=message,
|
|
664
|
-
category=UserWarning,
|
|
665
|
-
stacklevel=1,
|
|
666
|
-
)
|
|
@@ -24,19 +24,16 @@ class ConformerEmbedder(ConformerProcessor):
|
|
|
24
24
|
self,
|
|
25
25
|
method: str = 'ETKDGv3',
|
|
26
26
|
num_conformers: int = 5,
|
|
27
|
-
force: bool = True,
|
|
28
27
|
**kwargs,
|
|
29
28
|
) -> None:
|
|
30
29
|
self.method = method
|
|
31
30
|
self.num_conformers = num_conformers
|
|
32
|
-
self.force = force
|
|
33
31
|
self.kwargs = kwargs
|
|
34
32
|
|
|
35
33
|
def get_config(self) -> dict:
|
|
36
34
|
config = {
|
|
37
35
|
'method': self.method,
|
|
38
36
|
'num_conformers': self.num_conformers,
|
|
39
|
-
'force': self.force,
|
|
40
37
|
}
|
|
41
38
|
config.update({
|
|
42
39
|
k: v for (k, v) in self.kwargs.items()
|
|
@@ -48,7 +45,6 @@ class ConformerEmbedder(ConformerProcessor):
|
|
|
48
45
|
mol,
|
|
49
46
|
method=self.method,
|
|
50
47
|
num_conformers=self.num_conformers,
|
|
51
|
-
force=self.force,
|
|
52
48
|
**self.kwargs,
|
|
53
49
|
)
|
|
54
50
|
|
|
@@ -110,9 +110,10 @@ class Feature(abc.ABC):
|
|
|
110
110
|
'type `float`, `int`, `bool` or `None`.'
|
|
111
111
|
)
|
|
112
112
|
if not math.isfinite(value):
|
|
113
|
-
warn(
|
|
113
|
+
warnings.warn(
|
|
114
114
|
f'Found value of {self.name} to be non-finite. '
|
|
115
|
-
f'Value received: {value}. Converting it to a value of 0.'
|
|
115
|
+
f'Value received: {value}. Converting it to a value of 0.',
|
|
116
|
+
stacklevel=2
|
|
116
117
|
)
|
|
117
118
|
value = 0.0
|
|
118
119
|
return np.asarray([value], dtype=self.dtype)
|
|
@@ -380,10 +381,3 @@ default_vocabulary = {
|
|
|
380
381
|
],
|
|
381
382
|
}
|
|
382
383
|
|
|
383
|
-
|
|
384
|
-
def warn(message: str) -> None:
|
|
385
|
-
warnings.warn(
|
|
386
|
-
message=message,
|
|
387
|
-
category=UserWarning,
|
|
388
|
-
stacklevel=1
|
|
389
|
-
)
|
|
@@ -180,6 +180,12 @@ class MolGraphFeaturizer(Featurizer):
|
|
|
180
180
|
bond_features = [
|
|
181
181
|
features.BondType(vocab)
|
|
182
182
|
]
|
|
183
|
+
if not default_bond_features and self.radius > 1:
|
|
184
|
+
warnings.warn(
|
|
185
|
+
'Replacing user-specified bond features with default bond features, '
|
|
186
|
+
'as `radius`>1. When `radius`>1, only bond types are considered.',
|
|
187
|
+
stacklevel=2
|
|
188
|
+
)
|
|
183
189
|
default_molecule_features = (
|
|
184
190
|
molecule_features == 'auto' or molecule_features == 'default'
|
|
185
191
|
)
|
|
@@ -213,9 +219,10 @@ class MolGraphFeaturizer(Featurizer):
|
|
|
213
219
|
mol = chem.Mol.from_encoding(x, explicit_hs=self.include_hs)
|
|
214
220
|
|
|
215
221
|
if mol is None:
|
|
216
|
-
warn(
|
|
222
|
+
warnings.warn(
|
|
217
223
|
f'Could not obtain `chem.Mol` from {x}. '
|
|
218
|
-
'Returning `None` (proceeding without it).'
|
|
224
|
+
'Returning `None` (proceeding without it).',
|
|
225
|
+
stacklevel=2
|
|
219
226
|
)
|
|
220
227
|
return None
|
|
221
228
|
|
|
@@ -245,10 +252,11 @@ class MolGraphFeaturizer(Featurizer):
|
|
|
245
252
|
|
|
246
253
|
if molecule_feature is not None:
|
|
247
254
|
if 'feature' in context:
|
|
248
|
-
warn(
|
|
255
|
+
warnings.warn(
|
|
249
256
|
'Found both inputted and computed context feature. '
|
|
250
257
|
'Overwriting inputted context feature with computed '
|
|
251
|
-
'context feature (based on `molecule_features`).'
|
|
258
|
+
'context feature (based on `molecule_features`).',
|
|
259
|
+
stacklevel=2
|
|
252
260
|
)
|
|
253
261
|
context['feature'] = molecule_feature
|
|
254
262
|
|
|
@@ -284,9 +292,6 @@ class MolGraphFeaturizer(Featurizer):
|
|
|
284
292
|
edge['target'] = np.asarray(
|
|
285
293
|
[path[-1] for path in paths], dtype=self.index_dtype
|
|
286
294
|
)
|
|
287
|
-
edge['length'] = np.asarray(
|
|
288
|
-
[len(path) - 1 for path in paths], dtype=self.index_dtype
|
|
289
|
-
)
|
|
290
295
|
if bond_feature is not None:
|
|
291
296
|
zero_bond_feature = np.array(
|
|
292
297
|
[[1., 0., 0., 0., 0.]], dtype=bond_feature.dtype
|
|
@@ -297,7 +302,6 @@ class MolGraphFeaturizer(Featurizer):
|
|
|
297
302
|
edge['feature'] = self._expand_bond_features(
|
|
298
303
|
mol, paths, bond_feature,
|
|
299
304
|
)
|
|
300
|
-
edge['length'] = np.eye(self.radius + 1, dtype=self.feature_dtype)[edge['length']]
|
|
301
305
|
|
|
302
306
|
if self.super_atom:
|
|
303
307
|
node, edge = self._add_super_atom(node, edge)
|
|
@@ -533,9 +537,10 @@ class MolGraphFeaturizer3D(MolGraphFeaturizer):
|
|
|
533
537
|
mol = chem.Mol.from_encoding(x, explicit_hs=explicit_hs)
|
|
534
538
|
|
|
535
539
|
if mol is None:
|
|
536
|
-
warn(
|
|
540
|
+
warnings.warn(
|
|
537
541
|
f'Could not obtain `chem.Mol` from {x}. '
|
|
538
|
-
'Proceeding without it.'
|
|
542
|
+
'Proceeding without it.',
|
|
543
|
+
stacklevel=2
|
|
539
544
|
)
|
|
540
545
|
return None
|
|
541
546
|
|
|
@@ -575,10 +580,11 @@ class MolGraphFeaturizer3D(MolGraphFeaturizer):
|
|
|
575
580
|
|
|
576
581
|
if molecule_feature is not None:
|
|
577
582
|
if 'feature' in context:
|
|
578
|
-
warn(
|
|
583
|
+
warnings.warn(
|
|
579
584
|
'Found both inputted and computed context feature. '
|
|
580
585
|
'Overwriting inputted context feature with computed '
|
|
581
|
-
'context feature (based on `molecule_features`).'
|
|
586
|
+
'context feature (based on `molecule_features`).',
|
|
587
|
+
stacklevel=2
|
|
582
588
|
)
|
|
583
589
|
context['feature'] = molecule_feature
|
|
584
590
|
|
|
@@ -740,23 +746,9 @@ def _add_super_edges(
|
|
|
740
746
|
edge['self_loop'], [(0, num_nodes * num_super_nodes * 2)],
|
|
741
747
|
constant_values=False,
|
|
742
748
|
)
|
|
743
|
-
if 'length' in edge:
|
|
744
|
-
edge['length'] = np.pad(edge['length'], [(0, 0), (1, 0)])
|
|
745
|
-
zero_array = np.zeros([num_nodes * num_super_nodes * 2], dtype='int32')
|
|
746
|
-
edge_length_dim = edge['length'].shape[1]
|
|
747
|
-
virtual_edge_length = np.eye(edge_length_dim)[zero_array]
|
|
748
|
-
edge['length'] = np.concatenate([edge['length'], virtual_edge_length])
|
|
749
|
-
edge['length'] = edge['length'].astype(feature_dtype)
|
|
750
749
|
|
|
751
750
|
return edge
|
|
752
751
|
|
|
753
|
-
|
|
754
|
-
def warn(message: str) -> None:
|
|
755
|
-
warnings.warn(
|
|
756
|
-
message=message,
|
|
757
|
-
category=UserWarning,
|
|
758
|
-
stacklevel=1
|
|
759
|
-
)
|
|
760
752
|
|
|
761
753
|
MolFeaturizer = MolGraphFeaturizer
|
|
762
754
|
MolFeaturizer3D = MolGraphFeaturizer3D
|