molcraft 0.1.0a9__py3-none-any.whl → 0.1.0a10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of molcraft might be problematic. Click here for more details.

molcraft/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- __version__ = '0.1.0a9'
1
+ __version__ = '0.1.0a10'
2
2
 
3
3
  import os
4
4
  os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
molcraft/layers.py CHANGED
@@ -350,7 +350,7 @@ class GraphConv(GraphLayer):
350
350
  )
351
351
  if self._project_residual:
352
352
  warnings.warn(
353
- '`skip_connect` is set to `True`, but found incompatible dim '
353
+ '`skip_connect` is set to `True`, but found incompatible dim '
354
354
  'between input (node feature dim) and output (`self.units`). '
355
355
  'Automatically applying a projection layer to residual to '
356
356
  'match input and output. ',
@@ -369,7 +369,7 @@ class GraphConv(GraphLayer):
369
369
  self._message_intermediate_activation = self.activation
370
370
  self._message_final_dense = self.get_dense(self.units)
371
371
 
372
- has_overridden_aggregate = self.__class__.message != GraphConv.aggregate
372
+ has_overridden_aggregate = self.__class__.message != GraphConv.aggregate
373
373
  if not has_overridden_aggregate:
374
374
  pass
375
375
 
@@ -401,13 +401,15 @@ class GraphConv(GraphLayer):
401
401
  residual = self._residual_dense(residual)
402
402
 
403
403
  message = self.message(tensor)
404
- if not isinstance(message, tensors.GraphTensor):
404
+ add_message = not isinstance(message, tensors.GraphTensor)
405
+ if add_message:
405
406
  message = tensor.update({'edge': {'message': message}})
406
407
  elif not 'message' in message.edge:
407
408
  raise ValueError('Could not find `message` in `edge` output.')
408
-
409
+
409
410
  aggregate = self.aggregate(message)
410
- if not isinstance(aggregate, tensors.GraphTensor):
411
+ add_aggregate = not isinstance(aggregate, tensors.GraphTensor)
412
+ if add_aggregate:
411
413
  aggregate = tensor.update({'node': {'aggregate': aggregate}})
412
414
  elif not 'aggregate' in aggregate.node:
413
415
  raise ValueError('Could not find `aggregate` in `node` output.')
@@ -421,6 +423,16 @@ class GraphConv(GraphLayer):
421
423
  if update.node['feature'].shape[-1] != self.units:
422
424
  raise ValueError('Updated node `feature` is not equal to `self.units`.')
423
425
 
426
+ if add_message and add_aggregate:
427
+ update = update.update({'node': {'aggregate': None}, 'edge': {'message': None}})
428
+ elif add_message:
429
+ update = update.update({'edge': {'message': None}})
430
+ elif add_aggregate:
431
+ update = update.update({'node': {'aggregate': None}})
432
+
433
+ if not self._skip_connect and not self._normalize:
434
+ return update
435
+
424
436
  feature = update.node['feature']
425
437
 
426
438
  if self._skip_connect:
molcraft/ops.py CHANGED
@@ -105,7 +105,11 @@ def segment_mean(
105
105
  lambda: 0
106
106
  )
107
107
  if backend.backend() == 'tensorflow':
108
- return tf.math.unsorted_segment_mean(
108
+ segment_mean_fn = (
109
+ tf.math.unsorted_segment_mean if not sorted else
110
+ tf.math.segment_mean
111
+ )
112
+ return segment_mean_fn(
109
113
  data=data,
110
114
  segment_ids=segment_ids,
111
115
  num_segments=num_segments
molcraft/records.py CHANGED
@@ -51,19 +51,24 @@ def write(
51
51
  if num_files is None:
52
52
  num_files = min(len(inputs), max(1, math.ceil(len(inputs) / 1_000)))
53
53
 
54
- chunk_size = math.ceil(len(inputs) / num_files)
55
- num_files = math.ceil(len(inputs) / chunk_size)
54
+ num_examples = len(inputs)
55
+ chunk_sizes = [0] * num_files
56
+ for i in range(num_examples):
57
+ chunk_sizes[i % num_files] += 1
58
+
59
+ input_chunks = []
60
+ current_index = 0
61
+ for size in chunk_sizes:
62
+ input_chunks.append(inputs[current_index: current_index + size])
63
+ current_index += size
64
+
65
+ assert current_index == num_examples
56
66
 
57
67
  paths = [
58
68
  os.path.join(path, f'tfrecord-{i:04d}.tfrecord')
59
69
  for i in range(num_files)
60
70
  ]
61
71
 
62
- input_chunks = [
63
- inputs[i * chunk_size: (i + 1) * chunk_size]
64
- for i in range(num_files)
65
- ]
66
-
67
72
  if not multiprocessing:
68
73
  for path, input_chunk in zip(paths, input_chunks):
69
74
  _write_tfrecord(input_chunk, path, featurizer)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: molcraft
3
- Version: 0.1.0a9
3
+ Version: 0.1.0a10
4
4
  Summary: Graph Neural Networks for Molecular Machine Learning
5
5
  Author-email: Alexander Kensert <alexander.kensert@gmail.com>
6
6
  License: MIT License
@@ -25,7 +25,7 @@ License: MIT License
25
25
  OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
26
26
  SOFTWARE.
27
27
 
28
- Project-URL: Homepage, https://github.com/akensert/molcraft
28
+ Project-URL: Homepage, https://github.com/compomics/molcraft
29
29
  Keywords: python,machine-learning,deep-learning,graph-neural-networks,molecular-machine-learning,molecular-graphs,computational-chemistry,computational-biology
30
30
  Classifier: Programming Language :: Python :: 3
31
31
  Classifier: Intended Audience :: Science/Research
@@ -1,4 +1,4 @@
1
- molcraft/__init__.py,sha256=8f1z8Lhuhh8TxB-QGHI5w4a3M_ZZNH8EWGD4Y6pB578,463
1
+ molcraft/__init__.py,sha256=1te1sOK-k4LT9l-mlxRmOhY3_Za-7jPezx_B3gdToiQ,464
2
2
  molcraft/callbacks.py,sha256=x5HnkZhqcFRrW6xdApt_jZ4X08A-0fxcnFKfdmRKa0c,3571
3
3
  molcraft/chem.py,sha256=zHH7iX0ZJ7QmP-YqR_IXCpylTwCXHXptWf1DsblnZR4,21496
4
4
  molcraft/conformers.py,sha256=K6ZtiSUNDN_fwqGP9JrPcwALLFFvlMlF_XejEJH3Sr4,4205
@@ -6,14 +6,14 @@ molcraft/datasets.py,sha256=rFgXTC1ZheLhfgQgcCspP_wEE54a33PIneH7OplbS-8,4047
6
6
  molcraft/descriptors.py,sha256=gKqlJ3BqJLTeR2ft8isftSEaJDC8cv64eTq5IYhy4XM,3032
7
7
  molcraft/features.py,sha256=aBYxDfQqQsVuyjKaPUlwEgvCjbNZ-FJhuKo2Cg5ajrA,13554
8
8
  molcraft/featurizers.py,sha256=ybJ1djH747cgsftztWHxAX2iTq6k03MYr17btQ2Gtcs,27063
9
- molcraft/layers.py,sha256=r6hEAyJxO_Yrw5hD1r2v8yb_UxLRK9S4FMjDCUQedH8,59655
9
+ molcraft/layers.py,sha256=G-ZFhnyiSny0YHGXg5tBYrvmhZsurBEJj_0mHD1zmlw,60135
10
10
  molcraft/losses.py,sha256=JEKZEX2f8vDgky_fUocsF8vZjy9VMzRjZUBa20Uf9Qw,1065
11
11
  molcraft/models.py,sha256=FLXpO3OUmRxLmxG3MjBK4ZwcVFlea1gqEgs1ibKly2w,23263
12
- molcraft/ops.py,sha256=dLIUq-KG8nOzEcphJqNbF_f82VZRDNrB1UKrcPt5JNM,4752
13
- molcraft/records.py,sha256=0sjOdcr266ZER4F-aTBQ3AVPNAwflKWNiNJVsSc1-PQ,5370
12
+ molcraft/ops.py,sha256=PVxKfY_XbWCyntiSnmpyeBb-coFGT_VNNP9QzmeUwC0,4870
13
+ molcraft/records.py,sha256=MbvYkcCunbAmpy_MWXmQ9WBGi2WvwxFUlwQSPKPvSSk,5534
14
14
  molcraft/tensors.py,sha256=EOUKx496KUZsjA1zA2ABc7tU_TW3Jv7AXDsug_QsLbA,22407
15
- molcraft-0.1.0a9.dist-info/licenses/LICENSE,sha256=sbVeqlrtZ0V63uYhZGL5dCxUm8rBAOqe2avyA1zIQNk,1074
16
- molcraft-0.1.0a9.dist-info/METADATA,sha256=HiwS2wmntCA7m_YpgSWKiJTP0BFpl4GWWz4a77w1XBw,4062
17
- molcraft-0.1.0a9.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
18
- molcraft-0.1.0a9.dist-info/top_level.txt,sha256=dENV6MfOceshM6MQCgJlcN1ojZkiCL9B4F7XyUge3QM,9
19
- molcraft-0.1.0a9.dist-info/RECORD,,
15
+ molcraft-0.1.0a10.dist-info/licenses/LICENSE,sha256=sbVeqlrtZ0V63uYhZGL5dCxUm8rBAOqe2avyA1zIQNk,1074
16
+ molcraft-0.1.0a10.dist-info/METADATA,sha256=Tmh4KckmdKr20q8RVPOKlogt343qTdOMzci6zgT6CfQ,4064
17
+ molcraft-0.1.0a10.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
18
+ molcraft-0.1.0a10.dist-info/top_level.txt,sha256=dENV6MfOceshM6MQCgJlcN1ojZkiCL9B4F7XyUge3QM,9
19
+ molcraft-0.1.0a10.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (80.7.1)
2
+ Generator: setuptools (80.9.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5