nerdd-module 0.3.37__tar.gz → 0.3.39__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (99) hide show
  1. {nerdd_module-0.3.37 → nerdd_module-0.3.39}/PKG-INFO +1 -1
  2. {nerdd_module-0.3.37 → nerdd_module-0.3.39}/nerdd_module/cli.py +2 -2
  3. {nerdd_module-0.3.37 → nerdd_module-0.3.39}/nerdd_module/config/models.py +2 -2
  4. {nerdd_module-0.3.37 → nerdd_module-0.3.39}/nerdd_module/model/__init__.py +1 -1
  5. nerdd_module-0.3.37/nerdd_module/model/simple_model.py → nerdd_module-0.3.39/nerdd_module/model/model.py +57 -28
  6. nerdd_module-0.3.37/nerdd_module/model/model.py → nerdd_module-0.3.39/nerdd_module/model/prediction_step.py +86 -117
  7. {nerdd_module-0.3.37 → nerdd_module-0.3.39}/nerdd_module/problem.py +14 -0
  8. {nerdd_module-0.3.37 → nerdd_module-0.3.39}/nerdd_module/tests/models/AtomicMassModel.py +2 -2
  9. {nerdd_module-0.3.37 → nerdd_module-0.3.39}/nerdd_module/tests/models/MolWeightModel.py +2 -2
  10. {nerdd_module-0.3.37 → nerdd_module-0.3.39}/nerdd_module.egg-info/PKG-INFO +1 -1
  11. {nerdd_module-0.3.37 → nerdd_module-0.3.39}/nerdd_module.egg-info/SOURCES.txt +1 -1
  12. {nerdd_module-0.3.37 → nerdd_module-0.3.39}/pyproject.toml +1 -1
  13. {nerdd_module-0.3.37 → nerdd_module-0.3.39}/LICENSE +0 -0
  14. {nerdd_module-0.3.37 → nerdd_module-0.3.39}/README.md +0 -0
  15. {nerdd_module-0.3.37 → nerdd_module-0.3.39}/nerdd_module/__init__.py +0 -0
  16. {nerdd_module-0.3.37 → nerdd_module-0.3.39}/nerdd_module/config/__init__.py +0 -0
  17. {nerdd_module-0.3.37 → nerdd_module-0.3.39}/nerdd_module/config/configuration.py +0 -0
  18. {nerdd_module-0.3.37 → nerdd_module-0.3.39}/nerdd_module/config/default_configuration.py +0 -0
  19. {nerdd_module-0.3.37 → nerdd_module-0.3.39}/nerdd_module/config/dict_configuration.py +0 -0
  20. {nerdd_module-0.3.37 → nerdd_module-0.3.39}/nerdd_module/config/merged_configuration.py +0 -0
  21. {nerdd_module-0.3.37 → nerdd_module-0.3.39}/nerdd_module/config/package_configuration.py +0 -0
  22. {nerdd_module-0.3.37 → nerdd_module-0.3.39}/nerdd_module/config/search_yaml_configuration.py +0 -0
  23. {nerdd_module-0.3.37 → nerdd_module-0.3.39}/nerdd_module/config/yaml_configuration.py +0 -0
  24. {nerdd_module-0.3.37 → nerdd_module-0.3.39}/nerdd_module/converters/__init__.py +0 -0
  25. {nerdd_module-0.3.37 → nerdd_module-0.3.39}/nerdd_module/converters/basic_type_converter.py +0 -0
  26. {nerdd_module-0.3.37 → nerdd_module-0.3.39}/nerdd_module/converters/converter.py +0 -0
  27. {nerdd_module-0.3.37 → nerdd_module-0.3.39}/nerdd_module/converters/converter_config.py +0 -0
  28. {nerdd_module-0.3.37 → nerdd_module-0.3.39}/nerdd_module/converters/mol_converter.py +0 -0
  29. {nerdd_module-0.3.37 → nerdd_module-0.3.39}/nerdd_module/converters/problem_list_converter.py +0 -0
  30. {nerdd_module-0.3.37 → nerdd_module-0.3.39}/nerdd_module/converters/representation_converter.py +0 -0
  31. {nerdd_module-0.3.37 → nerdd_module-0.3.39}/nerdd_module/converters/source_list_converter.py +0 -0
  32. {nerdd_module-0.3.37 → nerdd_module-0.3.39}/nerdd_module/converters/void_converter.py +0 -0
  33. {nerdd_module-0.3.37 → nerdd_module-0.3.39}/nerdd_module/input/__init__.py +0 -0
  34. {nerdd_module-0.3.37 → nerdd_module-0.3.39}/nerdd_module/input/depth_first_explorer.py +0 -0
  35. {nerdd_module-0.3.37 → nerdd_module-0.3.39}/nerdd_module/input/explorer.py +0 -0
  36. {nerdd_module-0.3.37 → nerdd_module-0.3.39}/nerdd_module/input/file_reader.py +0 -0
  37. {nerdd_module-0.3.37 → nerdd_module-0.3.39}/nerdd_module/input/gzip_reader.py +0 -0
  38. {nerdd_module-0.3.37 → nerdd_module-0.3.39}/nerdd_module/input/inchi_reader.py +0 -0
  39. {nerdd_module-0.3.37 → nerdd_module-0.3.39}/nerdd_module/input/list_reader.py +0 -0
  40. {nerdd_module-0.3.37 → nerdd_module-0.3.39}/nerdd_module/input/mol_reader.py +0 -0
  41. {nerdd_module-0.3.37 → nerdd_module-0.3.39}/nerdd_module/input/reader.py +0 -0
  42. {nerdd_module-0.3.37 → nerdd_module-0.3.39}/nerdd_module/input/reader_config.py +0 -0
  43. {nerdd_module-0.3.37 → nerdd_module-0.3.39}/nerdd_module/input/sdf_reader.py +0 -0
  44. {nerdd_module-0.3.37 → nerdd_module-0.3.39}/nerdd_module/input/smiles_reader.py +0 -0
  45. {nerdd_module-0.3.37 → nerdd_module-0.3.39}/nerdd_module/input/string_reader.py +0 -0
  46. {nerdd_module-0.3.37 → nerdd_module-0.3.39}/nerdd_module/input/tar_reader.py +0 -0
  47. {nerdd_module-0.3.37 → nerdd_module-0.3.39}/nerdd_module/input/zip_reader.py +0 -0
  48. {nerdd_module-0.3.37 → nerdd_module-0.3.39}/nerdd_module/model/assign_name_step.py +0 -0
  49. {nerdd_module-0.3.37 → nerdd_module-0.3.39}/nerdd_module/model/convert_representations_step.py +0 -0
  50. {nerdd_module-0.3.37 → nerdd_module-0.3.39}/nerdd_module/model/enforce_schema_step.py +0 -0
  51. {nerdd_module-0.3.37 → nerdd_module-0.3.39}/nerdd_module/model/read_input_step.py +0 -0
  52. {nerdd_module-0.3.37 → nerdd_module-0.3.39}/nerdd_module/model/write_output_step.py +0 -0
  53. {nerdd_module-0.3.37 → nerdd_module-0.3.39}/nerdd_module/output/__init__.py +0 -0
  54. {nerdd_module-0.3.37 → nerdd_module-0.3.39}/nerdd_module/output/csv_writer.py +0 -0
  55. {nerdd_module-0.3.37 → nerdd_module-0.3.39}/nerdd_module/output/file_writer.py +0 -0
  56. {nerdd_module-0.3.37 → nerdd_module-0.3.39}/nerdd_module/output/iterator_writer.py +0 -0
  57. {nerdd_module-0.3.37 → nerdd_module-0.3.39}/nerdd_module/output/pandas_writer.py +0 -0
  58. {nerdd_module-0.3.37 → nerdd_module-0.3.39}/nerdd_module/output/record_list_writer.py +0 -0
  59. {nerdd_module-0.3.37 → nerdd_module-0.3.39}/nerdd_module/output/sdf_writer.py +0 -0
  60. {nerdd_module-0.3.37 → nerdd_module-0.3.39}/nerdd_module/output/writer.py +0 -0
  61. {nerdd_module-0.3.37 → nerdd_module-0.3.39}/nerdd_module/output/writer_config.py +0 -0
  62. {nerdd_module-0.3.37 → nerdd_module-0.3.39}/nerdd_module/polyfills/__init__.py +0 -0
  63. {nerdd_module-0.3.37 → nerdd_module-0.3.39}/nerdd_module/polyfills/block_logs.py +0 -0
  64. {nerdd_module-0.3.37 → nerdd_module-0.3.39}/nerdd_module/polyfills/files.py +0 -0
  65. {nerdd_module-0.3.37 → nerdd_module-0.3.39}/nerdd_module/polyfills/get_entry_points.py +0 -0
  66. {nerdd_module-0.3.37 → nerdd_module-0.3.39}/nerdd_module/polyfills/literal.py +0 -0
  67. {nerdd_module-0.3.37 → nerdd_module-0.3.39}/nerdd_module/polyfills/typed_dict.py +0 -0
  68. {nerdd_module-0.3.37 → nerdd_module-0.3.39}/nerdd_module/polyfills/types.py +0 -0
  69. {nerdd_module-0.3.37 → nerdd_module-0.3.39}/nerdd_module/polyfills/version.py +0 -0
  70. {nerdd_module-0.3.37 → nerdd_module-0.3.39}/nerdd_module/preprocessing/__init__.py +0 -0
  71. {nerdd_module-0.3.37 → nerdd_module-0.3.39}/nerdd_module/preprocessing/check_valid_smiles.py +0 -0
  72. {nerdd_module-0.3.37 → nerdd_module-0.3.39}/nerdd_module/preprocessing/chembl_structure_pipeline.py +0 -0
  73. {nerdd_module-0.3.37 → nerdd_module-0.3.39}/nerdd_module/preprocessing/filter_by_element.py +0 -0
  74. {nerdd_module-0.3.37 → nerdd_module-0.3.39}/nerdd_module/preprocessing/filter_by_weight.py +0 -0
  75. {nerdd_module-0.3.37 → nerdd_module-0.3.39}/nerdd_module/preprocessing/preprocessing_step.py +0 -0
  76. {nerdd_module-0.3.37 → nerdd_module-0.3.39}/nerdd_module/preprocessing/remove_stereochemistry.py +0 -0
  77. {nerdd_module-0.3.37 → nerdd_module-0.3.39}/nerdd_module/preprocessing/sanitize.py +0 -0
  78. {nerdd_module-0.3.37 → nerdd_module-0.3.39}/nerdd_module/py.typed +0 -0
  79. {nerdd_module-0.3.37 → nerdd_module-0.3.39}/nerdd_module/steps/__init__.py +0 -0
  80. {nerdd_module-0.3.37 → nerdd_module-0.3.39}/nerdd_module/steps/map_step.py +0 -0
  81. {nerdd_module-0.3.37 → nerdd_module-0.3.39}/nerdd_module/steps/output_step.py +0 -0
  82. {nerdd_module-0.3.37 → nerdd_module-0.3.39}/nerdd_module/steps/step.py +0 -0
  83. {nerdd_module-0.3.37 → nerdd_module-0.3.39}/nerdd_module/tests/__init__.py +0 -0
  84. {nerdd_module-0.3.37 → nerdd_module-0.3.39}/nerdd_module/tests/checks.py +0 -0
  85. {nerdd_module-0.3.37 → nerdd_module-0.3.39}/nerdd_module/tests/files.py +0 -0
  86. {nerdd_module-0.3.37 → nerdd_module-0.3.39}/nerdd_module/tests/models/__init__.py +0 -0
  87. {nerdd_module-0.3.37 → nerdd_module-0.3.39}/nerdd_module/tests/predictions.py +0 -0
  88. {nerdd_module-0.3.37 → nerdd_module-0.3.39}/nerdd_module/tests/preprocessing/DummyPreprocessingStep.py +0 -0
  89. {nerdd_module-0.3.37 → nerdd_module-0.3.39}/nerdd_module/tests/preprocessing/__init__.py +0 -0
  90. {nerdd_module-0.3.37 → nerdd_module-0.3.39}/nerdd_module/tests/representations.py +0 -0
  91. {nerdd_module-0.3.37 → nerdd_module-0.3.39}/nerdd_module/util/__init__.py +0 -0
  92. {nerdd_module-0.3.37 → nerdd_module-0.3.39}/nerdd_module/util/call_with_mappings.py +0 -0
  93. {nerdd_module-0.3.37 → nerdd_module-0.3.39}/nerdd_module/util/package.py +0 -0
  94. {nerdd_module-0.3.37 → nerdd_module-0.3.39}/nerdd_module/version.py +0 -0
  95. {nerdd_module-0.3.37 → nerdd_module-0.3.39}/nerdd_module.egg-info/dependency_links.txt +0 -0
  96. {nerdd_module-0.3.37 → nerdd_module-0.3.39}/nerdd_module.egg-info/requires.txt +0 -0
  97. {nerdd_module-0.3.37 → nerdd_module-0.3.39}/nerdd_module.egg-info/top_level.txt +0 -0
  98. {nerdd_module-0.3.37 → nerdd_module-0.3.39}/setup.cfg +0 -0
  99. {nerdd_module-0.3.37 → nerdd_module-0.3.39}/tests/test_features.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: nerdd-module
3
- Version: 0.3.37
3
+ Version: 0.3.39
4
4
  Summary: Base package to create NERDD modules
5
5
  Author-email: Steffen Hirte <steffen.hirte@univie.ac.at>
6
6
  Maintainer-email: Steffen Hirte <steffen.hirte@univie.ac.at>
@@ -56,7 +56,7 @@ def auto_cli(f: Callable[..., Model], *args: Any, **kwargs: Any) -> None:
56
56
  input_format_list = "\n".join([f"* {fmt}" for fmt in ["smiles", "sdf", "inchi"]])
57
57
 
58
58
  help_text = input_description.format(
59
- description=model.description, input_format_list=input_format_list
59
+ description=model.config.description, input_format_list=input_format_list
60
60
  )
61
61
 
62
62
  output_format_list = [
@@ -117,7 +117,7 @@ def auto_cli(f: Callable[..., Model], *args: Any, **kwargs: Any) -> None:
117
117
  #
118
118
  # Add job parameters
119
119
  #
120
- for param in model.job_parameters:
120
+ for param in model.config.job_parameters:
121
121
  # convert parameter name to spinal case (e.g. "max_confs" -> "max-confs")
122
122
  param_name = spinalcase(param.name)
123
123
  main = click.option(
@@ -123,7 +123,7 @@ class Module(BaseModel):
123
123
  return spinalcase(self.name)
124
124
 
125
125
  task: Optional[Task] = None
126
- rank: Optional[int] = None
126
+ rank: Optional[float] = None
127
127
  name: str
128
128
  batch_size: int = 100
129
129
  version: Optional[str] = None
@@ -203,7 +203,7 @@ class Module(BaseModel):
203
203
  for i, j in zip(indices[:-1], indices[1:]):
204
204
  assert i + 1 == j, (
205
205
  f"Properties with the same group should appear next to each other, "
206
- f"but group {group} appears at incides {i} and {j}."
206
+ f"but group {group} appears at indices {i} and {j}."
207
207
  )
208
208
 
209
209
  return values
@@ -1,6 +1,6 @@
1
1
  from .assign_name_step import *
2
2
  from .convert_representations_step import *
3
3
  from .model import *
4
+ from .prediction_step import *
4
5
  from .read_input_step import *
5
- from .simple_model import *
6
6
  from .write_output_step import *
@@ -1,4 +1,5 @@
1
- from abc import abstractmethod
1
+ import logging
2
+ from abc import ABC, abstractmethod
2
3
  from functools import cached_property
3
4
  from typing import Any, Iterable, List, Optional, Tuple, Union
4
5
 
@@ -8,7 +9,6 @@ from ..config import (
8
9
  Configuration,
9
10
  DefaultConfiguration,
10
11
  DictConfiguration,
11
- JobParameter,
12
12
  MergedConfiguration,
13
13
  Module,
14
14
  PackageConfiguration,
@@ -17,21 +17,22 @@ from ..config import (
17
17
  from ..input import DepthFirstExplorer
18
18
  from ..preprocessing import PreprocessingStep
19
19
  from ..problem import Problem
20
- from ..steps import Step
20
+ from ..steps import OutputStep, Step
21
21
  from ..util import get_file_path_to_instance
22
22
  from .assign_name_step import AssignNameStep
23
23
  from .convert_representations_step import ConvertRepresentationsStep
24
24
  from .enforce_schema_step import EnforceSchemaStep
25
- from .model import Model
25
+ from .prediction_step import PredictionStep
26
26
  from .read_input_step import ReadInputStep
27
27
  from .write_output_step import WriteOutputStep
28
28
 
29
- __all__ = ["SimpleModel"]
29
+ logger = logging.getLogger(__name__)
30
30
 
31
31
 
32
- class SimpleModel(Model):
32
+ class Model(ABC):
33
33
  def __init__(self, preprocessing_steps: Iterable[Step] = []) -> None:
34
34
  super().__init__()
35
+
35
36
  assert isinstance(
36
37
  preprocessing_steps, Iterable
37
38
  ), f"Expected Iterable for argument preprocessing_steps, got {type(preprocessing_steps)}"
@@ -39,8 +40,12 @@ class SimpleModel(Model):
39
40
  f"Expected all elements of preprocessing_steps to be of type Step, "
40
41
  f"got {[type(step) for step in preprocessing_steps if not isinstance(step, Step)]}"
41
42
  )
43
+
42
44
  self._preprocessing_steps = preprocessing_steps
43
45
 
46
+ def _preprocess(self, mol: Mol) -> Tuple[Optional[Mol], List[Problem]]:
47
+ return mol, []
48
+
44
49
  def _get_input_steps(
45
50
  self, input: Any, input_format: Optional[str], **kwargs: Any
46
51
  ) -> List[Step]:
@@ -59,6 +64,10 @@ class SimpleModel(Model):
59
64
  CustomPreprocessingStep(self),
60
65
  ]
61
66
 
67
+ @abstractmethod
68
+ def _predict_mols(self, mols: List[Mol], **kwargs: Any) -> Iterable[dict]:
69
+ pass
70
+
62
71
  def _get_postprocessing_steps(self, output_format: Optional[str], **kwargs: Any) -> List[Step]:
63
72
  output_format = output_format or "pandas"
64
73
  return [
@@ -67,13 +76,46 @@ class SimpleModel(Model):
67
76
  WriteOutputStep(output_format, config=self.config, **kwargs),
68
77
  ]
69
78
 
70
- def _preprocess(self, mol: Mol) -> Tuple[Optional[Mol], List[Problem]]:
71
- return mol, []
79
+ def predict(
80
+ self,
81
+ input: Any,
82
+ input_format: Optional[str] = None,
83
+ output_format: Optional[str] = None,
84
+ **kwargs: Any,
85
+ ) -> Any:
86
+ input_steps = self._get_input_steps(input, input_format, **kwargs)
87
+ preprocessing_steps = self._get_preprocessing_steps(input, input_format, **kwargs)
88
+ postprocessing_steps = self._get_postprocessing_steps(output_format, **kwargs)
89
+ output_step = postprocessing_steps[-1]
90
+
91
+ assert isinstance(output_step, OutputStep), "The last step must be an OutputStep."
92
+
93
+ # make mypy happy by restricting the type of self.config.task
94
+ assert self.config.task is not None
95
+
96
+ steps = [
97
+ *input_steps,
98
+ *preprocessing_steps,
99
+ PredictionStep(
100
+ self._predict_mols,
101
+ task=self.config.task,
102
+ batch_size=self.config.batch_size,
103
+ **kwargs,
104
+ ),
105
+ *postprocessing_steps,
106
+ ]
72
107
 
73
- @abstractmethod
74
- def _predict_mols(self, mols: List[Mol], **kwargs: Any) -> List[dict]:
75
- pass
108
+ # build the pipeline from the list of steps
109
+ pipeline = None
110
+ for t in steps:
111
+ pipeline = t(pipeline)
76
112
 
113
+ # the last pipeline step holds the result
114
+ return output_step.get_result()
115
+
116
+ #
117
+ # Configuration
118
+ #
77
119
  def _get_base_config(self) -> Union[Configuration, dict]:
78
120
  # get the class of the nerdd module, e.g. <CypstrateModel>
79
121
  nerdd_module_class = self.__class__
@@ -107,6 +149,9 @@ class SimpleModel(Model):
107
149
  if isinstance(base_config, dict):
108
150
  base_config = DictConfiguration(base_config)
109
151
 
152
+ # ensure that mandatory properties are present
153
+ base_config = MergedConfiguration(DefaultConfiguration(self), base_config)
154
+
110
155
  # add default properties mol_id, raw_input, etc.
111
156
  task = base_config.get_dict().task
112
157
 
@@ -180,7 +225,6 @@ class SimpleModel(Model):
180
225
  ]
181
226
 
182
227
  configs = [
183
- DefaultConfiguration(self),
184
228
  DictConfiguration({"result_properties": default_properties_start}),
185
229
  base_config,
186
230
  DictConfiguration({"result_properties": default_properties_end}),
@@ -192,24 +236,9 @@ class SimpleModel(Model):
192
236
  def config(self) -> Module:
193
237
  return self._get_config().get_dict()
194
238
 
195
- def _get_batch_size(self) -> int:
196
- default = super()._get_batch_size()
197
- return self.config.batch_size or default
198
-
199
- def _get_name(self) -> str:
200
- default = super()._get_name()
201
- return self.config.name or default
202
-
203
- def _get_description(self) -> str:
204
- default = super()._get_description()
205
- return self.config.description or default
206
-
207
- def _get_job_parameters(self) -> List[JobParameter]:
208
- return super()._get_job_parameters() + self.config.job_parameters
209
-
210
239
 
211
240
  class CustomPreprocessingStep(PreprocessingStep):
212
- def __init__(self, model: SimpleModel):
241
+ def __init__(self, model: Model):
213
242
  super().__init__()
214
243
  self.model = model
215
244
 
@@ -1,114 +1,24 @@
1
1
  import logging
2
- from abc import ABC, abstractmethod
3
2
  from collections import defaultdict
4
- from typing import Any, Iterable, Iterator, List, Optional, Tuple
3
+ from typing import Any, Callable, DefaultDict, Iterator, List, Set, Tuple
5
4
 
6
- from rdkit.Chem import Mol
7
- from stringcase import snakecase # type: ignore
8
-
9
- from ..config import JobParameter
10
- from ..problem import Problem
11
- from ..steps import OutputStep, Step
5
+ from ..config import Task
6
+ from ..problem import IncompletePredictionProblem, UnknownPredictionProblem
7
+ from ..steps import Step
12
8
  from ..util import call_with_mappings
13
9
 
14
10
  logger = logging.getLogger(__name__)
15
11
 
16
-
17
- # an unknown prediction problem indicates that the model raised an exception during
18
- # prediction
19
- def UnknownPredictionProblem() -> Problem:
20
- return Problem("unknown_prediction_error", "An unknown error occured during prediction.")
21
-
22
-
23
- # an incomplete prediction problem indicates that the model successfully returns
24
- # predictions, but part of the input molecules are missing in the results
25
- def IncompletePredictionProblem() -> Problem:
26
- return Problem("incomplete_prediction_error", "The model couldn't process the molecule.")
27
-
28
-
29
- class Model(ABC):
30
- def __init__(self) -> None:
31
- super().__init__()
32
-
33
- @abstractmethod
34
- def _predict_mols(self, mols: List[Mol], **kwargs: Any) -> Iterable[dict]:
35
- pass
36
-
37
- @abstractmethod
38
- def _get_input_steps(
39
- self, input: Any, input_format: Optional[str], **kwargs: Any
40
- ) -> List[Step]:
41
- pass
42
-
43
- @abstractmethod
44
- def _get_preprocessing_steps(
45
- self, input: Any, input_format: Optional[str], **kwargs: Any
46
- ) -> List[Step]:
47
- pass
48
-
49
- @abstractmethod
50
- def _get_postprocessing_steps(self, output_format: Optional[str], **kwargs: Any) -> List[Step]:
51
- pass
52
-
53
- def predict(
54
- self,
55
- input: Any,
56
- input_format: Optional[str] = None,
57
- output_format: Optional[str] = None,
58
- **kwargs: Any,
59
- ) -> Any:
60
- input_steps = self._get_input_steps(input, input_format, **kwargs)
61
- preprocessing_steps = self._get_preprocessing_steps(input, input_format, **kwargs)
62
- postprocessing_steps = self._get_postprocessing_steps(output_format, **kwargs)
63
- output_step = postprocessing_steps[-1]
64
-
65
- assert isinstance(output_step, OutputStep), "The last step must be an OutputStep."
66
-
67
- steps = [
68
- *input_steps,
69
- *preprocessing_steps,
70
- PredictionStep(self, batch_size=self.batch_size, **kwargs),
71
- *postprocessing_steps,
72
- ]
73
-
74
- # build the pipeline from the list of steps
75
- pipeline = None
76
- for t in steps:
77
- pipeline = t(pipeline)
78
-
79
- # the last pipeline step holds the result
80
- return output_step.get_result()
81
-
82
- #
83
- # Properties
84
- #
85
- def _get_batch_size(self) -> int:
86
- return 1
87
-
88
- batch_size = property(fget=lambda self: self._get_batch_size())
89
-
90
- def _get_name(self) -> str:
91
- return snakecase(self.__class__.__name__)
92
-
93
- name = property(fget=lambda self: self._get_name())
94
-
95
- def _get_description(self) -> str:
96
- return ""
97
-
98
- description = property(fget=lambda self: self._get_description())
99
-
100
- def _get_job_parameters(self) -> List[JobParameter]:
101
- return []
102
-
103
- job_parameters = property(fget=lambda self: self._get_job_parameters())
12
+ __all__ = ["PredictionStep"]
104
13
 
105
14
 
106
15
  class PredictionStep(Step):
107
- def __init__(self, model: Model, batch_size: int, **kwargs: Any) -> None:
16
+ def __init__(self, predict_fn: Callable, task: Task, batch_size: int, **kwargs: Any) -> None:
108
17
  super().__init__()
109
- self.model = model
110
- self.batch_size = batch_size
111
- self.kwargs = kwargs
18
+ self._predict_fn = predict_fn
19
+ self._task = task
20
+ self._batch_size = batch_size
21
+ self._kwargs = kwargs
112
22
 
113
23
  def _run(self, source: Iterator[dict]) -> Iterator[dict]:
114
24
  # We need to process the molecules in batches, because most ML models perform
@@ -131,7 +41,7 @@ class PredictionStep(Step):
131
41
  if len(batch) > 0 or len(none_batch) > 0:
132
42
  yield batch, none_batch
133
43
 
134
- for batch, none_batch in _batch_and_filter(source, self.batch_size):
44
+ for batch, none_batch in _batch_and_filter(source, self._batch_size):
135
45
  # return the records where mols are None
136
46
  yield from none_batch
137
47
 
@@ -151,8 +61,8 @@ class PredictionStep(Step):
151
61
  if len(batch) > 0:
152
62
  predictions = list(
153
63
  call_with_mappings(
154
- self.model._predict_mols,
155
- {**self.kwargs, "mols": mols},
64
+ self._predict_fn,
65
+ {**self._kwargs, "mols": mols},
156
66
  )
157
67
  )
158
68
  else:
@@ -208,10 +118,25 @@ class PredictionStep(Step):
208
118
  record["mol_id"] in mol_id_set
209
119
  ), f"The mol_id {record['mol_id']} is not in the batch."
210
120
 
121
+ # depending on the task, we need to check atom_id or derivative_id
122
+ if self._task == "atom_property_prediction":
123
+ sub_id_property = "atom_id"
124
+ elif self._task == "derivative_property_prediction":
125
+ sub_id_property = "derivative_id"
126
+ else:
127
+ sub_id_property = None
128
+
211
129
  # create a mapping from mol_id to record (for quick access)
212
- mol_id_to_record = defaultdict(list)
130
+ mol_id_to_record: DefaultDict[int, List[dict]] = defaultdict(list)
213
131
  for record in predictions:
214
- mol_id_to_record[record["mol_id"]].append(record)
132
+ current_record_list = mol_id_to_record[record["mol_id"]]
133
+ current_record_list.append(record)
134
+ if len(current_record_list) > 1 and sub_id_property is None:
135
+ raise ValueError(
136
+ f"There are duplicate records for mol_id={record['mol_id']}, but the "
137
+ f"prediction task {self._task} requires unique mol_id values. The duplicates "
138
+ f"are: {current_record_list}."
139
+ )
215
140
 
216
141
  # add all records that are missing in the predictions
217
142
  for mol_id in temporary_mol_ids:
@@ -224,19 +149,63 @@ class PredictionStep(Step):
224
149
  }
225
150
  )
226
151
 
227
- # If the result has multiple entries per mol_id, check that atom_id or
228
- # derivative_id is present in multi-entry results.
229
- if len(predictions) > len(batch):
230
- for _, records in mol_id_to_record.items():
231
- if len(records) > 1:
232
- has_atom_id = all("atom_id" in record for record in records)
233
- has_derivative_id = all("derivative_id" in record for record in records)
234
- assert has_atom_id or has_derivative_id, (
235
- "The result contains multiple entries per molecule, but does "
236
- "not contain atom_id or derivative_id."
152
+ if sub_id_property is not None:
153
+ # task must be either atom_property_prediction or derivative_property_prediction
154
+ # -> check consistency of sub_id_property
155
+ for mol_id, records in mol_id_to_record.items():
156
+ sub_ids: Set[int] = set()
157
+
158
+ for record in records:
159
+ sub_id = record.get(sub_id_property)
160
+ if sub_id is not None:
161
+ # check that sub_id is an integer
162
+ if not isinstance(sub_id, int):
163
+ raise ValueError(
164
+ f"The {sub_id_property} must be an integer, but got {sub_id}. "
165
+ f"Record: {record}"
166
+ )
167
+
168
+ sub_ids.add(sub_id)
169
+
170
+ if (
171
+ len(records) == 1
172
+ and "problems" in records[0]
173
+ and len(records[0]["problems"]) > 0
174
+ ):
175
+ # this record was not predicted, so we skip it
176
+ continue
177
+ elif len(sub_ids) == 0:
178
+ # no record has a sub id, we assign them (sequentially)
179
+ for i, record in enumerate(records):
180
+ record[sub_id_property] = i
181
+ continue
182
+ elif len(sub_ids) < len(records):
183
+ # None is not in sub_ids, but the number of unique sub ids is less than
184
+ # the number of records.
185
+ # -> there must be duplicates
186
+ sub_id_list = [record.get(sub_id_property) for record in records]
187
+ raise ValueError(
188
+ f"The result with mol_id={mol_id} contains multiple entries per "
189
+ f"molecule, but the sequence of {sub_id_property} is not unique. "
190
+ f"Found: {sub_id_list}."
237
191
  )
238
-
239
- # TODO: check range and completeness of atom ids and derivative ids
192
+ else:
193
+ min_sub_id = min(sub_ids)
194
+ max_sub_id = max(sub_ids)
195
+
196
+ if min_sub_id != 0:
197
+ raise ValueError(
198
+ f"The sequence of {sub_id_property} does not start at 0 for "
199
+ f"mol_id={mol_id}. Instead, the minimum {sub_id_property} was "
200
+ f"{min_sub_id}."
201
+ )
202
+ elif max_sub_id - min_sub_id + 1 != len(sub_ids):
203
+ # there are gaps in the sequence of sub ids
204
+ raise ValueError(
205
+ f"The result with mol_id={mol_id} contains multiple entries per "
206
+ f"molecule, but the sequence of {sub_id_property} has gaps. "
207
+ f"Found: {sub_ids}."
208
+ )
240
209
 
241
210
  for key, records in mol_id_to_record.items():
242
211
  for record in records:
@@ -2,6 +2,8 @@ from typing import Iterable, NamedTuple
2
2
 
3
3
  __all__ = [
4
4
  "Problem",
5
+ "UnknownPredictionProblem",
6
+ "IncompletePredictionProblem",
5
7
  "InvalidSmiles",
6
8
  "UnknownProblem",
7
9
  "InvalidWeightProblem",
@@ -14,6 +16,18 @@ class Problem(NamedTuple):
14
16
  message: str
15
17
 
16
18
 
19
+ # an unknown prediction problem indicates that the model raised an exception during
20
+ # prediction
21
+ def UnknownPredictionProblem() -> Problem:
22
+ return Problem("unknown_prediction_error", "An unknown error occured during prediction.")
23
+
24
+
25
+ # an incomplete prediction problem indicates that the model successfully returns
26
+ # predictions, but part of the input molecules are missing in the results
27
+ def IncompletePredictionProblem() -> Problem:
28
+ return Problem("incomplete_prediction_error", "The model couldn't process the molecule.")
29
+
30
+
17
31
  def InvalidSmiles() -> Problem:
18
32
  return Problem(type="invalid_smiles", message="Invalid SMILES string")
19
33
 
@@ -1,4 +1,4 @@
1
- from nerdd_module import SimpleModel
1
+ from nerdd_module import Model
2
2
  from nerdd_module.preprocessing import Sanitize
3
3
 
4
4
  __all__ = ["AtomicMassModel"]
@@ -7,7 +7,7 @@ __all__ = ["AtomicMassModel"]
7
7
  allowed_versions = ["mol_ids", "mols", "iterator", "error"]
8
8
 
9
9
 
10
- class AtomicMassModel(SimpleModel):
10
+ class AtomicMassModel(Model):
11
11
  def __init__(self, preprocessing_steps=[Sanitize()], version="mol_ids", **kwargs):
12
12
  assert (
13
13
  version in allowed_versions
@@ -1,6 +1,6 @@
1
1
  from rdkit.Chem.rdMolDescriptors import CalcExactMolWt
2
2
 
3
- from nerdd_module import SimpleModel
3
+ from nerdd_module import Model
4
4
  from nerdd_module.preprocessing import Sanitize
5
5
 
6
6
  __all__ = ["MolWeightModel"]
@@ -8,7 +8,7 @@ __all__ = ["MolWeightModel"]
8
8
  allowed_versions = ["order_based", "mol_ids", "mols", "iterator", "error"]
9
9
 
10
10
 
11
- class MolWeightModel(SimpleModel):
11
+ class MolWeightModel(Model):
12
12
  def __init__(self, preprocessing_steps=[Sanitize()], version="order_based", **kwargs):
13
13
  assert (
14
14
  version in allowed_versions
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: nerdd-module
3
- Version: 0.3.37
3
+ Version: 0.3.39
4
4
  Summary: Base package to create NERDD modules
5
5
  Author-email: Steffen Hirte <steffen.hirte@univie.ac.at>
6
6
  Maintainer-email: Steffen Hirte <steffen.hirte@univie.ac.at>
@@ -49,8 +49,8 @@ nerdd_module/model/assign_name_step.py
49
49
  nerdd_module/model/convert_representations_step.py
50
50
  nerdd_module/model/enforce_schema_step.py
51
51
  nerdd_module/model/model.py
52
+ nerdd_module/model/prediction_step.py
52
53
  nerdd_module/model/read_input_step.py
53
- nerdd_module/model/simple_model.py
54
54
  nerdd_module/model/write_output_step.py
55
55
  nerdd_module/output/__init__.py
56
56
  nerdd_module/output/csv_writer.py
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "nerdd-module"
7
- version = "0.3.37"
7
+ version = "0.3.39"
8
8
  description = "Base package to create NERDD modules"
9
9
  readme = "README.md"
10
10
  license = "BSD-3-Clause"
File without changes
File without changes
File without changes