molcraft 0.1.0a22__py3-none-any.whl → 0.1.0a23__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of molcraft might be problematic. Click here for more details.

molcraft/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- __version__ = '0.1.0a22'
1
+ __version__ = '0.1.0a23'
2
2
 
3
3
  import os
4
4
  os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
@@ -14,4 +14,4 @@ from molcraft import records
14
14
  from molcraft import tensors
15
15
  from molcraft import callbacks
16
16
  from molcraft import datasets
17
- from molcraft import losses
17
+ from molcraft import losses
molcraft/callbacks.py CHANGED
@@ -1,5 +1,5 @@
1
- import keras
2
1
  import warnings
2
+ import keras
3
3
  import numpy as np
4
4
 
5
5
 
molcraft/chem.py CHANGED
@@ -3,6 +3,7 @@ import collections
3
3
  import numpy as np
4
4
 
5
5
  from rdkit import Chem
6
+ from rdkit.Chem import AllChem
6
7
  from rdkit.Chem import Lipinski
7
8
  from rdkit.Chem import rdDistGeom
8
9
  from rdkit.Chem import rdDepictor
@@ -31,10 +32,8 @@ class Mol(Chem.Mol):
31
32
 
32
33
  @property
33
34
  def encoding(self):
34
- if hasattr(self, '_encoding'):
35
- return self._encoding
36
- return None
37
-
35
+ return getattr(self, '_encoding', None)
36
+
38
37
  @property
39
38
  def bonds(self) -> list['Bond']:
40
39
  if not hasattr(self, '_bonds'):
@@ -67,7 +66,7 @@ class Mol(Chem.Mol):
67
66
  atom = atom.GetIdx()
68
67
  return Atom.cast(self.GetAtomWithIdx(int(atom)))
69
68
 
70
- def get_path_between_atoms(
69
+ def get_shortest_path_between_atoms(
71
70
  self,
72
71
  atom_i: int | Chem.Atom,
73
72
  atom_j: int | Chem.Atom
@@ -107,13 +106,13 @@ class Mol(Chem.Mol):
107
106
 
108
107
  def get_conformer(self, index: int = 0) -> 'Conformer':
109
108
  if self.num_conformers == 0:
110
- warnings.warn('Molecule has no conformer.')
109
+ warnings.warn(f'{self} has no conformer. Returning None.')
111
110
  return None
112
111
  return Conformer.cast(self.GetConformer(index))
113
112
 
114
113
  def get_conformers(self) -> list['Conformer']:
115
114
  if self.num_conformers == 0:
116
- warnings.warn('Molecule has no conformer.')
115
+ warnings.warn(f'{self} has no conformers. Returning an empty list.')
117
116
  return []
118
117
  return [Conformer.cast(x) for x in self.GetConformers()]
119
118
 
@@ -124,7 +123,8 @@ class Mol(Chem.Mol):
124
123
  return None
125
124
 
126
125
  def __repr__(self) -> str:
127
- return f'<{self.__class__.__name__} {self.canonical_smiles} at {hex(id(self))}>'
126
+ encoding = self.encoding or self.canonical_smiles
127
+ return f'<{self.__class__.__name__} {encoding} at {hex(id(self))}>'
128
128
 
129
129
 
130
130
  class Conformer(Chem.Conformer):
@@ -251,7 +251,10 @@ def sanitize_mol(
251
251
  flag = Chem.SanitizeMol(mol, catchErrors=True)
252
252
  if flag != Chem.SanitizeFlags.SANITIZE_NONE:
253
253
  if strict:
254
- return None
254
+ raise ValueError(f'Could not sanitize {mol}.')
255
+ warnings.warn(
256
+ f'Could not sanitize {mol}. Proceeding with partial sanitization.'
257
+ )
255
258
  # Sanitize mol, excluding the steps causing the error previously
256
259
  Chem.SanitizeMol(mol, sanitizeOps=Chem.SanitizeFlags.SANITIZE_ALL^flag)
257
260
  if assign_stereo_chemistry:
@@ -411,13 +414,12 @@ def embed_conformers(
411
414
  'KDG': rdDistGeom.KDG()
412
415
  }
413
416
  mol = Mol(mol)
414
- encoding = mol.encoding or mol.canonical_smiles
415
417
  embedding_method = available_embedding_methods.get(method)
416
418
  if embedding_method is None:
417
- raise ValueError(
418
- f'Could not find `method` {method!r}. Specify either of: '
419
- '`ETDG`, `ETKDG`, `ETKDGv2`, `ETKDGv3`, `srETKDGv3` or `KDG`.'
419
+ warnings.warn(
420
+ f'{method} is not available. Proceeding with ETKDGv3.'
420
421
  )
422
+ embedding_method = available_embedding_methods['ETKDGv3']
421
423
 
422
424
  for key, value in kwargs.items():
423
425
  setattr(embedding_method, key, value)
@@ -438,8 +440,8 @@ def embed_conformers(
438
440
  if num_successes < num_conformers:
439
441
  warnings.warn(
440
442
  f'Could only embed {num_successes} out of {num_conformers} conformer(s) for '
441
- f'{encoding!r} using the specified method ({method!r}) and parameters. Attempting '
442
- f'to embed the remaining {num_conformers-num_successes} using fallback methods.',
443
+ f'{mol} using the specified method ({method}) and parameters. Attempting to '
444
+ f'embed the remaining {num_conformers-num_successes} using fallback methods.',
443
445
  )
444
446
  max_iters = 20 * mol.num_atoms # Doubling the number of iterations
445
447
  for fallback_method in [method, 'ETDG', 'KDG']:
@@ -457,10 +459,13 @@ def embed_conformers(
457
459
  break
458
460
  else:
459
461
  raise RuntimeError(
460
- f'Could not embed {num_conformers} conformer(s) for {encoding!r}. '
462
+ f'Could not embed {num_conformers} conformer(s) for {mol}. '
461
463
  )
462
464
  return mol
463
465
 
466
+
467
+ import warnings
468
+
464
469
  def optimize_conformers(
465
470
  mol: Mol,
466
471
  method: str = 'UFF',
@@ -469,14 +474,17 @@ def optimize_conformers(
469
474
  ignore_interfragment_interactions: bool = True,
470
475
  vdw_threshold: float = 10.0,
471
476
  ) -> Mol:
472
- available_force_field_methods = [
473
- 'MMFF', 'MMFF94', 'MMFF94s', 'UFF'
474
- ]
477
+ if mol.num_conformers == 0:
478
+ warnings.warn(
479
+ f'{mol} has no conformers to optimize. Proceeding without it.'
480
+ )
481
+ return Mol(mol)
482
+ available_force_field_methods = ['MMFF', 'MMFF94', 'MMFF94s', 'UFF']
475
483
  if method not in available_force_field_methods:
476
- raise ValueError(
477
- f'Could not find `method` {method!r}. Specify either of: '
478
- '`UFF`, `MMFF`, `MMFF94` or `MMFF94s`.'
484
+ warnings.warn(
485
+ f'{method} is not available. Proceeding with universal force field (UFF).'
479
486
  )
487
+ method = 'UFF'
480
488
  mol_optimized = Mol(mol)
481
489
  try:
482
490
  if method.startswith('MMFF'):
@@ -500,7 +508,7 @@ def optimize_conformers(
500
508
  )
501
509
  except RuntimeError as e:
502
510
  warnings.warn(
503
- f'{method} force field minimization did not succeed. Proceeding without it.',
511
+ f'Unsuccessful {method} force field minimization for {mol}. Proceeding without it.',
504
512
  )
505
513
  return Mol(mol)
506
514
  return mol_optimized
@@ -513,10 +521,9 @@ def prune_conformers(
513
521
  ) -> Mol:
514
522
  if mol.num_conformers == 0:
515
523
  warnings.warn(
516
- 'Molecule has no conformers. To embed conformers, invoke the `embed` method, '
517
- 'and optionally followed by `minimize()` to perform force field minimization.',
524
+ f'{mol} has no conformers to prune. Proceeding without it.'
518
525
  )
519
- return mol
526
+ return Chem.Mol(mol)
520
527
 
521
528
  threshold = threshold or 0.0
522
529
  deviations = conformer_deviations(mol)
molcraft/datasets.py CHANGED
@@ -1,3 +1,4 @@
1
+ import warnings
1
2
  import numpy as np
2
3
  import pandas as pd
3
4
  import typing
molcraft/descriptors.py CHANGED
@@ -1,5 +1,7 @@
1
+ import warnings
1
2
  import keras
2
3
  import numpy as np
4
+
3
5
  from rdkit.Chem import rdMolDescriptors
4
6
 
5
7
  from molcraft import chem
@@ -12,9 +14,7 @@ class Descriptor(features.Feature):
12
14
  def __call__(self, mol: chem.Mol) -> np.ndarray:
13
15
  if not isinstance(mol, chem.Mol):
14
16
  raise ValueError(
15
- f'Input to {self.name} needs to be a `chem.Mol`, which '
16
- 'implements two properties that should be iterated over '
17
- 'to compute features: `atoms` and `bonds`.'
17
+ f'Input to {self.name} must be a `chem.Mol` object.'
18
18
  )
19
19
  descriptor = self.call(mol)
20
20
  func = (
@@ -30,6 +30,23 @@ class Descriptor(features.Feature):
30
30
  return np.concatenate(descriptors)
31
31
 
32
32
 
33
+ @keras.saving.register_keras_serializable(package='molcraft')
34
+ class Descriptor3D(Descriptor):
35
+
36
+ def __call__(self, mol: chem.Mol) -> np.ndarray:
37
+ if not isinstance(mol, chem.Mol):
38
+ raise ValueError(
39
+ f'Input to {self.name} must be a `chem.Mol` object.'
40
+ )
41
+ if mol.num_conformers == 0:
42
+ raise ValueError(
43
+ f'The inputted `chem.Mol` to {self.name} must embed a conformer. '
44
+ f'It is recommended that {self.name} is used as a molecule feature '
45
+ 'for `MolGraphFeaturizer3D`, which by default embeds a conformer.'
46
+ )
47
+ return super().__call__(mol)
48
+
49
+
33
50
  @keras.saving.register_keras_serializable(package='molcraft')
34
51
  class MolWeight(Descriptor):
35
52
  def call(self, mol: chem.Mol) -> np.ndarray:
@@ -77,7 +94,7 @@ class NumHydrogenDonors(Descriptor):
77
94
  @keras.saving.register_keras_serializable(package='molcraft')
78
95
  class NumHydrogenAcceptors(Descriptor):
79
96
  def call(self, mol: chem.Mol) -> np.ndarray:
80
- return rdMolDescriptors.CalcNumHBA(mol)
97
+ return rdMolDescriptors.CalcNumHBA(mol)
81
98
 
82
99
 
83
100
  @keras.saving.register_keras_serializable(package='molcraft')
@@ -89,7 +106,7 @@ class NumRotatableBonds(Descriptor):
89
106
  @keras.saving.register_keras_serializable(package='molcraft')
90
107
  class NumRings(Descriptor):
91
108
  def call(self, mol: chem.Mol) -> np.ndarray:
92
- return rdMolDescriptors.CalcNumRings(mol)
109
+ return rdMolDescriptors.CalcNumRings(mol)
93
110
 
94
111
 
95
112
  @keras.saving.register_keras_serializable(package='molcraft')
@@ -105,3 +122,18 @@ class AtomCount(Descriptor):
105
122
  if atom.GetSymbol() == self.atom_type:
106
123
  count += 1
107
124
  return count
125
+
126
+ def get_config(self) -> dict:
127
+ config = super().get_config()
128
+ config['atom_type'] = self.atom_type
129
+ return config
130
+
131
+
132
+ @keras.saving.register_keras_serializable(package='molcraft')
133
+ class ForceFieldEnergy(Descriptor3D):
134
+ """Universal Force Field (UFF) Energy."""
135
+ def call(self, mol: chem.Mol) -> np.ndarray:
136
+ mol_copy = chem.Mol(mol)
137
+ mol_copy = chem.add_hs(mol_copy)
138
+ return chem.conformer_energies(mol_copy, method="UFF")
139
+
molcraft/diffusion.py ADDED
@@ -0,0 +1,241 @@
1
+ import warnings
2
+ import keras
3
+ import tensorflow as tf
4
+ import numpy as np
5
+
6
+ from molcraft import ops
7
+ from molcraft import tensors
8
+ from molcraft import layers
9
+ from molcraft import models
10
+
11
+
12
+ # smiles = pd.read_csv('../../data/rt/RIKEN.csv')['smiles'].values
13
+
14
+
15
+ # graph = featurizers.MolGraphFeaturizer3D(super_node=False)(smiles)
16
+ # graph.node['coordinate']
17
+
18
+
19
+ # encoder = molcraft.models.GraphModel.from_layers(
20
+ # [
21
+ # diffusion.CoordinateNoise(),
22
+ # molcraft.layers.NodeEmbedding(128),
23
+ # molcraft.layers.EdgeEmbedding(128),
24
+ # molcraft.layers.AddContext('position'),
25
+ # molcraft.layers.MPConv(128),
26
+ # molcraft.layers.AddContext('position'),
27
+ # molcraft.layers.MPConv(128),
28
+ # molcraft.layers.AddContext('position'),
29
+ # ]
30
+ # )
31
+
32
+ # decoder = keras.Sequential([
33
+ # keras.layers.Dense(128, activation='relu'),
34
+ # keras.layers.Dense(3),
35
+ # ])
36
+
37
+ # model = diffusion.CoordinateNoisePredictor(encoder, decoder)
38
+
39
+ # model(graph)
40
+
41
+ # model.save('/tmp/model.keras')
42
+ # model = molcraft.models.load_model('/tmp/model.keras')
43
+
44
+ # model.compile(keras.optimizers.Adam(1e-3), 'mse')
45
+ # model.fit(graph, epochs=100)
46
+
47
+
48
+ # from rdkit.Geometry import Point3D
49
+
50
+ # def energy(smiles, coordinate):
51
+ # m = chem.Mol.from_encoding(smiles)
52
+ # m = chem.embed_conformers(m, 1)
53
+ # conf = m.GetConformer()
54
+ # for i in range(m.GetNumAtoms()):
55
+ # x, y, z = coordinate[i]
56
+ # conf.SetAtomPosition(i, Point3D(float(x), float(y), float(z)))
57
+ # return m, chem.conformer_energies(m)[0]
58
+
59
+ # def denoise(
60
+ # graph: tensors.GraphTensor,
61
+ # model,
62
+ # ):
63
+
64
+ # print("----")
65
+ # print(energy(smiles[0], graph[0].node['coordinate'])[-1])
66
+ # print("----")
67
+
68
+ # beta = keras.ops.linspace(1e-4, 1e-2, 100)
69
+ # alpha = 1 - beta
70
+ # alpha_bar = keras.ops.cumprod(alpha)
71
+ # sigma = keras.ops.sqrt(beta[1:] * (1.0 - alpha_bar[:-1]) / (1.0 - alpha_bar[1:]))
72
+
73
+ # graph = graph.update(
74
+ # {
75
+ # 'context': {
76
+ # 'position': keras.ops.ones_like(graph.context['size']) * 99
77
+ # },
78
+ # 'node': {
79
+ # 'coordinate': keras.random.normal(graph.node['coordinate'].shape)
80
+ # }
81
+ # }
82
+ # )
83
+
84
+ # for t in reversed(range(100)):
85
+ # alpha_t = alpha[t]
86
+ # alpha_bar_t = alpha_bar[t]
87
+
88
+ # a = 1 / keras.ops.sqrt(alpha_t)
89
+
90
+ # b = (1 - alpha_t) / keras.ops.sqrt(1 - alpha_bar_t)
91
+
92
+ # if t > 0:
93
+ # z = keras.random.normal(()) * sigma[t-1]
94
+ # else:
95
+ # z = 0.0
96
+
97
+ # graph = graph.update({
98
+ # 'node': {
99
+ # 'coordinate': (
100
+ # a * (graph.node['coordinate'] - b * model(graph)) + z
101
+ # )
102
+ # }
103
+ # })
104
+
105
+
106
+ # print(energy(smiles[0], graph[0].node['coordinate'])[-1])
107
+
108
+ # return graph
109
+
110
+ # graph_updated = denoise(graph[:1], model)x
111
+ # mol, e = energy(smiles[0], graph_updated[0].node['coordinate'])
112
+ # print(e)
113
+ # Chem.Mol(mol)
114
+
115
+ @keras.saving.register_keras_serializable(package='molcraft')
116
+ class CoordinateNoisePredictor(models.GraphModel):
117
+
118
+ def __init__(self, encoder, decoder, *args, **kwargs):
119
+ super().__init__(*args, **kwargs)
120
+ self.encoder = encoder
121
+ self.decoder = decoder
122
+
123
+ def propagate(self, tensor):
124
+ return self.decoder(self.encoder(tensor).node['feature'])
125
+
126
+ def train_step(self, tensor: tensors.GraphTensor) -> dict[str, float]:
127
+ with tf.GradientTape() as tape:
128
+ tensor = self.encoder(tensor)
129
+ feature = tensor.node['feature']
130
+ noise_true = tensor.node['label']
131
+ noise_pred = self.decoder(feature)
132
+ loss = self.compute_loss(tensor, noise_true, noise_pred)
133
+ loss = self.optimizer.scale_loss(loss)
134
+ trainable_weights = self.trainable_weights
135
+ gradients = tape.gradient(loss, trainable_weights)
136
+ self.optimizer.apply_gradients(zip(gradients, trainable_weights))
137
+ return self.compute_metrics(tensor, noise_true, noise_pred)
138
+
139
+ def test_step(self, tensor: tensors.GraphTensor) -> dict[str, float]:
140
+ tensor = self.encoder(tensor)
141
+ feature = tensor.node['feature']
142
+ noise_true = tensor.node['label']
143
+ noise_pred = self.decoder(feature)
144
+ return self.compute_metrics(tensor, noise_true, noise_pred)
145
+
146
+ def get_config(self) -> dict:
147
+ config = super().get_config()
148
+ config['encoder'] = keras.saving.serialize_keras_object(self.encoder)
149
+ config['decoder'] = keras.saving.serialize_keras_object(self.decoder)
150
+ return config
151
+
152
+ @classmethod
153
+ def from_config(cls, config: dict):
154
+ config['encoder'] = keras.saving.deserialize_keras_object(config['encoder'])
155
+ config['decoder'] = keras.saving.deserialize_keras_object(config['decoder'])
156
+ return super().from_config(config)
157
+
158
+
159
+ @keras.saving.register_keras_serializable(package='molcraft')
160
+ class CoordinateNoise(layers.GraphLayer):
161
+
162
+ def __init__(
163
+ self,
164
+ beta: tuple[float, float] = (1e-4, 1e-2),
165
+ position_dim: int = 128,
166
+ max_timesteps: int = 100,
167
+ **kwargs
168
+ ) -> None:
169
+ super().__init__(**kwargs)
170
+ self._beta = beta
171
+ self._max_timesteps = max_timesteps
172
+ beta = keras.ops.linspace(*self._beta, self._max_timesteps)
173
+ alpha = 1 - beta
174
+ alpha_cumprod = keras.ops.cumprod(alpha)
175
+ alpha_cumprod = keras.ops.expand_dims(alpha_cumprod, -1)
176
+ self._alpha_cumprod = alpha_cumprod
177
+ self._timestep_embedding = TimestepEmbedding(dim=position_dim)
178
+
179
+ def propagate(self, graph: tensors.GraphTensor) -> tensors.GraphTensor:
180
+ if 'position' in graph.context:
181
+ return graph.update({'context': {'position': self._timestep_embedding(graph.context['position'])}})
182
+
183
+ timestep = keras.random.randint(
184
+ shape=(graph.num_subgraphs,), minval=0, maxval=self._max_timesteps
185
+ )
186
+ alpha_cumprod = ops.gather(
187
+ ops.gather(self._alpha_cumprod, timestep), graph.graph_indicator
188
+ )
189
+ epsilon = keras.random.normal(
190
+ shape=keras.ops.shape(graph.node['coordinate']), mean=0, stddev=1
191
+ )
192
+ noisy_coordinate = (
193
+ keras.ops.sqrt(alpha_cumprod) * graph.node['coordinate'] +
194
+ keras.ops.sqrt(1 - alpha_cumprod) * epsilon
195
+ )
196
+ timestep = self._timestep_embedding(timestep)
197
+ return graph.update(
198
+ {
199
+ 'context': {
200
+ 'position': timestep,
201
+ },
202
+ 'node': {
203
+ 'coordinate': noisy_coordinate,
204
+ 'label': epsilon
205
+ },
206
+ }
207
+ )
208
+
209
+ def get_config(self) -> dict:
210
+ config = super().get_config()
211
+ config['beta'] = self._beta
212
+ config['max_timesteps'] = self._max_timesteps
213
+ return config
214
+
215
+
216
+ class TimestepEmbedding(keras.layers.Layer):
217
+
218
+ def __init__(self, dim: int, max_wavelength: int = 10000, **kwargs) -> None:
219
+ super().__init__(**kwargs)
220
+ self._dim = dim
221
+ self._max_wavelength = max_wavelength
222
+
223
+ def call(self, inputs: tf.Tensor) -> tf.Tensor:
224
+ timestep = keras.ops.cast(inputs, 'float32')
225
+ embedding = keras.ops.log(self._max_wavelength) / (self._dim // 2 - 1)
226
+ embedding = keras.ops.exp(
227
+ -embedding * keras.ops.arange(self._dim // 2, dtype='float32')
228
+ )
229
+ embedding = timestep[:, None] * embedding[None, :]
230
+ embedding = keras.ops.concatenate(
231
+ [keras.ops.sin(embedding), keras.ops.cos(embedding)], axis=-1
232
+ )
233
+ return embedding
234
+
235
+
236
+ def get_config(self) -> dict:
237
+ config = super().get_config()
238
+ config['dim'] = self._dim
239
+ config['max_wavelength'] = self._max_wavelength
240
+ return config
241
+
molcraft/features.py CHANGED
@@ -1,7 +1,7 @@
1
+ import warnings
1
2
  import abc
2
3
  import math
3
4
  import keras
4
- import warnings
5
5
  import numpy as np
6
6
 
7
7
  from molcraft import chem
@@ -41,14 +41,14 @@ class Feature(abc.ABC):
41
41
 
42
42
  def __call__(self, mol: chem.Mol) -> np.ndarray:
43
43
  if not isinstance(mol, chem.Mol):
44
- raise TypeError(f'Input to {self.name} must be a `chem.Mol` instance.')
44
+ raise TypeError(f'Input to {self.name} must be a `chem.Mol` object.')
45
45
  features = self.call(mol)
46
46
  if len(features) != mol.num_atoms and len(features) != mol.num_bonds:
47
47
  raise ValueError(
48
48
  f'The number of features computed by {self.name} does not '
49
49
  'match the number of atoms or bonds of the `chem.Mol` object. '
50
- 'Make sure to iterate over `atoms` or `bonds` of `chem.Mol` '
51
- 'when computing features.'
50
+ 'Make sure to iterate over `atoms` or `bonds` of the `chem.Mol` '
51
+ 'object when computing features.'
52
52
  )
53
53
  if len(features) == 0:
54
54
  # Edge case: no atoms or bonds in the molecule.
@@ -109,7 +109,6 @@ class Feature(abc.ABC):
109
109
  warnings.warn(
110
110
  f'Found value of {self.name} to be non-finite. '
111
111
  f'Value received: {value}. Converting it to a value of 0.',
112
- stacklevel=2
113
112
  )
114
113
  value = 0.0
115
114
  return np.asarray([value], dtype=self.dtype)
molcraft/featurizers.py CHANGED
@@ -1,8 +1,8 @@
1
+ import warnings
1
2
  import keras
2
3
  import json
3
4
  import abc
4
5
  import typing
5
- import os
6
6
  import numpy as np
7
7
  import pandas as pd
8
8
  import tensorflow as tf
molcraft/layers.py CHANGED
@@ -1,6 +1,6 @@
1
+ import warnings
1
2
  import keras
2
3
  import tensorflow as tf
3
- import warnings
4
4
  import functools
5
5
  from keras.src.models import functional
6
6
 
@@ -350,11 +350,8 @@ class GraphConv(GraphLayer):
350
350
  )
351
351
  if self._project_residual:
352
352
  warnings.warn(
353
- '`skip_connect` is set to `True`, but found incompatible dim '
354
- 'between input (node feature dim) and output (`self.units`). '
355
- 'Automatically applying a projection layer to residual to '
356
- 'match input and output. ',
357
- stacklevel=2,
353
+ 'Found incompatible dim between input and output. Applying '
354
+ 'a projection layer to residual to match input and output dim.',
358
355
  )
359
356
  self._residual_dense = self.get_dense(
360
357
  self.units, name='residual_dense'
@@ -613,10 +610,8 @@ class GIConv(GraphConv):
613
610
  if not self._update_edge_feature:
614
611
  if (edge_feature_dim != node_feature_dim):
615
612
  warnings.warn(
616
- 'Found edge feature dim to be incompatible with node feature dim. '
617
- 'Automatically adding a edge feature projection layer to match '
618
- 'the dim of node features.',
619
- stacklevel=2,
613
+ 'Found edge and node feature dim to be incompatible. Applying a '
614
+ 'projection layer to edge features to match the dim of the node features.',
620
615
  )
621
616
  self._update_edge_feature = True
622
617
 
@@ -870,10 +865,10 @@ class MPConv(GraphConv):
870
865
  self._project_previous_node_feature = node_feature_dim != self.units
871
866
  if self._project_previous_node_feature:
872
867
  warnings.warn(
873
- 'Input node feature dim does not match updated node feature dim. '
874
- 'To make sure input node feature can be passed as `states` to the '
875
- 'GRU cell, it will automatically be projected prior to it.',
876
- stacklevel=2
868
+ 'Inputted node feature dim does not match updated node feature dim, '
869
+ 'which is required for the GRU update. Applying a projection layer to '
870
+ 'the inputted node features prior to the GRU update, to match dim '
871
+ 'of the updated node feature dim.'
877
872
  )
878
873
  self._previous_node_dense = self.get_dense(self.units)
879
874
 
@@ -1497,6 +1492,7 @@ class AddContext(GraphLayer):
1497
1492
 
1498
1493
  def build(self, spec: tensors.GraphTensor.Spec) -> None:
1499
1494
  feature_dim = spec.node['feature'].shape[-1]
1495
+ self._has_super_node = 'super' in spec.node
1500
1496
  if self._intermediate_dim is None:
1501
1497
  self._intermediate_dim = feature_dim * 2
1502
1498
  self._intermediate_dense = self.get_dense(
@@ -1515,9 +1511,14 @@ class AddContext(GraphLayer):
1515
1511
  context = self._intermediate_dense(context)
1516
1512
  context = self._intermediate_norm(context)
1517
1513
  context = self._final_dense(context)
1518
- node_feature = ops.scatter_add(
1519
- tensor.node['feature'], tensor.node['super'], context
1520
- )
1514
+ if self._has_super_node:
1515
+ node_feature = ops.scatter_add(
1516
+ tensor.node['feature'], tensor.node['super'], context
1517
+ )
1518
+ else:
1519
+ node_feature = (
1520
+ tensor.node['feature'] + ops.gather(context, tensor.graph_indicator)
1521
+ )
1521
1522
  data = {'node': {'feature': node_feature}}
1522
1523
  if self._drop:
1523
1524
  data['context'] = {self._field: None}
@@ -1561,8 +1562,7 @@ class GraphNetwork(GraphLayer):
1561
1562
  if self._update_node_feature:
1562
1563
  warnings.warn(
1563
1564
  'Node feature dim does not match `units` of the first layer. '
1564
- 'Automatically adding a node projection layer to match `units`.',
1565
- stacklevel=2
1565
+ 'Applying a projection layer to node features to match `units`.',
1566
1566
  )
1567
1567
  self._node_dense = self.get_dense(units)
1568
1568
  self._has_edge_feature = 'feature' in spec.edge
@@ -1572,8 +1572,7 @@ class GraphNetwork(GraphLayer):
1572
1572
  if self._update_edge_feature:
1573
1573
  warnings.warn(
1574
1574
  'Edge feature dim does not match `units` of the first layer. '
1575
- 'Automatically adding a edge projection layer to match `units`.',
1576
- stacklevel=2
1575
+ 'Applying projection layer to edge features to match `units`.'
1577
1576
  )
1578
1577
  self._edge_dense = self.get_dense(units)
1579
1578
 
molcraft/losses.py CHANGED
@@ -1,3 +1,4 @@
1
+ import warnings
1
2
  import keras
2
3
  import numpy as np
3
4
 
molcraft/models.py CHANGED
@@ -1,3 +1,4 @@
1
+ import warnings
1
2
  import typing
2
3
  import keras
3
4
  import numpy as np
@@ -111,7 +112,7 @@ class GraphModel(layers.GraphLayer, keras.models.Model):
111
112
  def __new__(cls, *args, **kwargs):
112
113
  if _functional_init_arguments(args, kwargs) and cls == GraphModel:
113
114
  return FunctionalGraphModel(*args, **kwargs)
114
- return typing.cast(GraphModel, super().__new__(cls))
115
+ return super().__new__(cls)
115
116
 
116
117
  def __init__(self, *args, **kwargs):
117
118
  self._model_layers = kwargs.pop('model_layers', None)
@@ -137,6 +138,8 @@ class GraphModel(layers.GraphLayer, keras.models.Model):
137
138
  """
138
139
  if not tensors.is_graph(graph_layers[0]):
139
140
  return cls(model_layers=graph_layers)
141
+ elif cls != GraphModel:
142
+ return cls(model_layers=graph_layers[1:])
140
143
  inputs: dict = graph_layers.pop(0)
141
144
  x = inputs
142
145
  for layer in graph_layers:
molcraft/ops.py CHANGED
@@ -1,3 +1,4 @@
1
+ import warnings
1
2
  import keras
2
3
  import numpy as np
3
4
  import tensorflow as tf
molcraft/records.py CHANGED
@@ -1,9 +1,9 @@
1
+ import warnings
1
2
  import os
2
3
  import math
3
4
  import glob
4
5
  import time
5
6
  import typing
6
- import warnings
7
7
  import tensorflow as tf
8
8
  import numpy as np
9
9
  import pandas as pd
@@ -164,8 +164,8 @@ def _write_tfrecord(
164
164
  _write_example(tensor)
165
165
  except Exception as e:
166
166
  warnings.warn(
167
- f"Could not write record for index {i + start_index}, proceeding without it."
168
- f"Exception raised:\n{e}"
167
+ f'Could not write record for index {i + start_index}, proceeding without it.'
168
+ f'Exception raised:\n{e}'
169
169
  )
170
170
 
171
171
  def _serialize_example(
molcraft/tensors.py CHANGED
@@ -1,3 +1,4 @@
1
+ import warnings
1
2
  import tensorflow as tf
2
3
  import keras
3
4
  import typing
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: molcraft
3
- Version: 0.1.0a22
3
+ Version: 0.1.0a23
4
4
  Summary: Graph Neural Networks for Molecular Machine Learning
5
5
  Author-email: Alexander Kensert <alexander.kensert@gmail.com>
6
6
  License: MIT License
@@ -35,7 +35,6 @@ Requires-Python: >=3.10
35
35
  Description-Content-Type: text/markdown
36
36
  License-File: LICENSE
37
37
  Requires-Dist: tensorflow>=2.16
38
- Requires-Dist: tensorflow-text>=2.16
39
38
  Requires-Dist: rdkit>=2023.9.5
40
39
  Requires-Dist: pandas>=1.0.3
41
40
  Requires-Dist: ipython>=8.12.0
@@ -43,7 +42,7 @@ Provides-Extra: gpu
43
42
  Requires-Dist: tensorflow[and-cuda]>=2.16; extra == "gpu"
44
43
  Dynamic: license-file
45
44
 
46
- <img src="https://github.com/akensert/molcraft/blob/main/docs/_static/molcraft-logo.png" alt="molcraft-logo", width="90%">
45
+ <img src="https://github.com/akensert/molcraft/blob/main/docs/_static/molcraft-logo.png" alt="molcraft-logo" width="90%">
47
46
 
48
47
  **Deep Learning on Molecules**: A Minimalistic GNN package for Molecular ML.
49
48
 
@@ -0,0 +1,22 @@
1
+ molcraft/__init__.py,sha256=QAJQSS_jOBzLxGRL7ciskMY_kn2ARLCg7FVTWeF-D_I,432
2
+ molcraft/callbacks.py,sha256=B4gGWjVW_1ORrt38jfk1ZFI9c0rOpN5sgjGWVqs3Ess,3571
3
+ molcraft/chem.py,sha256=0Zni91J4fQJW16R6g3jlOX9Vm8FH0Z5NOx0s7_X-xQw,22232
4
+ molcraft/datasets.py,sha256=1rHccqra5chIBwo2pz9vduyv0i07uY3CABzmAqWiFBU,4161
5
+ molcraft/descriptors.py,sha256=uqMPeIKqfkHC04FgztxS1FsfC3zsFJhvniZO70D22l0,4553
6
+ molcraft/diffusion.py,sha256=HR1kp2MuCWyUtGoGXvEA6kXTdWMGD2w5EZEoKLI1ilM,7902
7
+ molcraft/features.py,sha256=q-wuRP9YjPu_v5czipsh00VEXEjgFaeuLk6dbgyD_VM,13505
8
+ molcraft/featurizers.py,sha256=nGdV9G-aO43-vgKPNFfEOESW2hVvIvixHu3EHjIRrgU,18097
9
+ molcraft/layers.py,sha256=ba0WdQC2IUNsLy9pV0mIm5BBO3MCvR2lYWWuq1-8M4M,64522
10
+ molcraft/losses.py,sha256=piu4XYAgjnK7k9LqA4Vkh-SooYZ31sWwRfG1cacCwyA,1081
11
+ molcraft/models.py,sha256=-at-yFWj8mIkGchVY39m9-HtTnKqAUDQrF6wDrQXNuQ,22040
12
+ molcraft/ops.py,sha256=Qf9l1oOg20HNi9L9nLgf_c_5v09GtXDCc-1fkEQqn54,6194
13
+ molcraft/records.py,sha256=dAaq5tr3B1jXNgd3tkvVxWgUW2Qa9gtXWE3TToeiKmQ,6283
14
+ molcraft/tensors.py,sha256=JILwU9l6kUsrtJJ9YSzmT90_G5kbZZC6LAShvsnZOrk,22493
15
+ molcraft/applications/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
16
+ molcraft/applications/chromatography.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
17
+ molcraft/applications/proteomics.py,sha256=BL3EtW-q-0j79pLYO7npC67mA2ApRhH-XI4rOaP8_wc,8407
18
+ molcraft-0.1.0a23.dist-info/licenses/LICENSE,sha256=sbVeqlrtZ0V63uYhZGL5dCxUm8rBAOqe2avyA1zIQNk,1074
19
+ molcraft-0.1.0a23.dist-info/METADATA,sha256=RomzSM8GmqbILgyZMmzR4GjrBLWDPHGZcedb1VDmqxs,3892
20
+ molcraft-0.1.0a23.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
21
+ molcraft-0.1.0a23.dist-info/top_level.txt,sha256=dENV6MfOceshM6MQCgJlcN1ojZkiCL9B4F7XyUge3QM,9
22
+ molcraft-0.1.0a23.dist-info/RECORD,,
@@ -1,21 +0,0 @@
1
- molcraft/__init__.py,sha256=O88EmicQAD8oz9oFMXk_IzFChQEbbU-BCs3IE-c9Dkk,431
2
- molcraft/callbacks.py,sha256=x5HnkZhqcFRrW6xdApt_jZ4X08A-0fxcnFKfdmRKa0c,3571
3
- molcraft/chem.py,sha256=ynrEpWZL2D370p7CqH2kE1KhBByq7IiuQbUNoKQt96I,22028
4
- molcraft/datasets.py,sha256=Nd2lw5USUZE52vvAiNr-q-n03Y3--NlZlK0NzqHgp-E,4145
5
- molcraft/descriptors.py,sha256=Cl3KnBPsTST7XLgRLktkX5LwY9MV0P_lUlrt8iPV5no,3508
6
- molcraft/features.py,sha256=s0WeV8eZcDEypPgC1m37f4s9QkvWIlVgn-L43Cdsa14,13525
7
- molcraft/featurizers.py,sha256=1yBz5-JA7IhNm0dGivvVm1nJ5QGck8VQXtwHPWFbTuQ,18091
8
- molcraft/layers.py,sha256=H7XZru4XGJA6gbRO9V1BsGqh1mIrMdhzNCKS5o6oNok,64544
9
- molcraft/losses.py,sha256=qnS2yC5g-O3n_zVea9MR6TNiFraW2yqRgePOisoUP4A,1065
10
- molcraft/models.py,sha256=2Pc1htT9fCukGd8ZxrvE0rzEHsPBm0pluHw4FZXaUE4,21963
11
- molcraft/ops.py,sha256=bQbdFDt9waxVCzF5-dkTB6vlpj9eoSt8I4Qg7ZGXbsU,6178
12
- molcraft/records.py,sha256=sopYElKWC3A9QE5I8_957v3faLb2Wt5WILHZv_FLLds,6283
13
- molcraft/tensors.py,sha256=vk-W8zZu-re1g18YevDEEoVQRxT4AdIiMdI-4EvtJI4,22477
14
- molcraft/applications/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
15
- molcraft/applications/chromatography.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
16
- molcraft/applications/proteomics.py,sha256=BL3EtW-q-0j79pLYO7npC67mA2ApRhH-XI4rOaP8_wc,8407
17
- molcraft-0.1.0a22.dist-info/licenses/LICENSE,sha256=sbVeqlrtZ0V63uYhZGL5dCxUm8rBAOqe2avyA1zIQNk,1074
18
- molcraft-0.1.0a22.dist-info/METADATA,sha256=1OHx3-Q94fFEi21l0p3bnMjU-Q0EHaZLm4PU1A6QbkU,3930
19
- molcraft-0.1.0a22.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
20
- molcraft-0.1.0a22.dist-info/top_level.txt,sha256=dENV6MfOceshM6MQCgJlcN1ojZkiCL9B4F7XyUge3QM,9
21
- molcraft-0.1.0a22.dist-info/RECORD,,