nerdd-module 0.1.6__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. nerdd-module-0.1.6/LICENSE +21 -0
  2. nerdd-module-0.1.6/PKG-INFO +87 -0
  3. nerdd-module-0.1.6/README.md +61 -0
  4. nerdd-module-0.1.6/nerdd_module/__init__.py +8 -0
  5. nerdd-module-0.1.6/nerdd_module/abstract_model.py +274 -0
  6. nerdd-module-0.1.6/nerdd_module/cli.py +142 -0
  7. nerdd-module-0.1.6/nerdd_module/config/__init__.py +6 -0
  8. nerdd-module-0.1.6/nerdd_module/config/auto_configuration.py +48 -0
  9. nerdd-module-0.1.6/nerdd_module/config/configuration.py +52 -0
  10. nerdd-module-0.1.6/nerdd_module/config/default_configuration.py +15 -0
  11. nerdd-module-0.1.6/nerdd_module/config/dict_configuration.py +13 -0
  12. nerdd-module-0.1.6/nerdd_module/config/merged_configuration.py +16 -0
  13. nerdd-module-0.1.6/nerdd_module/config/yaml_configuration.py +45 -0
  14. nerdd-module-0.1.6/nerdd_module/io/__init__.py +17 -0
  15. nerdd-module-0.1.6/nerdd_module/io/csv_writer.py +40 -0
  16. nerdd-module-0.1.6/nerdd_module/io/elementary_reader.py +30 -0
  17. nerdd-module-0.1.6/nerdd_module/io/file_reader.py +30 -0
  18. nerdd-module-0.1.6/nerdd_module/io/guess_and_read.py +58 -0
  19. nerdd-module-0.1.6/nerdd_module/io/guessing_reader.py +52 -0
  20. nerdd-module-0.1.6/nerdd_module/io/inchi_file_reader.py +27 -0
  21. nerdd-module-0.1.6/nerdd_module/io/inchi_reader.py +35 -0
  22. nerdd-module-0.1.6/nerdd_module/io/list_reader.py +22 -0
  23. nerdd-module-0.1.6/nerdd_module/io/mol_block_reader.py +37 -0
  24. nerdd-module-0.1.6/nerdd_module/io/rdkit_mol_reader.py +21 -0
  25. nerdd-module-0.1.6/nerdd_module/io/reader.py +31 -0
  26. nerdd-module-0.1.6/nerdd_module/io/reader_registry.py +43 -0
  27. nerdd-module-0.1.6/nerdd_module/io/sdf_file_reader.py +42 -0
  28. nerdd-module-0.1.6/nerdd_module/io/sdf_writer.py +35 -0
  29. nerdd-module-0.1.6/nerdd_module/io/smiles_file_reader.py +28 -0
  30. nerdd-module-0.1.6/nerdd_module/io/smiles_reader.py +40 -0
  31. nerdd-module-0.1.6/nerdd_module/io/writer.py +36 -0
  32. nerdd-module-0.1.6/nerdd_module/io/writer_registry.py +40 -0
  33. nerdd-module-0.1.6/nerdd_module/preprocessing/__init__.py +8 -0
  34. nerdd-module-0.1.6/nerdd_module/preprocessing/check_valid_smiles.py +27 -0
  35. nerdd-module-0.1.6/nerdd_module/preprocessing/chembl_structure_pipeline.py +112 -0
  36. nerdd-module-0.1.6/nerdd_module/preprocessing/empty_pipeline.py +8 -0
  37. nerdd-module-0.1.6/nerdd_module/preprocessing/filter_by_element.py +29 -0
  38. nerdd-module-0.1.6/nerdd_module/preprocessing/filter_by_weight.py +29 -0
  39. nerdd-module-0.1.6/nerdd_module/preprocessing/pipeline.py +52 -0
  40. nerdd-module-0.1.6/nerdd_module/preprocessing/registry.py +20 -0
  41. nerdd-module-0.1.6/nerdd_module/preprocessing/remove_stereochemistry.py +21 -0
  42. nerdd-module-0.1.6/nerdd_module/preprocessing/step.py +24 -0
  43. nerdd-module-0.1.6/nerdd_module/version.py +10 -0
  44. nerdd-module-0.1.6/nerdd_module.egg-info/PKG-INFO +87 -0
  45. nerdd-module-0.1.6/nerdd_module.egg-info/SOURCES.txt +64 -0
  46. nerdd-module-0.1.6/nerdd_module.egg-info/dependency_links.txt +1 -0
  47. nerdd-module-0.1.6/nerdd_module.egg-info/requires.txt +18 -0
  48. nerdd-module-0.1.6/nerdd_module.egg-info/top_level.txt +2 -0
  49. nerdd-module-0.1.6/setup.cfg +4 -0
  50. nerdd-module-0.1.6/setup.py +51 -0
  51. nerdd-module-0.1.6/tests/__init__.py +0 -0
  52. nerdd-module-0.1.6/tests/conftest.py +7 -0
  53. nerdd-module-0.1.6/tests/models/AtomicMassModel.py +29 -0
  54. nerdd-module-0.1.6/tests/models/MolWeightModel.py +28 -0
  55. nerdd-module-0.1.6/tests/models/MolWeightModelWithExplicitMolIds.py +29 -0
  56. nerdd-module-0.1.6/tests/models/MolWeightModelWithExplicitMols.py +29 -0
  57. nerdd-module-0.1.6/tests/models/__init__.py +4 -0
  58. nerdd-module-0.1.6/tests/steps/__init__.py +4 -0
  59. nerdd-module-0.1.6/tests/steps/checks.py +37 -0
  60. nerdd-module-0.1.6/tests/steps/molecules.py +53 -0
  61. nerdd-module-0.1.6/tests/steps/predictors.py +52 -0
  62. nerdd-module-0.1.6/tests/steps/preprocessing.py +9 -0
  63. nerdd-module-0.1.6/tests/test_atom_property_prediction.py +66 -0
  64. nerdd-module-0.1.6/tests/test_molecule_property_prediction.py +68 -0
  65. nerdd-module-0.1.6/tests/test_preprocessing.py +12 -0
  66. nerdd-module-0.1.6/tests/test_reading_formats.py +134 -0
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2023 Molecular Informatics Vienna
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,87 @@
1
+ Metadata-Version: 2.1
2
+ Name: nerdd-module
3
+ Version: 0.1.6
4
+ Summary: Base package to create NERDD modules
5
+ Home-page: https://github.com/molinfo-vienna/nerdd-module.git
6
+ Maintainer: Steffen Hirte
7
+ Maintainer-email: steffen.hirte@univie.ac.at
8
+ Description-Content-Type: text/markdown
9
+ License-File: LICENSE
10
+ Requires-Dist: pandas>=1.2.1
11
+ Requires-Dist: pyyaml>=6.0
12
+ Requires-Dist: filetype~=1.2.0
13
+ Requires-Dist: rich-click>=1.7.1
14
+ Requires-Dist: stringcase>=1.2.0
15
+ Requires-Dist: chembl_structure_pipeline>=1.0.0
16
+ Provides-Extra: dev
17
+ Provides-Extra: test
18
+ Requires-Dist: pytest; extra == "test"
19
+ Requires-Dist: pytest-cov; extra == "test"
20
+ Requires-Dist: pytest-asyncio; extra == "test"
21
+ Requires-Dist: pytest-bdd; extra == "test"
22
+ Requires-Dist: pytest-mock; extra == "test"
23
+ Requires-Dist: pytest-watch; extra == "test"
24
+ Requires-Dist: hypothesis; extra == "test"
25
+ Requires-Dist: hypothesis-rdkit; extra == "test"
26
+
27
+ # Nerdd Module
28
+
29
+ This package provides the basis to implement molecular prediction modules in the
30
+ NERDD ecosystem.
31
+
32
+ ## Installation
33
+
34
+ ```pip install nerdd-module```
35
+
36
+
37
+ ## Implement your own module
38
+
39
+ A new module is created by inheriting from the ```AbstractModel``` class. A
40
+ preprocessing pipeline can be configured via calling the constructor of the superclass.
41
+ The actual prediction procedure is implemented in ```_predict_mols```:
42
+
43
+ ```python
44
+ import pandas as pd
45
+ from typing import List
46
+ from rdkit.Chem import Mol
47
+ from nerdd_module import AbstractModel
48
+
49
+ class MyModel(AbstractModel):
50
+ def __init__(self):
51
+ super().__init__(
52
+ preprocessing_pipeline="chembl_structure_pipeline",
53
+ )
54
+
55
+ def _predict_mols(self, mols: List[Mol], custom_param: int = 5) -> pd.DataFrame:
56
+ # implement prediction logic and return a dataframe with new columns
57
+ # containing values per input molecule
58
+ return pd.DataFrame(dict(predictions=[custom_param]*len(mols)))
59
+ ```
60
+
61
+ For custom preprocessing, specify ```preprocessing_pipeline="custom"``` when calling
62
+ the constructor of the superclass and override the method ```_preprocess_single_mol```:
63
+
64
+ ```python
65
+ class MyModel(AbstractModel):
66
+ def __init__(self):
67
+ # important:
68
+ super().__init__(preprocessing_pipeline="custom")
69
+
70
+ def _preprocess_single_mol(self, mol: Mol) -> Tuple[Mol, List[str]]:
71
+ # implement custom preprocessing logic here
72
+ # return preprocessed molecule and a list of error messages
73
+ return preprocessed_mol, errors
74
+ # ...
75
+ ```
76
+
77
+
78
+ ## Contribute
79
+
80
+ 1. Fork and clone the code
81
+ 2. Install test dependencies with ```pip install -e .[test]```
82
+ 3. Run tests via ```pytest``` or ```pytest-watch``` (short: ```ptw```)
83
+
84
+
85
+ ## Contributors
86
+
87
+ * Steffen Hirte
@@ -0,0 +1,61 @@
1
+ # Nerdd Module
2
+
3
+ This package provides the basis to implement molecular prediction modules in the
4
+ NERDD ecosystem.
5
+
6
+ ## Installation
7
+
8
+ ```pip install nerdd-module```
9
+
10
+
11
+ ## Implement your own module
12
+
13
+ A new module is created by inheriting from the ```AbstractModel``` class. A
14
+ preprocessing pipeline can be configured via calling the constructor of the superclass.
15
+ The actual prediction procedure is implemented in ```_predict_mols```:
16
+
17
+ ```python
18
+ import pandas as pd
19
+ from typing import List
20
+ from rdkit.Chem import Mol
21
+ from nerdd_module import AbstractModel
22
+
23
+ class MyModel(AbstractModel):
24
+ def __init__(self):
25
+ super().__init__(
26
+ preprocessing_pipeline="chembl_structure_pipeline",
27
+ )
28
+
29
+ def _predict_mols(self, mols: List[Mol], custom_param: int = 5) -> pd.DataFrame:
30
+ # implement prediction logic and return a dataframe with new columns
31
+ # containing values per input molecule
32
+ return pd.DataFrame(dict(predictions=[custom_param]*len(mols)))
33
+ ```
34
+
35
+ For custom preprocessing, specify ```preprocessing_pipeline="custom"``` when calling
36
+ the constructor of the superclass and override the method ```_preprocess_single_mol```:
37
+
38
+ ```python
39
+ class MyModel(AbstractModel):
40
+ def __init__(self):
41
+ # important:
42
+ super().__init__(preprocessing_pipeline="custom")
43
+
44
+ def _preprocess_single_mol(self, mol: Mol) -> Tuple[Mol, List[str]]:
45
+ # implement custom preprocessing logic here
46
+ # return preprocessed molecule and a list of error messages
47
+ return preprocessed_mol, errors
48
+ # ...
49
+ ```
50
+
51
+
52
+ ## Contribute
53
+
54
+ 1. Fork and clone the code
55
+ 2. Install test dependencies with ```pip install -e .[test]```
56
+ 3. Run tests via ```pytest``` or ```pytest-watch``` (short: ```ptw```)
57
+
58
+
59
+ ## Contributors
60
+
61
+ * Steffen Hirte
@@ -0,0 +1,8 @@
1
+ import pkg_resources
2
+
3
+ from .abstract_model import *
4
+ from .config import *
5
+ from .version import *
6
+
7
+ for entry_point in pkg_resources.iter_entry_points("nerdd-module.plugins"):
8
+ entry_point.load()
@@ -0,0 +1,274 @@
1
+ from abc import ABC, abstractmethod
2
+ from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
3
+
4
+ import pandas as pd
5
+ from rdkit.Chem import Mol, MolToSmiles
6
+
7
+ from .config import AutoConfiguration, Configuration
8
+ from .io import MoleculeEntry, guess_and_read
9
+ from .preprocessing import Pipeline, Step, registry
10
+
11
+ __all__ = ["AbstractModel"]
12
+
13
+
14
+ class CustomPreprocessingStep(Step):
15
+ def __init__(self, fn: Callable[[Mol], Tuple[Mol, List[str]]]):
16
+ super().__init__()
17
+ self.fn = fn
18
+
19
+ def _run(self, mol: Mol) -> Tuple[Mol, List[str]]:
20
+ return self.fn(mol)
21
+
22
+
23
+ class AbstractModel(ABC):
24
+ def __init__(
25
+ self,
26
+ preprocessing_pipeline: Union[str, Pipeline, Iterable[Step], None],
27
+ num_processes: int = 1,
28
+ ):
29
+ #
30
+ # preprocessing pipeline
31
+ #
32
+ if preprocessing_pipeline is None or preprocessing_pipeline == "custom":
33
+ self.preprocessing_pipeline = Pipeline(
34
+ steps=[CustomPreprocessingStep(self._preprocess_single_mol)]
35
+ )
36
+ elif isinstance(preprocessing_pipeline, Pipeline):
37
+ self.preprocessing_pipeline = preprocessing_pipeline
38
+ elif isinstance(preprocessing_pipeline, str):
39
+ if preprocessing_pipeline in registry:
40
+ self.preprocessing_pipeline = registry[preprocessing_pipeline]
41
+ else:
42
+ raise ValueError(
43
+ "Invalid preprocessing pipeline. Choose one of the following: "
44
+ ", ".join(list(registry.keys()) + ["custom"])
45
+ )
46
+ elif isinstance(preprocessing_pipeline, Iterable) and all(
47
+ isinstance(step, Step) for step in preprocessing_pipeline
48
+ ):
49
+ # mypy assumes that preprocessing_pipeline might be a string (although we
50
+ # checked this case above) and complains about that when constructing the
51
+ # pipeline
52
+ # --> explicitly assert that preprocessing_pipeline is not a string
53
+ assert not isinstance(preprocessing_pipeline, str)
54
+ self.preprocessing_pipeline = Pipeline(steps=preprocessing_pipeline)
55
+ else:
56
+ raise ValueError(
57
+ f"Invalid preprocessing pipeline {preprocessing_pipeline}."
58
+ )
59
+
60
+ #
61
+ # reading molecules
62
+ #
63
+
64
+ # add methods for all supported formats
65
+ # TODO
66
+
67
+ #
68
+ # other parameters
69
+ #
70
+ self.num_processes = num_processes
71
+
72
+ def _preprocess_single_mol(self, mol: Mol) -> Tuple[Mol, List[str]]:
73
+ # if this method is called, the preprocessing_pipeline was set to "custom"
74
+ # and this method has to be overwritten
75
+ raise NotImplementedError()
76
+
77
+ @abstractmethod
78
+ def _predict_mols(self, mols: List[Mol], **kwargs) -> pd.DataFrame:
79
+ pass
80
+
81
+ def _predict_entries(
82
+ self,
83
+ inputs: Iterable[MoleculeEntry],
84
+ **kwargs,
85
+ ) -> pd.DataFrame:
86
+ """
87
+ 'preprocessed_mol', 'mol_id', 'input_mol', 'input_type', 'name',
88
+ 'input_smiles', 'preprocessed_smiles', 'atom_id', 'mass', 'errors',
89
+ 'input'
90
+ """
91
+ #
92
+ # LOAD MOLECULES
93
+ #
94
+ df_load = pd.DataFrame(
95
+ inputs,
96
+ columns=["input", "input_type", "source", "mol", "load_errors"],
97
+ )
98
+ df_load["mol_id"] = range(len(df_load))
99
+
100
+ #
101
+ # PREPROCESS ALL MOLECULES
102
+ #
103
+ df_preprocess = pd.DataFrame(
104
+ [self.preprocessing_pipeline.run(mol) for mol in df_load.mol],
105
+ columns=["preprocessed_mol", "preprocessing_errors"],
106
+ )
107
+
108
+ # necessary for models that create multiple (or zero) entries per molecule
109
+ df_preprocess["mol_id"] = range(len(df_preprocess))
110
+
111
+ # add raw molecules to dataframe
112
+ df_preprocess["input_mol"] = df_load.mol
113
+
114
+ # add name to dataframe
115
+ df_preprocess["name"] = [
116
+ (mol.GetProp("_Name") if mol is not None and mol.HasProp("_Name") else "")
117
+ for mol in df_preprocess.input_mol
118
+ ]
119
+
120
+ # add smiles columns for web UI
121
+ def _to_smiles(mol):
122
+ try:
123
+ return MolToSmiles(mol)
124
+ except:
125
+ return None
126
+
127
+ #
128
+ # PREPARE PREDICTION OF MOLECULES
129
+ #
130
+
131
+ # each molecule gets its unique id (0, 1, ..., n) as its name
132
+ for id, mol in zip(df_preprocess.mol_id, df_preprocess.preprocessed_mol):
133
+ if mol is not None:
134
+ mol.SetProp("_Name", str(id))
135
+
136
+ # do the prediction on molecules that are not None
137
+ df_valid_subset = df_preprocess[df_preprocess.preprocessed_mol.notnull()]
138
+
139
+ #
140
+ # PREDICTION
141
+ #
142
+ df_predictions = self._predict_mols(
143
+ df_valid_subset.preprocessed_mol.tolist(), **kwargs
144
+ )
145
+
146
+ #
147
+ # POST PROCESSING AND ERROR HANDLING
148
+ #
149
+
150
+ # make sure that reserved column names do not appear in the output dataframe
151
+ reserved_column_names = ["input", "name", "input_mol"]
152
+ assert (
153
+ set(df_predictions.columns).intersection(reserved_column_names) == set()
154
+ ), f"Do not use reserved column names {', '.join(reserved_column_names)}!"
155
+
156
+ # during prediction, molecules might have been removed / reordered
157
+ # there are three ways to connect the predictions to the original molecules:
158
+ # 1. df_prediction contains a column "mol_id" that contains the molecule ids
159
+ # 2. df_prediction contains a column "mol" that contains the molecules, which
160
+ # have the id as their name so that we can match them to the original
161
+ # 3. df_prediction has the same length as the number of valid molecules
162
+ # (and we assume that the order of the molecules is the same)
163
+ if "mol_id" in df_predictions.columns:
164
+ # check that mol_id contains only valid ids
165
+ assert set(df_predictions.mol_id).issubset(
166
+ set(df_valid_subset.mol_id)
167
+ ), "The mol_id column must only contain valid ids!"
168
+ # use mol_id as index
169
+ df_predictions.set_index("mol_id", drop=True, inplace=True)
170
+ elif "mol" in df_predictions.columns:
171
+ # check that molecule names contain only valid ids
172
+ names = df_predictions.mol.apply(lambda mol: int(mol.GetProp("_Name")))
173
+ assert set(names).issubset(
174
+ set(df_preprocess.mol_id)
175
+ ), "The molecule names must only contain valid ids!"
176
+
177
+ # use mol_id as index
178
+ df_predictions.set_index(
179
+ names,
180
+ inplace=True,
181
+ )
182
+ df_predictions.drop(columns="mol", inplace=True)
183
+ else:
184
+ assert len(df_predictions) == len(df_valid_subset), (
185
+ "The number of predicted molecules must be equal to the number of "
186
+ "valid input molecules."
187
+ )
188
+ # use index from input series (type cast if series was empty)
189
+ df_predictions.set_index(
190
+ df_valid_subset.index.astype("int64"), inplace=True
191
+ )
192
+
193
+ # add column that indicates whether a molecule was missing
194
+ missing_mol_ids = set(df_preprocess.mol_id).difference(df_predictions.index)
195
+ df_preprocess["missing"] = df_preprocess.mol_id.isin(missing_mol_ids)
196
+
197
+ # merge the preprocessed molecules with the predictions
198
+ df_result = df_preprocess.merge(
199
+ df_predictions, left_on="mol_id", right_index=True, how="left"
200
+ )
201
+
202
+ # if the result has multiple entries per mol_id, check that atom_id or
203
+ # derivative_id is present
204
+ if len(df_result) > df_result.mol_id.nunique():
205
+ assert (
206
+ "atom_id" in df_result.columns or "derivative_id" in df_result.columns
207
+ ), (
208
+ "The result contains multiple entries per molecule, but does not "
209
+ "contain atom_id or derivative_id."
210
+ )
211
+
212
+ # merge errors from preprocessing and prediction
213
+ if "prediction_errors" in df_result.columns:
214
+ df_result["errors"] = (
215
+ df_result.preprocessing_errors + df_result.prediction_errors
216
+ )
217
+ df_result.drop(columns=["prediction_errors"], inplace=True)
218
+ else:
219
+ df_result["errors"] = df_result.preprocessing_errors
220
+ df_result["errors"] = df_result.errors + df_result.missing.map(
221
+ lambda x: ["!1"] if x else []
222
+ )
223
+ df_result.drop(columns=["missing", "preprocessing_errors"], inplace=True)
224
+
225
+ # convert errors to string
226
+ if "errors" in df_result.columns:
227
+ df_result["errors"] = df_result.errors.map(lambda x: ", ".join(set(x)))
228
+ else:
229
+ df_result["errors"] = ""
230
+
231
+ # delete mol column (not needed anymore)
232
+ df_load.drop(columns=["mol"], inplace=True)
233
+
234
+ # merge load and prediction
235
+ df_result = df_result.merge(df_load, on="mol_id", how="left")
236
+
237
+ # merge errors from loading and prediction
238
+ df_result["errors"] = [
239
+ ", ".join(set(load_errors + [prediction_errors]))
240
+ for load_errors, prediction_errors in zip(
241
+ df_result.load_errors, df_result.errors
242
+ )
243
+ ]
244
+
245
+ df_result.drop(columns=["load_errors"], inplace=True)
246
+
247
+ # reorder columns
248
+ mandatory_columns = [
249
+ "mol_id",
250
+ "input",
251
+ "input_type",
252
+ "source",
253
+ "name",
254
+ "input_mol",
255
+ "preprocessed_mol",
256
+ "errors",
257
+ ]
258
+ remaining_columns = [c for c in df_result.columns if c not in mandatory_columns]
259
+ df_result = df_result[mandatory_columns + remaining_columns]
260
+
261
+ return df_result
262
+
263
+ def predict(
264
+ self,
265
+ inputs: Union[Iterable[str], Iterable[Mol], str, Mol],
266
+ input_type=None,
267
+ **kwargs,
268
+ ):
269
+ entries = guess_and_read(inputs)
270
+
271
+ return self._predict_entries(entries, **kwargs)
272
+
273
+ def get_config(self) -> Configuration:
274
+ return AutoConfiguration(self)
@@ -0,0 +1,142 @@
1
+ import logging
2
+ import os
3
+ import sys
4
+
5
+ import rich_click as click
6
+ from decorator import decorator
7
+
8
+ from nerdd_module.io import WriterRegistry
9
+
10
+ __all__ = ["auto_cli"]
11
+
12
+ input_description = """{description}
13
+
14
+ INPUT molecules are provided as file paths or strings. The following formats are
15
+ supported:
16
+
17
+ {format_list}
18
+
19
+ Note that input formats shouldn't be mixed.
20
+ """
21
+
22
+
23
+ def infer_click_type(param):
24
+ if "choices" in param:
25
+ choices = [c["value"] for c in param["choices"]]
26
+ return click.Choice(choices)
27
+
28
+ type_map = {
29
+ "float": float,
30
+ "int": int,
31
+ "str": str,
32
+ "bool": bool,
33
+ }
34
+
35
+ return type_map[param.get("type")]
36
+
37
+
38
+ @decorator
39
+ def auto_cli(f, *args, **kwargs):
40
+ # infer the command name
41
+ command_name = os.path.basename(sys.argv[0])
42
+
43
+ # get the model
44
+ model = f()
45
+
46
+ config = model.get_config().get_dict()
47
+
48
+ # compose cli description
49
+ description = config.get("description", "")
50
+
51
+ format_list = "\n".join([f"* {fmt}" for fmt in ["smiles", "sdf", "inchi"]])
52
+
53
+ help_text = input_description.format(
54
+ description=description, format_list=format_list
55
+ )
56
+
57
+ # compose footer with examples
58
+ examples = []
59
+ if "example_smiles" in config:
60
+ examples.append(config["example_smiles"])
61
+
62
+ if len(examples) > 0:
63
+ footer = "Examples:\n"
64
+ for example in examples:
65
+ footer += f"* {command_name} {example}\n"
66
+ else:
67
+ footer = ""
68
+
69
+ # show_default=True: default values are shown in the help text
70
+ # show_metavars_column=False: the column types are not in a separate column
71
+ # append_metavars_help=True: the column types are shown below the help text
72
+ @click.command(context_settings={"show_default": True}, help=help_text)
73
+ @click.rich_config(
74
+ help_config=click.RichHelpConfiguration(
75
+ use_markdown=True,
76
+ show_metavars_column=False,
77
+ append_metavars_help=True,
78
+ footer_text=footer,
79
+ )
80
+ )
81
+ @click.argument("input", type=click.Path(), nargs=-1, required=True)
82
+ def main(
83
+ input,
84
+ format: str,
85
+ output: click.Path,
86
+ log_level: str,
87
+ **kwargs,
88
+ ):
89
+ logging.basicConfig(level=log_level.upper())
90
+
91
+ df_result = model.predict(input, **kwargs)
92
+
93
+ if output.lower() == "stdout":
94
+ output_handle = sys.stdout
95
+ else:
96
+ output_handle = click.open_file(output, "wb")
97
+
98
+ # write results
99
+ assert format in WriterRegistry().supported_formats
100
+ writer = WriterRegistry().get_writer(format)
101
+
102
+ entries = (tup._asdict() for tup in df_result.itertuples(index=False))
103
+ writer.write(output_handle, entries)
104
+
105
+ #
106
+ # Add job parameters
107
+ #
108
+ for param in config["job_parameters"]:
109
+ main = click.option(
110
+ f"--{param['name']}",
111
+ default=param.get("default", None),
112
+ type=infer_click_type(param),
113
+ help=param.get("help_text", None),
114
+ )(main)
115
+
116
+ #
117
+ # Add other options
118
+ #
119
+ main = click.option(
120
+ "--output",
121
+ default="stdout",
122
+ type=click.Path(),
123
+ help="The output file. If 'stdout' is specified, the output is written to stdout.",
124
+ )(main)
125
+
126
+ main = click.option(
127
+ "--format",
128
+ default="csv",
129
+ type=click.Choice(["csv", "sdf"], case_sensitive=False),
130
+ help="The output format.",
131
+ )(main)
132
+
133
+ main = click.option(
134
+ "--log-level",
135
+ default="warning",
136
+ type=click.Choice(
137
+ ["debug", "info", "warning", "error", "critical"], case_sensitive=False
138
+ ),
139
+ help="The logging level.",
140
+ )(main)
141
+
142
+ return main()
@@ -0,0 +1,6 @@
1
+ from .auto_configuration import *
2
+ from .configuration import *
3
+ from .default_configuration import *
4
+ from .dict_configuration import *
5
+ from .merged_configuration import *
6
+ from .yaml_configuration import *
@@ -0,0 +1,48 @@
1
+ import os
2
+ import sys
3
+
4
+ from .configuration import Configuration
5
+ from .default_configuration import DefaultConfiguration
6
+ from .dict_configuration import DictConfiguration
7
+ from .merged_configuration import MergedConfiguration
8
+ from .yaml_configuration import YamlConfiguration
9
+
10
+ __all__ = ["AutoConfiguration"]
11
+
12
+
13
+ class AutoConfiguration(Configuration):
14
+ def __init__(self, nerdd_module):
15
+ super().__init__()
16
+
17
+ nerdd_module_class = nerdd_module.__class__
18
+
19
+ configs = []
20
+
21
+ # 1. module has a default configuration (containing default values)
22
+ configs.append(DefaultConfiguration(nerdd_module))
23
+
24
+ # 2. module can be configured via a yaml file
25
+ # search for nerdd.yml
26
+ # start at the directory containing the file where nerdd_module_class is
27
+ # defined and go up the directory tree until nerdd.yml is found
28
+ leaf = sys.modules[nerdd_module_class.__module__].__file__ or ""
29
+ while True:
30
+ if os.path.isfile(os.path.join(leaf, "nerdd.yml")):
31
+ default_config_file = os.path.join(leaf, "nerdd.yml")
32
+ break
33
+ elif leaf == os.path.dirname(leaf):
34
+ default_config_file = None
35
+ break
36
+ leaf = os.path.dirname(leaf)
37
+
38
+ if default_config_file is not None:
39
+ configs.append(YamlConfiguration(default_config_file))
40
+
41
+ # 3. module can be configured via the method _get_config in the module
42
+ if hasattr(nerdd_module, "_get_config"):
43
+ configs.append(DictConfiguration(nerdd_module._get_config()))
44
+
45
+ self.delegate = MergedConfiguration(*configs)
46
+
47
+ def _get_dict(self):
48
+ return self.delegate._get_dict()