molcraft 0.1.0a3__py3-none-any.whl → 0.1.0a5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of molcraft might be problematic. Click here for more details.

molcraft/models.py CHANGED
@@ -270,6 +270,19 @@ class GraphModel(layers.GraphLayer, keras.models.Model):
270
270
  """
271
271
  super().load_weights(filepath, *args, **kwargs)
272
272
 
273
+ def embedding(self) -> 'FunctionalGraphModel':
274
+ model = self
275
+ if not isinstance(model, FunctionalGraphModel):
276
+ raise ValueError(
277
+ 'Currently, to extract the embedding part of the model, '
278
+ 'it needs to be a `FunctionalGraphModel`. '
279
+ )
280
+ inputs = model.input
281
+ for layer in model.layers:
282
+ if isinstance(layer, layers.Readout):
283
+ outputs = layer.output
284
+ return self.__class__(inputs, outputs, name=f'{self.name}_embedding')
285
+
273
286
  def train_step(self, tensor: tensors.GraphTensor) -> dict[str, float]:
274
287
  y = tensor.context.get('label')
275
288
  sample_weight = tensor.context.get('weight')
@@ -381,7 +394,7 @@ def interpret(
381
394
 
382
395
  def saliency(
383
396
  model: GraphModel,
384
- graph_tensor: tensors.GraphTensor
397
+ graph_tensor: tensors.GraphTensor,
385
398
  ) -> tensors.GraphTensor:
386
399
  x = graph_tensor
387
400
  if tensors.is_ragged(x):
molcraft/ops.py CHANGED
@@ -19,9 +19,16 @@ def gather(
19
19
  def aggregate(
20
20
  node_feature: tf.Tensor,
21
21
  edge: tf.Tensor,
22
- num_nodes: tf.Tensor
22
+ num_nodes: tf.Tensor,
23
+ mode: str = 'sum',
23
24
  ) -> tf.Tensor:
24
- return keras.ops.segment_sum(node_feature, edge, num_nodes)
25
+ if mode == 'mean':
26
+ return segment_mean(
27
+ node_feature, edge, num_nodes, sorted=False
28
+ )
29
+ return keras.ops.segment_sum(
30
+ node_feature, edge, num_nodes, sorted=False
31
+ )
25
32
 
26
33
  def propagate(
27
34
  node_feature: tf.Tensor,
@@ -82,7 +89,11 @@ def segment_mean(
82
89
  sorted: bool = False,
83
90
  ) -> tf.Tensor:
84
91
  if num_segments is None:
85
- num_segments = keras.ops.max(segment_ids) + 1
92
+ num_segments = keras.ops.cond(
93
+ keras.ops.shape(segment_ids)[0] > 0,
94
+ lambda: keras.ops.max(segment_ids) + 1,
95
+ lambda: 0
96
+ )
86
97
  if backend.backend() == 'tensorflow':
87
98
  return tf.math.unsorted_segment_mean(
88
99
  data=data,
molcraft/tensors.py CHANGED
@@ -219,13 +219,13 @@ class GraphTensor(tf.experimental.BatchableExtensionType):
219
219
  raise ValueError
220
220
  return ops.gather(self.node[node_attr], self.edge[edge_type])
221
221
 
222
- def aggregate(self, edge_attr: str, edge_type: str = 'target') -> tf.Tensor:
222
+ def aggregate(self, edge_attr: str, edge_type: str = 'target', mode: str = 'sum') -> tf.Tensor:
223
223
  if edge_type != 'source' and edge_type != 'target':
224
- raise ValueError
224
+ raise ValueError('`edge_attr` needs to be `source` or `target`.')
225
225
  edge_attr = self.edge[edge_attr]
226
226
  if 'weight' in self.edge:
227
227
  edge_attr = edge_attr * self.edge['weight']
228
- return ops.aggregate(edge_attr, self.edge[edge_type], self.num_nodes)
228
+ return ops.aggregate(edge_attr, self.edge[edge_type], self.num_nodes, mode=mode)
229
229
 
230
230
  def propagate(self, add_edge_feature: bool = False):
231
231
  updated_feature = ops.propagate(
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: molcraft
3
- Version: 0.1.0a3
3
+ Version: 0.1.0a5
4
4
  Summary: Graph Neural Networks for Molecular Machine Learning
5
5
  Author-email: Alexander Kensert <alexander.kensert@gmail.com>
6
6
  License: MIT License
@@ -51,11 +51,11 @@ Dynamic: license-file
51
51
 
52
52
  ## Highlights
53
53
  - Compatible with **Keras 3**
54
- - Simplified API
55
- - Fast featurization
56
- - Modular graph **layers**
57
- - Serializable graph **featurizers** and **models**
58
- - Flexible **GraphTensor**
54
+ - Customizable and serializable **featurizers**
55
+ - Customizable and serializable **layers** and **models**
56
+ - Customizable **GraphTensor**
57
+ - Fast and efficient featurization of molecular graphs
58
+ - Efficient and easy-to-use input pipelines using TF **records**
59
59
 
60
60
  ## Examples
61
61
 
@@ -0,0 +1,18 @@
1
+ molcraft/__init__.py,sha256=eTGjgMlXf3I8ThkUwgdiONb5Yc-5fWOFvY8U8WXOMwc,435
2
+ molcraft/callbacks.py,sha256=mkz4ALjJFPy8nHd2nCAuMbKceKnq4tIpZhUuUOvie2Y,1209
3
+ molcraft/chem.py,sha256=_UO5O-I7KUtGf3vRrFEYoAUGlW5xi2x8ylu5f-Ybumo,18696
4
+ molcraft/conformers.py,sha256=p09gOQOdxLSj3yohZOMkxxLriHsZ1ZqOoiWLi73OpIg,4325
5
+ molcraft/datasets.py,sha256=rFgXTC1ZheLhfgQgcCspP_wEE54a33PIneH7OplbS-8,4047
6
+ molcraft/descriptors.py,sha256=gKqlJ3BqJLTeR2ft8isftSEaJDC8cv64eTq5IYhy4XM,3032
7
+ molcraft/features.py,sha256=69oV_GHNdBKPA4sp6Tpo6brvNmaauk_IVIzNjX7VDmg,13648
8
+ molcraft/featurizers.py,sha256=kV5RN_Z2pELjDcwE65KYy_JagbDUueXoClpsIOFsI9I,27073
9
+ molcraft/layers.py,sha256=y-sBLXWttr-fkGZ-acL1srMB8QqeXnHotYK9KCcyJNU,70581
10
+ molcraft/models.py,sha256=0MN4PAlsacni7RfIcYm_imxuzBVL2K8w3MnaUM24DeI,18021
11
+ molcraft/ops.py,sha256=uSnBYQwxYJ1ATdDpr290bxiyQZkrSCVxlB7btlh_n2I,4112
12
+ molcraft/records.py,sha256=w4-bcWZEC0oVInrE1e0kQBroIaSCA0PN1JBPOtO6VUY,5251
13
+ molcraft/tensors.py,sha256=8hwlad000wQ5pNLSdzd3rCXVbaUHBxUq2MbBx27dKzU,22391
14
+ molcraft-0.1.0a5.dist-info/licenses/LICENSE,sha256=sbVeqlrtZ0V63uYhZGL5dCxUm8rBAOqe2avyA1zIQNk,1074
15
+ molcraft-0.1.0a5.dist-info/METADATA,sha256=mb5KnvJUzofmx-MNraJxyiBBug2QNIQTQDGyC1L3SDw,4201
16
+ molcraft-0.1.0a5.dist-info/WHEEL,sha256=ck4Vq1_RXyvS4Jt6SI0Vz6fyVs4GWg7AINwpsaGEgPE,91
17
+ molcraft-0.1.0a5.dist-info/top_level.txt,sha256=dENV6MfOceshM6MQCgJlcN1ojZkiCL9B4F7XyUge3QM,9
18
+ molcraft-0.1.0a5.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (78.1.0)
2
+ Generator: setuptools (80.0.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -1 +0,0 @@
1
- from molcraft.experimental import peptides
@@ -1,264 +0,0 @@
1
- import re
2
- import keras
3
- import numpy as np
4
- import tensorflow as tf
5
- from rdkit import Chem
6
-
7
- from molcraft import ops
8
- from molcraft import chem
9
- from molcraft import features
10
- from molcraft import featurizers
11
- from molcraft import tensors
12
- from molcraft import descriptors
13
-
14
-
15
- def Graph(
16
- inputs,
17
- atom_features: list[features.Feature] | str | None = 'auto',
18
- bond_features: list[features.Feature] | str | None = 'auto',
19
- super_atom: bool = True,
20
- radius: int | float | None = None,
21
- self_loops: bool = False,
22
- include_hs: bool = False,
23
- **kwargs,
24
- ):
25
- featurizer = featurizers.MolGraphFeaturizer(
26
- atom_features=atom_features,
27
- bond_features=bond_features,
28
- molecule_features=[AminoAcidType()],
29
- super_atom=super_atom,
30
- radius=radius,
31
- self_loops=self_loops,
32
- include_hs=include_hs,
33
- **kwargs,
34
- )
35
-
36
- inputs = [
37
- residues[x] for x in ['G'] + inputs
38
- ]
39
- tensor_list = [featurizer(x) for x in inputs]
40
- return tf.stack(tensor_list, axis=0)
41
-
42
-
43
- def GraphLookup(graph: tensors.GraphTensor) -> 'GraphLookupLayer':
44
- lookup = GraphLookupLayer()
45
- lookup._build(graph)
46
- return lookup
47
-
48
-
49
- @keras.saving.register_keras_serializable(package='molcraft')
50
- class GraphLookupLayer(keras.layers.Layer):
51
-
52
- def call(self, indices: tf.Tensor) -> tensors.GraphTensor:
53
- indices = tf.sort(tf.unique(tf.reshape(indices, [-1]))[0])
54
- graph = self.graph[indices]
55
- sizes = graph.context['size']
56
- max_index = keras.ops.max(indices)
57
- sizes = tf.tensor_scatter_nd_update(
58
- tensor=tf.zeros([max_index + 1], dtype=indices.dtype),
59
- indices=indices[:, None],
60
- updates=sizes
61
- )
62
- graph = graph.update(
63
- {
64
- 'context': {
65
- 'size': sizes
66
- }
67
- },
68
- )
69
- return tensors.to_dict(graph)
70
-
71
- def _build(self, x):
72
-
73
- if isinstance(x, tensors.GraphTensor):
74
- tensor = tensors.to_dict(x)
75
- self._spec = tf.nest.map_structure(
76
- tf.type_spec_from_value, tensor
77
- )
78
- else:
79
- self._spec = x
80
-
81
- self._graph = tf.nest.map_structure(
82
- lambda s: self.add_weight(
83
- shape=s.shape,
84
- dtype=s.dtype,
85
- trainable=False,
86
- initializer='zeros'
87
- ),
88
- self._spec
89
- )
90
-
91
- if isinstance(x, tensors.GraphTensor):
92
- tf.nest.map_structure(
93
- lambda v, x: v.assign(x),
94
- self._graph, tensor
95
- )
96
-
97
- graph = tf.nest.map_structure(
98
- keras.ops.convert_to_tensor, self._graph
99
- )
100
- self._graph_tensor = tensors.from_dict(graph)
101
-
102
- def get_config(self):
103
- config = super().get_config()
104
- spec = keras.saving.serialize_keras_object(self._spec)
105
- config['spec'] = spec
106
- return config
107
-
108
- @classmethod
109
- def from_config(cls, config: dict) -> 'GraphLookupLayer':
110
- spec = config.pop('spec')
111
- spec = keras.saving.deserialize_keras_object(spec)
112
- layer = cls(**config)
113
- layer._build(spec)
114
- return layer
115
-
116
- @property
117
- def graph(self) -> tensors.GraphTensor:
118
- return self._graph_tensor
119
-
120
-
121
- @keras.saving.register_keras_serializable(package='molcraft')
122
- class Gather(keras.layers.Layer):
123
-
124
- def __init__(
125
- self,
126
- padding: list[tuple[int]] | tuple[int] | int = 1,
127
- mask_value: int = 0,
128
- **kwargs
129
- ) -> None:
130
- super().__init__(**kwargs)
131
- self.padding = padding
132
- self.mask_value = mask_value
133
- self.supports_masking = True
134
-
135
- def get_config(self):
136
- config = super().get_config()
137
- config['mask_value'] = self.mask_value
138
- config['padding'] = self.padding
139
- return config
140
-
141
- def call(self, inputs: tuple[tf.Tensor, tf.Tensor]) -> tf.Tensor:
142
- data, indices = inputs
143
- # if self.padding:
144
- # padding = self.padding
145
- # if isinstance(self.padding, int):
146
- # padding = [(self.padding, 0)]
147
- # if isinstance(self.padding, tuple):
148
- # padding = [self.padding]
149
- # data_rank = len(keras.ops.shape(data))
150
- # for _ in range(data_rank - len(padding)):
151
- # padding.append((0, 0))
152
- # data = keras.ops.pad(data, padding)
153
- return ops.gather(data, indices)
154
-
155
- def compute_mask(
156
- self,
157
- inputs: tuple[tf.Tensor, tf.Tensor],
158
- mask: bool | None = None
159
- ) -> tf.Tensor | None:
160
- # if self.mask_value is None:
161
- # return None
162
- _, indices = inputs
163
- return keras.ops.not_equal(indices, self.mask_value)
164
-
165
-
166
- @keras.saving.register_keras_serializable(package='molcraft')
167
- class AminoAcidType(descriptors.Descriptor):
168
-
169
- def __init__(self, vocab=None, **kwargs):
170
- vocab = [
171
- "A", "C", "D", "E", "F", "G", "H", "I", "K", "L",
172
- "M", "N", "P", "Q", "R", "S", "T", "V", "W", "Y",
173
- ]
174
- super().__init__(vocab=vocab, **kwargs)
175
-
176
- def call(self, mol: chem.Mol) -> list[str]:
177
- residue = residues_reverse.get(mol.canonical_smiles)
178
- if not residue:
179
- raise KeyError(f'Could not find {mol.canonical_smiles} in `residues_reverse`.')
180
- mol = chem.remove_hs(mol)
181
- return _extract_residue_type(residues_reverse[mol.canonical_smiles])
182
-
183
- def sequence_split(sequence: str):
184
- patterns = [
185
- r'(\[[A-Za-z0-9]+\]-[A-Z]\[[A-Za-z0-9]+\])', # N-term mod + mod
186
- r'([A-Z]\[[A-Za-z0-9]+\]-\[[A-Za-z0-9]+\])', # C-term mod + mod
187
- r'([A-Z]-\[[A-Za-z0-9]+\])', # C-term mod
188
- r'(\[[A-Za-z0-9]+\]-[A-Z])', # N-term mod
189
- r'([A-Z]\[[A-Za-z0-9]+\])', # Mod
190
- r'([A-Z])', # No mod
191
- ]
192
- return [match.group(0) for match in re.finditer("|".join(patterns), sequence)]
193
-
194
- residues = {
195
- "A": "N[C@@H](C)C(=O)O",
196
- "C": "N[C@@H](CS)C(=O)O",
197
- "C[Carbamidomethyl]": "N[C@@H](CSCC(=O)N)C(=O)O",
198
- "D": "N[C@@H](CC(=O)O)C(=O)O",
199
- "E": "N[C@@H](CCC(=O)O)C(=O)O",
200
- "F": "N[C@@H](Cc1ccccc1)C(=O)O",
201
- "G": "NCC(=O)O",
202
- "H": "N[C@@H](CC1=CN=C-N1)C(=O)O",
203
- "I": "N[C@@H](C(CC)C)C(=O)O",
204
- "K": "N[C@@H](CCCCN)C(=O)O",
205
- "K[Acetyl]": "N[C@@H](CCCCNC(=O)C)C(=O)O",
206
- "K[Crotonyl]": "N[C@@H](CCCCNC(C=CC)=O)C(=O)O",
207
- "K[Dimethyl]": "N[C@@H](CCCCN(C)C)C(=O)O",
208
- "K[Formyl]": "N[C@@H](CCCCNC=O)C(=O)O",
209
- "K[Malonyl]": "N[C@@H](CCCCNC(=O)CC(O)=O)C(=O)O",
210
- "K[Methyl]": "N[C@@H](CCCCNC)C(=O)O",
211
- "K[Propionyl]": "N[C@@H](CCCCNC(=O)CC)C(=O)O",
212
- "K[Succinyl]": "N[C@@H](CCCCNC(CCC(O)=O)=O)C(=O)O",
213
- "K[Trimethyl]": "N[C@@H](CCCC[N+](C)(C)C)C(=O)O",
214
- "L": "N[C@@H](CC(C)C)C(=O)O",
215
- "M": "N[C@@H](CCSC)C(=O)O",
216
- "M[Oxidation]": "N[C@@H](CCS(=O)C)C(=O)O",
217
- "N": "N[C@@H](CC(=O)N)C(=O)O",
218
- "P": "N1[C@@H](CCC1)C(=O)O",
219
- "P[Oxidation]": "N1CC(O)C[C@H]1C(=O)O",
220
- "Q": "N[C@@H](CCC(=O)N)C(=O)O",
221
- "R": "N[C@@H](CCCNC(=N)N)C(=O)O",
222
- "R[Deamidated]": "N[C@@H](CCCNC(N)=O)C(=O)O",
223
- "R[Dimethyl]": "N[C@@H](CCCNC(N(C)C)=N)C(=O)O",
224
- "R[Methyl]": "N[C@@H](CCCNC(=N)NC)C(=O)O",
225
- "S": "N[C@@H](CO)C(=O)O",
226
- "T": "N[C@@H](C(O)C)C(=O)O",
227
- "V": "N[C@@H](C(C)C)C(=O)O",
228
- "W": "N[C@@H](CC(=CN2)C1=C2C=CC=C1)C(=O)O",
229
- "Y": "N[C@@H](Cc1ccc(O)cc1)C(=O)O",
230
- "Y[Nitro]": "N[C@@H](Cc1ccc(O)c(N(=O)=O)c1)C(=O)O",
231
- "Y[Phospho]": "N[C@@H](Cc1ccc(OP(O)(=O)O)cc1)C(=O)O",
232
- "[Acetyl]-A": "N(C(C)=O)[C@@H](C)C(=O)O",
233
- "[Acetyl]-C": "N(C(C)=O)[C@@H](CS)C(=O)O",
234
- "[Acetyl]-D": "N(C(=O)C)[C@H](C(=O)O)CC(=O)O",
235
- "[Acetyl]-E": "N(C(=O)C)[C@@H](CCC(O)=O)C(=O)O",
236
- "[Acetyl]-F": "N(C(C)=O)[C@@H](Cc1ccccc1)C(=O)O",
237
- "[Acetyl]-G": "N(C(=O)C)CC(=O)O",
238
- "[Acetyl]-H": "N(C(=O)C)[C@@H](Cc1[nH]cnc1)C(=O)O",
239
- "[Acetyl]-I": "N(C(=O)C)[C@@H]([C@H](CC)C)C(=O)O",
240
- "[Acetyl]-K": "N(C(C)=O)[C@@H](CCCCN)C(=O)O",
241
- "[Acetyl]-L": "N(C(=O)C)[C@@H](CC(C)C)C(=O)O",
242
- "[Acetyl]-M": "N(C(=O)C)[C@@H](CCSC)C(=O)O",
243
- "[Acetyl]-N": "N(C(C)=O)[C@@H](CC(=O)N)C(=O)O",
244
- "[Acetyl]-P": "N1(C(=O)C)CCC[C@H]1C(=O)O",
245
- "[Acetyl]-Q": "N(C(=O)C)[C@@H](CCC(=O)N)C(=O)O",
246
- "[Acetyl]-R": "N(C(C)=O)[C@@H](CCCN=C(N)N)C(=O)O",
247
- "[Acetyl]-S": "N(C(C)=O)[C@@H](CO)C(=O)O",
248
- "[Acetyl]-T": "N(C(=O)C)[C@@H]([C@H](O)C)C(=O)O",
249
- "[Acetyl]-V": "N(C(=O)C)[C@@H](C(C)C)C(=O)O",
250
- "[Acetyl]-W": "N(C(C)=O)[C@@H](Cc1c2ccccc2[nH]c1)C(=O)O",
251
- "[Acetyl]-Y": "N(C(C)=O)[C@@H](Cc1ccc(O)cc1)C(=O)O"
252
- }
253
-
254
- residues_reverse = {}
255
- def register_peptide_residues(residues: dict[str, str]):
256
- for residue, smiles in residues.items():
257
- residues[residue] = Chem.MolToSmiles(Chem.MolFromSmiles(smiles))
258
- residues_reverse[residues[residue]] = residue
259
-
260
- register_peptide_residues(residues)
261
-
262
- def _extract_residue_type(residue_tag: str) -> str:
263
- pattern = r"(?<!\[)[A-Z](?![\w-])"
264
- return [match.group(0) for match in re.finditer(pattern, residue_tag)][0]
@@ -1,20 +0,0 @@
1
- molcraft/__init__.py,sha256=2ZNfWBjGl8DscOwjdDiRkgIsuPnKit29Q3MhZyP336Q,435
2
- molcraft/callbacks.py,sha256=6gwCwdsHGb-fVB4m1QGmtBwQwZ9mFq9QUkmPKSMn05U,849
3
- molcraft/chem.py,sha256=_UO5O-I7KUtGf3vRrFEYoAUGlW5xi2x8ylu5f-Ybumo,18696
4
- molcraft/conformers.py,sha256=p09gOQOdxLSj3yohZOMkxxLriHsZ1ZqOoiWLi73OpIg,4325
5
- molcraft/datasets.py,sha256=rFgXTC1ZheLhfgQgcCspP_wEE54a33PIneH7OplbS-8,4047
6
- molcraft/descriptors.py,sha256=x6RfZ-gK7D_WSvmK6sh6yHyEjQqovPnRU0xwC3dAKfg,2880
7
- molcraft/features.py,sha256=69oV_GHNdBKPA4sp6Tpo6brvNmaauk_IVIzNjX7VDmg,13648
8
- molcraft/featurizers.py,sha256=Yu8I6I_zkzB__WYSiqz-FDGjvKFOmyWFxojRBr39Aw8,26236
9
- molcraft/layers.py,sha256=HjnAtqhuP0uZ5yP4L33k3xT4IUdLavWBrjd3wO9_Rmw,64915
10
- molcraft/models.py,sha256=DXqWR_XnMVXQseVR91XnDLXvmHa1hv-6_Y_wvpQZBFI,17476
11
- molcraft/ops.py,sha256=iiE6zgA2P7cmjKO1RHmL9GE_Tv7Tyuo_xDoxB_ELZQM,3824
12
- molcraft/records.py,sha256=w4-bcWZEC0oVInrE1e0kQBroIaSCA0PN1JBPOtO6VUY,5251
13
- molcraft/tensors.py,sha256=b7PO-YOvV72s9g057ILJACKS2n2fn10VkO35gHXpssI,22312
14
- molcraft/experimental/__init__.py,sha256=x5h6LOO8bo3NPjkKKM9M1H-Kz6R3yxYhRSePoxHCdRE,42
15
- molcraft/experimental/peptides.py,sha256=RCuOTSwoYHGSdeYi6TWHdPIv2WC3avCZjKLdhEZQeXw,8997
16
- molcraft-0.1.0a3.dist-info/licenses/LICENSE,sha256=sbVeqlrtZ0V63uYhZGL5dCxUm8rBAOqe2avyA1zIQNk,1074
17
- molcraft-0.1.0a3.dist-info/METADATA,sha256=f_5sBinpFcGSqKLaSqGkZJ83gGQtZw1Pb3fkgq9aCBM,4088
18
- molcraft-0.1.0a3.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
19
- molcraft-0.1.0a3.dist-info/top_level.txt,sha256=dENV6MfOceshM6MQCgJlcN1ojZkiCL9B4F7XyUge3QM,9
20
- molcraft-0.1.0a3.dist-info/RECORD,,