molcraft 0.1.0a12__py3-none-any.whl → 0.1.0a14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of molcraft might be problematic. Click here for more details.

molcraft/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- __version__ = '0.1.0a12'
1
+ __version__ = '0.1.0a14'
2
2
 
3
3
  import os
4
4
  os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
molcraft/chem.py CHANGED
@@ -331,20 +331,20 @@ def get_shortest_paths(
331
331
  def get_periodic_table():
332
332
  return Chem.GetPeriodicTable()
333
333
 
334
- def gasteiger_charges(mol: 'Mol') -> list[float]:
334
+ def partial_charges(mol: 'Mol') -> list[float]:
335
335
  rdPartialCharges.ComputeGasteigerCharges(mol)
336
336
  return [atom.GetDoubleProp("_GasteigerCharge") for atom in mol.atoms]
337
337
 
338
338
  def logp_contributions(mol: 'Mol') -> list[float]:
339
339
  return [i[0] for i in rdMolDescriptors._CalcCrippenContribs(mol)]
340
340
 
341
- def molar_refractivity_contribution(mol: 'Mol') -> list[float]:
341
+ def molar_refractivity_contributions(mol: 'Mol') -> list[float]:
342
342
  return [i[1] for i in rdMolDescriptors._CalcCrippenContribs(mol)]
343
343
 
344
- def tpsa_contribution(mol: 'Mol') -> list[float]:
344
+ def total_polar_surface_area_contributions(mol: 'Mol') -> list[float]:
345
345
  return list(rdMolDescriptors._CalcTPSAContribs(mol))
346
346
 
347
- def asa_contribution(mol: 'Mol') -> list[float]:
347
+ def accessible_surface_area_contributions(mol: 'Mol') -> list[float]:
348
348
  return list(rdMolDescriptors._CalcLabuteASAContribs(mol)[0])
349
349
 
350
350
  def hydrogen_acceptors(mol: 'Mol') -> list[bool]:
molcraft/datasets.py CHANGED
@@ -1,123 +1,131 @@
1
1
  import numpy as np
2
2
  import pandas as pd
3
+ import typing
3
4
 
4
5
 
5
6
  def split(
6
7
  data: pd.DataFrame | np.ndarray,
8
+ *,
7
9
  train_size: float | None = None,
8
10
  validation_size: float | None = None,
9
- test_size: float = 0.1,
11
+ test_size: float | None = None,
12
+ groups: str | np.ndarray = None,
10
13
  shuffle: bool = False,
11
14
  random_state: int | None = None,
12
- ) -> pd.DataFrame | np.ndarray:
13
- """Splits dataset into subsets.
15
+ ) -> tuple[np.ndarray | pd.DataFrame, ...]:
16
+ """Splits the dataset into subsets.
14
17
 
15
18
  Args:
16
19
  data:
17
20
  A pd.DataFrame or np.ndarray object.
18
21
  train_size:
19
- Optional train size, as a fraction (`float`) or size (`int`).
22
+ The size of the train set.
20
23
  validation_size:
21
- Optional validation size, as a fraction (`float`) or size (`int`).
24
+ The size of the validation set.
22
25
  test_size:
23
- Required test size, as a fraction (`float`) or size (`int`).
26
+ The size of the test set.
27
+ groups:
28
+ The groups to perform the splitting on.
24
29
  shuffle:
25
30
  Whether the dataset should be shuffled prior to splitting.
26
31
  random_state:
27
- The random state (or seed). Only applicable if shuffling.
32
+ The random state/seed. Only applicable if shuffling.
28
33
  """
34
+ if not isinstance(data, (pd.DataFrame, np.ndarray)):
35
+ raise ValueError(f'Unsupported `data` type ({type(data)}).')
29
36
 
30
- if not isinstance(data, (pd.DataFrame, np.ndarray, list)):
31
- raise ValueError(
32
- '`data` needs to be a pd.DataFrame, np.ndarray or a list. '
33
- f'Found {type(data)}.'
34
- )
35
-
36
- size = len(data)
37
+ if isinstance(groups, str):
38
+ groups = data[groups].values
39
+ elif groups is None:
40
+ groups = np.arange(len(data))
37
41
 
38
- if test_size is None:
39
- raise ValueError('`test_size` is required.')
40
- elif test_size <= 0:
41
- raise ValueError(
42
- f'Test size needs to be positive. Found: {test_size}. '
43
- 'Either specify a positive `float` (fraction) or '
44
- 'a positive `int` (size).'
45
- )
46
- if train_size is not None and train_size <= 0:
47
- raise ValueError(
48
- f'Train size needs to be None or positive. Found: {train_size}. '
49
- 'Either specify `None`, a positive `float` (fraction) or '
50
- 'a positive `int` (size).'
51
- )
52
- if validation_size is not None and validation_size <= 0:
42
+ indices = np.unique(groups)
43
+ size = len(indices)
44
+
45
+ if not train_size and not test_size:
53
46
  raise ValueError(
54
- f'Validation size needs to be None or positive. Found: {validation_size}. '
55
- 'Either specify `None`, a positive `float` (fraction) or '
56
- 'a positive `int` (size).'
47
+ f'Found both `train_size` and `test_size` to be `None`, '
48
+ f'specify at least one of them.'
57
49
  )
58
-
59
50
  if isinstance(test_size, float):
60
51
  test_size = int(size * test_size)
61
- if validation_size and isinstance(validation_size, float):
52
+ if isinstance(train_size, float):
53
+ train_size = int(size * train_size)
54
+ if isinstance(validation_size, float):
62
55
  validation_size = int(size * validation_size)
63
56
  elif not validation_size:
64
57
  validation_size = 0
65
58
 
66
- if train_size and isinstance(train_size, float):
67
- train_size = int(size * train_size)
68
- elif not train_size:
69
- train_size = 0
70
-
71
59
  if not train_size:
72
- train_size = size - test_size
73
- if not validation_size:
74
- train_size -= validation_size
75
-
60
+ train_size = (size - test_size - validation_size)
61
+ if not test_size:
62
+ test_size = (size - train_size - validation_size)
63
+
76
64
  remainder = size - (train_size + validation_size + test_size)
77
-
78
65
  if remainder < 0:
79
66
  raise ValueError(
80
- 'Sizes of data subsets add up to more than the size of the original data set: '
81
- f'{size} < ({train_size} + {validation_size} + {test_size})'
67
+ f'subset sizes added up to more than the data size.'
82
68
  )
83
- if test_size <= 0:
84
- raise ValueError(
85
- f'Test size needs to be greater than 0. Found: {test_size}.'
86
- )
87
- if train_size <= 0:
88
- raise ValueError(
89
- f'Train size needs to be greater than 0. Found: {train_size}.'
90
- )
91
-
92
69
  train_size += remainder
93
70
 
94
- if isinstance(data, pd.DataFrame):
95
- if shuffle:
96
- data = data.sample(
97
- frac=1.0, replace=False, random_state=random_state
98
- )
99
- train_data = data.iloc[:train_size]
100
- test_data = data.iloc[-test_size:]
101
- if not validation_size:
102
- return train_data, test_data
103
- validation_data = data.iloc[train_size:-test_size]
104
- return train_data, validation_data, test_data
105
-
106
- if not isinstance(data, np.ndarray):
107
- data = np.asarray(data)
108
-
109
- np.random.seed(random_state)
110
-
111
- random_indices = np.arange(size)
112
- np.random.shuffle(random_indices)
113
- data = data[random_indices]
71
+ if shuffle:
72
+ np.random.seed(random_state)
73
+ np.random.shuffle(indices)
114
74
 
115
- train_data = data[:train_size]
116
- test_data = data[-test_size:]
75
+ train_mask = np.isin(groups, indices[:train_size])
76
+ test_mask = np.isin(groups, indices[-test_size:])
117
77
  if not validation_size:
118
- return train_data, test_data
119
- validation_data = data[train_size:-test_size]
120
- return train_data, validation_data, test_data
78
+ return data[train_mask], data[test_mask]
79
+ validation_mask = np.isin(groups, indices[train_size:-test_size])
80
+ return data[train_mask], data[validation_mask], data[test_mask]
121
81
 
82
+ def cv_split(
83
+ data: pd.DataFrame | np.ndarray,
84
+ num_splits: int = 10,
85
+ groups: str | np.ndarray = None,
86
+ shuffle: bool = False,
87
+ random_state: int | None = None,
88
+ ) -> typing.Iterator[
89
+ tuple[np.ndarray | pd.DataFrame, np.ndarray | pd.DataFrame]
90
+ ]:
91
+ """Splits the dataset into cross-validation folds.
122
92
 
93
+ Args:
94
+ data:
95
+ A pd.DataFrame or np.ndarray object.
96
+ num_splits:
97
+ The number of cross-validation folds.
98
+ groups:
99
+ The groups to perform the splitting on.
100
+ shuffle:
101
+ Whether the dataset should be shuffled prior to splitting.
102
+ random_state:
103
+ The random state/seed. Only applicable if shuffling.
104
+ """
105
+ if not isinstance(data, (pd.DataFrame, np.ndarray)):
106
+ raise ValueError(f'Unsupported `data` type ({type(data)}).')
107
+
108
+ if isinstance(groups, str):
109
+ groups = data[groups].values
110
+ elif groups is None:
111
+ groups = np.arange(len(data))
112
+
113
+ indices = np.unique(groups)
114
+ size = len(indices)
123
115
 
116
+ if num_splits > size:
117
+ raise ValueError(
118
+ f'`num_splits` ({num_splits}) must not be greater than'
119
+ f'the data size or the number of groups ({size}).'
120
+ )
121
+ if shuffle:
122
+ np.random.seed(random_state)
123
+ np.random.shuffle(indices)
124
+
125
+ indices_splits = np.array_split(indices, num_splits)
126
+
127
+ for k in range(num_splits):
128
+ test_indices = indices_splits[k]
129
+ test_mask = np.isin(groups, test_indices)
130
+ train_mask = ~test_mask
131
+ yield data[train_mask], data[test_mask]
molcraft/descriptors.py CHANGED
@@ -37,19 +37,21 @@ class MolWeight(Descriptor):
37
37
 
38
38
 
39
39
  @keras.saving.register_keras_serializable(package='molcraft')
40
- class TPSA(Descriptor):
40
+ class TotalPolarSurfaceArea(Descriptor):
41
41
  def call(self, mol: chem.Mol) -> np.ndarray:
42
42
  return rdMolDescriptors.CalcTPSA(mol)
43
43
 
44
44
 
45
45
  @keras.saving.register_keras_serializable(package='molcraft')
46
- class CrippenLogP(Descriptor):
46
+ class LogP(Descriptor):
47
+ """Crippen logP."""
47
48
  def call(self, mol: chem.Mol) -> np.ndarray:
48
49
  return rdMolDescriptors.CalcCrippenDescriptors(mol)[0]
49
50
 
50
51
 
51
52
  @keras.saving.register_keras_serializable(package='molcraft')
52
- class CrippenMolarRefractivity(Descriptor):
53
+ class MolarRefractivity(Descriptor):
54
+ """Crippen molar refractivity."""
53
55
  def call(self, mol: chem.Mol) -> np.ndarray:
54
56
  return rdMolDescriptors.CalcCrippenDescriptors(mol)[1]
55
57
 
molcraft/features.py CHANGED
@@ -276,37 +276,42 @@ class IsHydrogenAcceptor(Feature):
276
276
  class IsInRing(Feature):
277
277
  def call(self, mol: chem.Mol) -> list[int, float, str]:
278
278
  return [atom.IsInRing() for atom in mol.atoms]
279
-
279
+
280
280
 
281
281
  @keras.saving.register_keras_serializable(package='molcraft')
282
- class CrippenLogPContribution(Feature):
283
- def call(self, mol: chem.Mol) -> list[int, float, str]:
284
- return chem.logp_contributions(mol)
282
+ class PartialCharge(Feature):
283
+ """Gasteiger partial charge."""
284
+ def call(self, mol: chem.Mol) -> list[int, float, str]:
285
+ return chem.partial_charges(mol)
285
286
 
286
287
 
287
288
  @keras.saving.register_keras_serializable(package='molcraft')
288
- class CrippenMolarRefractivityContribution(Feature):
289
+ class TotalPolarSurfaceAreaContribution(Feature):
290
+ """Total polar surface area (TPSA) contribution."""
289
291
  def call(self, mol: chem.Mol) -> list[int, float, str]:
290
- return chem.molar_refractivity_contribution(mol)
291
-
292
+ return chem.total_polar_surface_area_contributions(mol)
292
293
 
293
- @keras.saving.register_keras_serializable(package='molcraft')
294
- class TPSAContribution(Feature):
295
- def call(self, mol: chem.Mol) -> list[int, float, str]:
296
- return chem.tpsa_contribution(mol)
297
-
298
294
 
299
295
  @keras.saving.register_keras_serializable(package='molcraft')
300
- class LabuteASAContribution(Feature):
296
+ class AccessibleSurfaceAreaContribution(Feature):
297
+ """Labute accessible surface area (ASA) contribution."""
301
298
  def call(self, mol: chem.Mol) -> list[int, float, str]:
302
- return chem.asa_contribution(mol)
299
+ return chem.accessible_surface_area_contributions(mol)
303
300
 
304
301
 
305
302
  @keras.saving.register_keras_serializable(package='molcraft')
306
- class GasteigerCharge(Feature):
307
- def call(self, mol: chem.Mol) -> list[int, float, str]:
308
- return chem.gasteiger_charges(mol)
309
-
303
+ class LogPContribution(Feature):
304
+ """Crippen logP contribution."""
305
+ def call(self, mol: chem.Mol) -> list[int, float, str]:
306
+ return chem.logp_contributions(mol)
307
+
308
+
309
+ @keras.saving.register_keras_serializable(package='molcraft')
310
+ class MolarRefractivityContribution(Feature):
311
+ """Crippen molar refractivity contribution."""
312
+ def call(self, mol: chem.Mol) -> list[int, float, str]:
313
+ return chem.molar_refractivity_contributions(mol)
314
+
310
315
 
311
316
  @keras.saving.register_keras_serializable(package='molcraft')
312
317
  class BondType(Feature):
molcraft/featurizers.py CHANGED
@@ -192,9 +192,9 @@ class MolGraphFeaturizer(Featurizer):
192
192
  if default_molecule_features:
193
193
  molecule_features = [
194
194
  descriptors.MolWeight(),
195
- descriptors.TPSA(),
196
- descriptors.CrippenLogP(),
197
- descriptors.CrippenMolarRefractivity(),
195
+ descriptors.TotalPolarSurfaceArea(),
196
+ descriptors.LogP(),
197
+ descriptors.MolarRefractivity(),
198
198
  descriptors.NumHeavyAtoms(),
199
199
  descriptors.NumHeteroatoms(),
200
200
  descriptors.NumHydrogenDonors(),
molcraft/models.py CHANGED
@@ -250,7 +250,7 @@ class GraphModel(layers.GraphLayer, keras.models.Model):
250
250
  val_size = int(val_split * x.num_subgraphs)
251
251
  x_val = _make_dataset(x[-val_size:], batch_size)
252
252
  x = x[:-val_size]
253
- x = _make_dataset(x, batch_size)
253
+ x = _make_dataset(x, batch_size, shuffle=kwargs.get('shuffle', True))
254
254
  return super().fit(x, validation_data=x_val, **kwargs)
255
255
 
256
256
  def evaluate(self, x: tensors.GraphTensor | tf.data.Dataset, **kwargs):
@@ -561,9 +561,8 @@ def _functional_init_arguments(args, kwargs):
561
561
  or ("inputs" in kwargs and "outputs" in kwargs)
562
562
  )
563
563
 
564
- def _make_dataset(x: tensors.GraphTensor, batch_size: int):
565
- return (
566
- tf.data.Dataset.from_tensor_slices(x)
567
- .batch(batch_size)
568
- .prefetch(-1)
569
- )
564
+ def _make_dataset(x: tensors.GraphTensor, batch_size: int, shuffle: bool = False):
565
+ ds = tf.data.Dataset.from_tensor_slices(x)
566
+ if shuffle:
567
+ ds = ds.shuffle(buffer_size=ds.cardinality())
568
+ return ds.batch(batch_size).prefetch(-1)
molcraft/ops.py CHANGED
@@ -62,6 +62,22 @@ def scatter_update(
62
62
  indices = keras.ops.expand_dims(indices, axis=-1)
63
63
  return keras.ops.scatter_update(inputs, indices, updates)
64
64
 
65
+ def scatter_add(
66
+ inputs: tf.Tensor,
67
+ indices: tf.Tensor,
68
+ updates: tf.Tensor,
69
+ ) -> tf.Tensor:
70
+ if indices.dtype == tf.bool:
71
+ indices = keras.ops.stack(keras.ops.where(indices), axis=-1)
72
+ expected_rank = len(keras.ops.shape(inputs))
73
+ current_rank = len(keras.ops.shape(indices))
74
+ for _ in range(expected_rank - current_rank):
75
+ indices = keras.ops.expand_dims(indices, axis=-1)
76
+ if backend.backend() == 'tensorflow':
77
+ return tf.tensor_scatter_nd_add(inputs, indices, updates)
78
+ updates = scatter_update(keras.ops.zeros_like(inputs), indices, updates)
79
+ return inputs + updates
80
+
65
81
  def edge_softmax(
66
82
  score: tf.Tensor,
67
83
  edge_target: tf.Tensor
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: molcraft
3
- Version: 0.1.0a12
3
+ Version: 0.1.0a14
4
4
  Summary: Graph Neural Networks for Molecular Machine Learning
5
5
  Author-email: Alexander Kensert <alexander.kensert@gmail.com>
6
6
  License: MIT License
@@ -0,0 +1,21 @@
1
+ molcraft/__init__.py,sha256=lReyUDRgBySoe9LPZzlwv1N_x9unwr6nHxIU70u3mLU,464
2
+ molcraft/callbacks.py,sha256=x5HnkZhqcFRrW6xdApt_jZ4X08A-0fxcnFKfdmRKa0c,3571
3
+ molcraft/chem.py,sha256=--4AdZV0TCj_cf5i-TRidNJGSFyab1ksUEMjmDi7zaM,21837
4
+ molcraft/conformers.py,sha256=K6ZtiSUNDN_fwqGP9JrPcwALLFFvlMlF_XejEJH3Sr4,4205
5
+ molcraft/datasets.py,sha256=QKHi9SUBKvJvdkRFmRQNowhrnu35pQqtujuLatOK8bE,4151
6
+ molcraft/descriptors.py,sha256=jJpT0XWu3Tx_bxnwk1rENySRkaM8cMDMaDIjG8KKvtg,3097
7
+ molcraft/features.py,sha256=GwOecLCNUIuGfbIVzsAJH4LikkzWMKj5IT7zSgGTttU,13846
8
+ molcraft/featurizers.py,sha256=QiyNEFRJdMcKZM-gJGHU6Soy300RWDtLeYw0QEkFG20,27129
9
+ molcraft/layers.py,sha256=cUpo9dqqNEnc7rNf-Dze8adFhOkTV5F9IhHOKs13OUI,60134
10
+ molcraft/losses.py,sha256=qnS2yC5g-O3n_zVea9MR6TNiFraW2yqRgePOisoUP4A,1065
11
+ molcraft/models.py,sha256=h9cRAdCOU-_UAxROC9Utuz4AR4HfFE9QqJ4geLYlynE,21878
12
+ molcraft/ops.py,sha256=TaAD26V-b7eSNKFKswWt9IExSgIBOmLqwlPPcdpt8wk,5496
13
+ molcraft/records.py,sha256=MbvYkcCunbAmpy_MWXmQ9WBGi2WvwxFUlwQSPKPvSSk,5534
14
+ molcraft/tensors.py,sha256=EOUKx496KUZsjA1zA2ABc7tU_TW3Jv7AXDsug_QsLbA,22407
15
+ molcraft/apps/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
16
+ molcraft/apps/peptides.py,sha256=N5wJDGDIDRbmOmxin_dTY-odLqb0avAX9FU22U6x6c0,14576
17
+ molcraft-0.1.0a14.dist-info/licenses/LICENSE,sha256=sbVeqlrtZ0V63uYhZGL5dCxUm8rBAOqe2avyA1zIQNk,1074
18
+ molcraft-0.1.0a14.dist-info/METADATA,sha256=1Op3VxuV9hkciALrrOXx2KnGShFI5a9n_XbhT-oPpKI,3893
19
+ molcraft-0.1.0a14.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
20
+ molcraft-0.1.0a14.dist-info/top_level.txt,sha256=dENV6MfOceshM6MQCgJlcN1ojZkiCL9B4F7XyUge3QM,9
21
+ molcraft-0.1.0a14.dist-info/RECORD,,
@@ -1,21 +0,0 @@
1
- molcraft/__init__.py,sha256=exZr4HcSy0uUnFlh9cshJrs0MBDP-pXT2MqKjq0a2BY,464
2
- molcraft/callbacks.py,sha256=x5HnkZhqcFRrW6xdApt_jZ4X08A-0fxcnFKfdmRKa0c,3571
3
- molcraft/chem.py,sha256=JARpv4IgFBtuNia0FLW_VF_DdmaA6e-_eZgH9dFAykA,21796
4
- molcraft/conformers.py,sha256=K6ZtiSUNDN_fwqGP9JrPcwALLFFvlMlF_XejEJH3Sr4,4205
5
- molcraft/datasets.py,sha256=rFgXTC1ZheLhfgQgcCspP_wEE54a33PIneH7OplbS-8,4047
6
- molcraft/descriptors.py,sha256=W8GLuDpc38RtwmreNsPOcn-PRvMjTfVng9ksJwcrVyM,3032
7
- molcraft/features.py,sha256=FpvT_9zk9EiOhvrk6OA5eEvUAYalquF7V6IvpiEJCns,13559
8
- molcraft/featurizers.py,sha256=1xyJ2JroFBHzcheRZ8v9P3bYBIaoiY-WCBdbbqXK4co,27126
9
- molcraft/layers.py,sha256=cUpo9dqqNEnc7rNf-Dze8adFhOkTV5F9IhHOKs13OUI,60134
10
- molcraft/losses.py,sha256=qnS2yC5g-O3n_zVea9MR6TNiFraW2yqRgePOisoUP4A,1065
11
- molcraft/models.py,sha256=0x74B4WsaZgmUrHmpX9YNr9QXqd1rR3QF_ygyegHoXU,21770
12
- molcraft/ops.py,sha256=PVxKfY_XbWCyntiSnmpyeBb-coFGT_VNNP9QzmeUwC0,4870
13
- molcraft/records.py,sha256=MbvYkcCunbAmpy_MWXmQ9WBGi2WvwxFUlwQSPKPvSSk,5534
14
- molcraft/tensors.py,sha256=EOUKx496KUZsjA1zA2ABc7tU_TW3Jv7AXDsug_QsLbA,22407
15
- molcraft/apps/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
16
- molcraft/apps/peptides.py,sha256=N5wJDGDIDRbmOmxin_dTY-odLqb0avAX9FU22U6x6c0,14576
17
- molcraft-0.1.0a12.dist-info/licenses/LICENSE,sha256=sbVeqlrtZ0V63uYhZGL5dCxUm8rBAOqe2avyA1zIQNk,1074
18
- molcraft-0.1.0a12.dist-info/METADATA,sha256=zMjHudRgekPvWDmQdtV2pW9tyapaYqkntWZ4k3u9X_g,3893
19
- molcraft-0.1.0a12.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
20
- molcraft-0.1.0a12.dist-info/top_level.txt,sha256=dENV6MfOceshM6MQCgJlcN1ojZkiCL9B4F7XyUge3QM,9
21
- molcraft-0.1.0a12.dist-info/RECORD,,