nerdd-module 0.2.6__tar.gz → 0.3.6__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (132) hide show
  1. nerdd_module-0.3.6/PKG-INFO +105 -0
  2. nerdd_module-0.3.6/README.md +18 -0
  3. {nerdd_module-0.2.6 → nerdd_module-0.3.6}/nerdd_module/__init__.py +2 -4
  4. {nerdd_module-0.2.6 → nerdd_module-0.3.6}/nerdd_module/cli.py +73 -59
  5. {nerdd_module-0.2.6 → nerdd_module-0.3.6}/nerdd_module/config/__init__.py +1 -1
  6. nerdd_module-0.3.6/nerdd_module/config/configuration.py +90 -0
  7. {nerdd_module-0.2.6 → nerdd_module-0.3.6}/nerdd_module/config/default_configuration.py +11 -13
  8. {nerdd_module-0.2.6 → nerdd_module-0.3.6}/nerdd_module/config/dict_configuration.py +4 -5
  9. nerdd_module-0.3.6/nerdd_module/config/merged_configuration.py +44 -0
  10. {nerdd_module-0.2.6 → nerdd_module-0.3.6}/nerdd_module/config/package_configuration.py +11 -7
  11. nerdd_module-0.3.6/nerdd_module/config/search_yaml_configuration.py +37 -0
  12. nerdd_module-0.3.6/nerdd_module/config/yaml_configuration.py +90 -0
  13. nerdd_module-0.3.6/nerdd_module/converters/__init__.py +2 -0
  14. nerdd_module-0.3.6/nerdd_module/converters/converter.py +62 -0
  15. nerdd_module-0.3.6/nerdd_module/converters/identity_converter.py +8 -0
  16. {nerdd_module-0.2.6 → nerdd_module-0.3.6}/nerdd_module/input/__init__.py +1 -1
  17. nerdd_module-0.3.6/nerdd_module/input/depth_first_explorer.py +129 -0
  18. nerdd_module-0.3.6/nerdd_module/input/explorer.py +16 -0
  19. {nerdd_module-0.2.6 → nerdd_module-0.3.6}/nerdd_module/input/file_reader.py +8 -9
  20. {nerdd_module-0.2.6 → nerdd_module-0.3.6}/nerdd_module/input/gzip_reader.py +4 -6
  21. {nerdd_module-0.2.6 → nerdd_module-0.3.6}/nerdd_module/input/inchi_reader.py +5 -7
  22. {nerdd_module-0.2.6 → nerdd_module-0.3.6}/nerdd_module/input/list_reader.py +4 -6
  23. {nerdd_module-0.2.6 → nerdd_module-0.3.6}/nerdd_module/input/mol_reader.py +4 -6
  24. nerdd_module-0.3.6/nerdd_module/input/reader.py +58 -0
  25. {nerdd_module-0.2.6 → nerdd_module-0.3.6}/nerdd_module/input/sdf_reader.py +5 -8
  26. {nerdd_module-0.2.6 → nerdd_module-0.3.6}/nerdd_module/input/smiles_reader.py +9 -8
  27. {nerdd_module-0.2.6 → nerdd_module-0.3.6}/nerdd_module/input/string_reader.py +4 -6
  28. {nerdd_module-0.2.6 → nerdd_module-0.3.6}/nerdd_module/input/tar_reader.py +4 -6
  29. {nerdd_module-0.2.6 → nerdd_module-0.3.6}/nerdd_module/input/zip_reader.py +4 -6
  30. nerdd_module-0.3.6/nerdd_module/model/__init__.py +7 -0
  31. nerdd_module-0.3.6/nerdd_module/model/assign_mol_id_step.py +17 -0
  32. nerdd_module-0.3.6/nerdd_module/model/assign_name_step.py +17 -0
  33. nerdd_module-0.3.6/nerdd_module/model/convert_representations_step.py +18 -0
  34. nerdd_module-0.3.6/nerdd_module/model/enforce_schema_step.py +28 -0
  35. nerdd_module-0.3.6/nerdd_module/model/model.py +260 -0
  36. nerdd_module-0.3.6/nerdd_module/model/read_input_step.py +24 -0
  37. nerdd_module-0.3.6/nerdd_module/model/simple_model.py +174 -0
  38. nerdd_module-0.3.6/nerdd_module/model/write_output_step.py +19 -0
  39. nerdd_module-0.3.6/nerdd_module/output/__init__.py +6 -0
  40. nerdd_module-0.3.6/nerdd_module/output/csv_writer.py +24 -0
  41. nerdd_module-0.3.6/nerdd_module/output/file_writer.py +41 -0
  42. nerdd_module-0.3.6/nerdd_module/output/iterator_writer.py +13 -0
  43. nerdd_module-0.3.6/nerdd_module/output/pandas_writer.py +16 -0
  44. nerdd_module-0.3.6/nerdd_module/output/record_list_writer.py +13 -0
  45. nerdd_module-0.3.6/nerdd_module/output/sdf_writer.py +33 -0
  46. nerdd_module-0.3.6/nerdd_module/output/writer.py +54 -0
  47. {nerdd_module-0.2.6 → nerdd_module-0.3.6}/nerdd_module/polyfills/__init__.py +1 -0
  48. nerdd_module-0.3.6/nerdd_module/polyfills/files.py +13 -0
  49. nerdd_module-0.3.6/nerdd_module/polyfills/get_entry_points.py +27 -0
  50. nerdd_module-0.3.6/nerdd_module/polyfills/types.py +14 -0
  51. {nerdd_module-0.2.6 → nerdd_module-0.3.6}/nerdd_module/preprocessing/__init__.py +1 -4
  52. {nerdd_module-0.2.6 → nerdd_module-0.3.6}/nerdd_module/preprocessing/check_valid_smiles.py +7 -7
  53. nerdd_module-0.3.6/nerdd_module/preprocessing/chembl_structure_pipeline.py +78 -0
  54. nerdd_module-0.3.6/nerdd_module/preprocessing/filter_by_element.py +67 -0
  55. {nerdd_module-0.2.6 → nerdd_module-0.3.6}/nerdd_module/preprocessing/filter_by_weight.py +15 -13
  56. nerdd_module-0.3.6/nerdd_module/preprocessing/preprocessing_step.py +61 -0
  57. nerdd_module-0.3.6/nerdd_module/preprocessing/remove_stereochemistry.py +24 -0
  58. nerdd_module-0.3.6/nerdd_module/preprocessing/sanitize.py +21 -0
  59. nerdd_module-0.3.6/nerdd_module/problem.py +16 -0
  60. nerdd_module-0.3.6/nerdd_module/steps/__init__.py +3 -0
  61. nerdd_module-0.3.6/nerdd_module/steps/map_step.py +38 -0
  62. nerdd_module-0.3.6/nerdd_module/steps/output_step.py +27 -0
  63. nerdd_module-0.3.6/nerdd_module/steps/step.py +27 -0
  64. nerdd_module-0.3.6/nerdd_module/tests/checks.py +188 -0
  65. nerdd_module-0.3.6/nerdd_module/tests/models/AtomicMassModel.py +65 -0
  66. nerdd_module-0.3.6/nerdd_module/tests/models/MolWeightModel.py +49 -0
  67. nerdd_module-0.3.6/nerdd_module/tests/models/__init__.py +2 -0
  68. nerdd_module-0.3.6/nerdd_module/tests/predictions.py +56 -0
  69. nerdd_module-0.3.6/nerdd_module/tests/preprocessing/DummyPreprocessingStep.py +25 -0
  70. nerdd_module-0.3.6/nerdd_module/tests/preprocessing/__init__.py +1 -0
  71. {nerdd_module-0.2.6 → nerdd_module-0.3.6}/nerdd_module/tests/representations.py +10 -6
  72. nerdd_module-0.3.6/nerdd_module/util/__init__.py +2 -0
  73. nerdd_module-0.3.6/nerdd_module/util/call_with_mappings.py +51 -0
  74. nerdd_module-0.3.6/nerdd_module/util/package.py +25 -0
  75. nerdd_module-0.3.6/nerdd_module.egg-info/PKG-INFO +105 -0
  76. {nerdd_module-0.2.6 → nerdd_module-0.3.6}/nerdd_module.egg-info/SOURCES.txt +34 -24
  77. {nerdd_module-0.2.6 → nerdd_module-0.3.6}/nerdd_module.egg-info/requires.txt +13 -5
  78. {nerdd_module-0.2.6 → nerdd_module-0.3.6}/nerdd_module.egg-info/top_level.txt +0 -1
  79. nerdd_module-0.3.6/pyproject.toml +134 -0
  80. nerdd_module-0.3.6/tests/test_features.py +3 -0
  81. nerdd_module-0.2.6/PKG-INFO +0 -73
  82. nerdd_module-0.2.6/README.md +0 -18
  83. nerdd_module-0.2.6/nerdd_module/abstract_model.py +0 -271
  84. nerdd_module-0.2.6/nerdd_module/config/auto_configuration.py +0 -62
  85. nerdd_module-0.2.6/nerdd_module/config/configuration.py +0 -52
  86. nerdd_module-0.2.6/nerdd_module/config/merged_configuration.py +0 -18
  87. nerdd_module-0.2.6/nerdd_module/config/yaml_configuration.py +0 -56
  88. nerdd_module-0.2.6/nerdd_module/input/depth_first_explorer.py +0 -111
  89. nerdd_module-0.2.6/nerdd_module/input/explorer.py +0 -13
  90. nerdd_module-0.2.6/nerdd_module/input/reader.py +0 -25
  91. nerdd_module-0.2.6/nerdd_module/input/reader_registry.py +0 -64
  92. nerdd_module-0.2.6/nerdd_module/output/__init__.py +0 -1
  93. nerdd_module-0.2.6/nerdd_module/output/csv_writer.py +0 -30
  94. nerdd_module-0.2.6/nerdd_module/output/sdf_writer.py +0 -35
  95. nerdd_module-0.2.6/nerdd_module/output/writer.py +0 -45
  96. nerdd_module-0.2.6/nerdd_module/output/writer_registry.py +0 -40
  97. nerdd_module-0.2.6/nerdd_module/polyfills/files.py +0 -8
  98. nerdd_module-0.2.6/nerdd_module/polyfills/get_entry_points.py +0 -18
  99. nerdd_module-0.2.6/nerdd_module/preprocessing/chembl_structure_pipeline.py +0 -124
  100. nerdd_module-0.2.6/nerdd_module/preprocessing/empty_pipeline.py +0 -8
  101. nerdd_module-0.2.6/nerdd_module/preprocessing/filter_by_element.py +0 -39
  102. nerdd_module-0.2.6/nerdd_module/preprocessing/pipeline.py +0 -53
  103. nerdd_module-0.2.6/nerdd_module/preprocessing/registry.py +0 -20
  104. nerdd_module-0.2.6/nerdd_module/preprocessing/remove_stereochemistry.py +0 -24
  105. nerdd_module-0.2.6/nerdd_module/preprocessing/sanitize.py +0 -18
  106. nerdd_module-0.2.6/nerdd_module/preprocessing/step.py +0 -26
  107. nerdd_module-0.2.6/nerdd_module/problem.py +0 -13
  108. nerdd_module-0.2.6/nerdd_module/tests/checks.py +0 -184
  109. nerdd_module-0.2.6/nerdd_module/tests/predictions.py +0 -30
  110. nerdd_module-0.2.6/nerdd_module.egg-info/PKG-INFO +0 -73
  111. nerdd_module-0.2.6/setup.py +0 -87
  112. nerdd_module-0.2.6/tests/conftest.py +0 -7
  113. nerdd_module-0.2.6/tests/models/AtomicMassModel.py +0 -29
  114. nerdd_module-0.2.6/tests/models/MolWeightModel.py +0 -25
  115. nerdd_module-0.2.6/tests/models/MolWeightModelWithExplicitMolIds.py +0 -30
  116. nerdd_module-0.2.6/tests/models/MolWeightModelWithExplicitMols.py +0 -27
  117. nerdd_module-0.2.6/tests/models/__init__.py +0 -4
  118. nerdd_module-0.2.6/tests/steps/__init__.py +0 -3
  119. nerdd_module-0.2.6/tests/steps/checks.py +0 -45
  120. nerdd_module-0.2.6/tests/steps/predictors.py +0 -52
  121. nerdd_module-0.2.6/tests/steps/preprocessing.py +0 -9
  122. nerdd_module-0.2.6/tests/test_atom_property_prediction.py +0 -66
  123. nerdd_module-0.2.6/tests/test_molecule_property_prediction.py +0 -60
  124. nerdd_module-0.2.6/tests/test_preprocessing.py +0 -12
  125. nerdd_module-0.2.6/tests/test_reading_formats.py +0 -137
  126. {nerdd_module-0.2.6 → nerdd_module-0.3.6}/LICENSE +0 -0
  127. {nerdd_module-0.2.6 → nerdd_module-0.3.6}/nerdd_module/polyfills/version.py +0 -0
  128. /nerdd_module-0.2.6/tests/__init__.py → /nerdd_module-0.3.6/nerdd_module/py.typed +0 -0
  129. {nerdd_module-0.2.6 → nerdd_module-0.3.6}/nerdd_module/tests/__init__.py +0 -0
  130. {nerdd_module-0.2.6 → nerdd_module-0.3.6}/nerdd_module/version.py +0 -0
  131. {nerdd_module-0.2.6 → nerdd_module-0.3.6}/nerdd_module.egg-info/dependency_links.txt +0 -0
  132. {nerdd_module-0.2.6 → nerdd_module-0.3.6}/setup.cfg +0 -0
@@ -0,0 +1,105 @@
1
+ Metadata-Version: 2.1
2
+ Name: nerdd-module
3
+ Version: 0.3.6
4
+ Summary: Base package to create NERDD modules
5
+ Author-email: Steffen Hirte <steffen.hirte@univie.ac.at>
6
+ Maintainer-email: Steffen Hirte <steffen.hirte@univie.ac.at>
7
+ License: BSD 3-Clause License
8
+
9
+ Copyright (c) 2023 - present, The Computational Drug Discovery and Design Group (COMP3D)
10
+
11
+ Redistribution and use in source and binary forms, with or without
12
+ modification, are permitted provided that the following conditions are met:
13
+
14
+ 1. Redistributions of source code must retain the above copyright notice, this
15
+ list of conditions and the following disclaimer.
16
+
17
+ 2. Redistributions in binary form must reproduce the above copyright notice,
18
+ this list of conditions and the following disclaimer in the documentation
19
+ and/or other materials provided with the distribution.
20
+
21
+ 3. Neither the name of the copyright holder nor the names of its
22
+ contributors may be used to endorse or promote products derived from
23
+ this software without specific prior written permission.
24
+
25
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
26
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
27
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
28
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
29
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
30
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
31
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
32
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
33
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
34
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
35
+ Project-URL: Repository, https://github.com/molinfo-vienna/nerdd-module
36
+ Keywords: science,research,development,nerdd
37
+ Classifier: Intended Audience :: Science/Research
38
+ Classifier: Intended Audience :: Developers
39
+ Classifier: License :: OSI Approved :: BSD License
40
+ Classifier: Programming Language :: Python
41
+ Classifier: Topic :: Software Development
42
+ Classifier: Topic :: Scientific/Engineering
43
+ Classifier: Operating System :: Microsoft :: Windows
44
+ Classifier: Operating System :: POSIX
45
+ Classifier: Operating System :: Unix
46
+ Classifier: Operating System :: MacOS
47
+ Classifier: Programming Language :: Python :: 3
48
+ Classifier: Programming Language :: Python :: 3.9
49
+ Classifier: Programming Language :: Python :: 3.10
50
+ Classifier: Programming Language :: Python :: 3.11
51
+ Classifier: Programming Language :: Python :: 3.12
52
+ Description-Content-Type: text/markdown
53
+ License-File: LICENSE
54
+ Requires-Dist: pandas>=1.2.1
55
+ Requires-Dist: pyyaml>=6.0
56
+ Requires-Dist: filetype~=1.2.0
57
+ Requires-Dist: rich-click>=1.7.1
58
+ Requires-Dist: stringcase>=1.2.0
59
+ Requires-Dist: decorator>=5.1.1
60
+ Requires-Dist: importlib-resources>=5; python_version < "3.9"
61
+ Requires-Dist: importlib-metadata>=4.6; python_version < "3.10"
62
+ Provides-Extra: dev
63
+ Requires-Dist: mypy; extra == "dev"
64
+ Requires-Dist: ruff; extra == "dev"
65
+ Requires-Dist: pandas-stubs; extra == "dev"
66
+ Requires-Dist: types-PyYAML; extra == "dev"
67
+ Requires-Dist: types-decorator; extra == "dev"
68
+ Requires-Dist: types-setuptools; extra == "dev"
69
+ Provides-Extra: rdkit
70
+ Requires-Dist: rdkit>=2022.3.3; extra == "rdkit"
71
+ Provides-Extra: csp
72
+ Requires-Dist: chembl_structure_pipeline>=1.0.0; extra == "csp"
73
+ Provides-Extra: test
74
+ Requires-Dist: pytest; extra == "test"
75
+ Requires-Dist: pytest-sugar; extra == "test"
76
+ Requires-Dist: pytest-cov; extra == "test"
77
+ Requires-Dist: pytest-asyncio; extra == "test"
78
+ Requires-Dist: pytest-bdd; extra == "test"
79
+ Requires-Dist: pytest-mock; extra == "test"
80
+ Requires-Dist: pytest-watcher; extra == "test"
81
+ Requires-Dist: hypothesis; extra == "test"
82
+ Requires-Dist: hypothesis-rdkit; extra == "test"
83
+ Provides-Extra: docs
84
+ Requires-Dist: mkdocs; extra == "docs"
85
+ Requires-Dist: mkdocs-material; extra == "docs"
86
+ Requires-Dist: mkdocstrings; extra == "docs"
87
+
88
+ # NERDD Module
89
+
90
+ This package provides the basis to implement molecular prediction modules in the
91
+ NERDD ecosystem.
92
+
93
+ ## Installation
94
+
95
+ ``` bash
96
+ pip install -U nerdd-module
97
+ ```
98
+
99
+
100
+ ## Contribute
101
+
102
+ 1. Fork and clone the code
103
+ 2. Install test dependencies with ` pip install -e .[test,dev,csp]`
104
+ 3. Run tests via `pytest` or `pytest-watch` (short: `ptw`)
105
+ 4. Build docs via ` pip install -e .[docs]` and ` mkdocs serve`
@@ -0,0 +1,18 @@
1
+ # NERDD Module
2
+
3
+ This package provides the basis to implement molecular prediction modules in the
4
+ NERDD ecosystem.
5
+
6
+ ## Installation
7
+
8
+ ``` bash
9
+ pip install -U nerdd-module
10
+ ```
11
+
12
+
13
+ ## Contribute
14
+
15
+ 1. Fork and clone the code
16
+ 2. Install test dependencies with ` pip install -e .[test,dev,csp]`
17
+ 3. Run tests via `pytest` or `pytest-watch` (short: `ptw`)
18
+ 4. Build docs via ` pip install -e .[docs]` and ` mkdocs serve`
@@ -1,10 +1,8 @@
1
- from .abstract_model import *
2
1
  from .cli import *
3
- from .config import *
2
+ from .model import *
3
+ from .polyfills import get_entry_points
4
4
  from .problem import *
5
5
  from .version import *
6
- from .polyfills import get_entry_points
7
-
8
6
 
9
7
  for entry_point in get_entry_points("nerdd-module.plugins"):
10
8
  entry_point.load()
@@ -1,11 +1,12 @@
1
1
  import logging
2
- import os
3
2
  import sys
3
+ from typing import Any, Callable
4
4
 
5
5
  import rich_click as click
6
6
  from decorator import decorator
7
- from nerdd_module.output import WriterRegistry
8
- from stringcase import spinalcase
7
+ from stringcase import spinalcase # type: ignore
8
+
9
+ from .model import Model
9
10
 
10
11
  __all__ = ["auto_cli"]
11
12
 
@@ -14,100 +15,97 @@ input_description = """{description}
14
15
  INPUT molecules are provided as file paths or strings. The following formats are
15
16
  supported:
16
17
 
17
- {format_list}
18
+ {input_format_list}
18
19
 
19
20
  Note that input formats shouldn't be mixed.
20
21
  """
21
22
 
22
23
 
23
- def infer_click_type(param):
24
+ def infer_click_type(param: dict) -> click.ParamType:
24
25
  if "choices" in param:
25
26
  choices = [c["value"] for c in param["choices"]]
26
27
  return click.Choice(choices)
27
28
 
28
29
  type_map = {
29
- "float": float,
30
- "int": int,
31
- "str": str,
32
- "bool": bool,
30
+ "float": click.FLOAT,
31
+ "int": click.INT,
32
+ "str": click.STRING,
33
+ "bool": click.BOOL,
33
34
  }
34
35
 
35
- return type_map[param.get("type")]
36
+ if "type" not in param:
37
+ raise ValueError(f"Parameter {param['name']} does not have a type")
38
+
39
+ t = param["type"]
40
+
41
+ if t not in type_map:
42
+ raise ValueError(f"Unknown type {t} for parameter {param['name']}")
43
+
44
+ return type_map[t]
36
45
 
37
46
 
38
47
  @decorator
39
- def auto_cli(f, *args, **kwargs):
48
+ def auto_cli(f: Callable[..., Model], *args: Any, **kwargs: Any) -> None:
40
49
  # infer the command name
41
- command_name = os.path.basename(sys.argv[0])
50
+ # command_name = os.path.basename(sys.argv[0])
42
51
 
43
52
  # get the model
44
53
  model = f()
45
54
 
46
- config = model.get_config().get_dict()
47
-
48
55
  # compose cli description
49
- description = config.get("description", "")
50
-
51
- format_list = "\n".join([f"* {fmt}" for fmt in ["smiles", "sdf", "inchi"]])
56
+ input_format_list = "\n".join([f"* {fmt}" for fmt in ["smiles", "sdf", "inchi"]])
52
57
 
53
58
  help_text = input_description.format(
54
- description=description, format_list=format_list
59
+ description=model.description, input_format_list=input_format_list
55
60
  )
56
61
 
57
- # compose footer with examples
58
- examples = []
59
- if "example_smiles" in config:
60
- examples.append(config["example_smiles"])
62
+ output_format_list = ["sdf", "csv"]
61
63
 
62
- if len(examples) > 0:
63
- footer = "Examples:\n"
64
- for example in examples:
65
- footer += f'* {command_name} "{example}"\n'
66
- else:
67
- footer = ""
64
+ # compose footer with examples
65
+ # TODO: add examples
66
+ # examples = []
67
+ # if "example_smiles" in config:
68
+ # examples.append(config["example_smiles"])
69
+
70
+ # if len(examples) > 0:
71
+ # footer = "Examples:\n"
72
+ # for example in examples:
73
+ # footer += f'* {command_name} "{example}"\n'
74
+ # else:
75
+ # footer = ""
76
+ footer = ""
68
77
 
69
- # show_default=True: default values are shown in the help text
70
- # show_metavars_column=False: the column types are not in a separate column
71
- # append_metavars_help=True: the column types are shown below the help text
72
- @click.command(context_settings={"show_default": True}, help=help_text)
73
- @click.rich_config(
74
- help_config=click.RichHelpConfiguration(
75
- use_markdown=True,
76
- show_metavars_column=False,
77
- append_metavars_help=True,
78
- footer_text=footer,
79
- )
80
- )
81
- @click.argument("input", type=click.Path(), nargs=-1, required=True)
78
+ #
79
+ # Define the CLI entry point
80
+ #
82
81
  def main(
83
- input,
82
+ input: Any,
84
83
  format: str,
85
84
  output: click.Path,
86
85
  log_level: str,
87
- **kwargs,
88
- ):
86
+ **kwargs: Any,
87
+ ) -> None:
89
88
  logging.basicConfig(level=log_level.upper())
90
89
 
91
- df_result = model.predict(input, **kwargs)
92
-
93
90
  # write results
94
- assert format in WriterRegistry().supported_formats
95
- writer = WriterRegistry().get_writer(format)
91
+ assert format in output_format_list, f"Unknown output format: {format}"
96
92
 
97
- if output.lower() == "stdout":
98
- assert not writer.writes_bytes, "stdout does not support binary output"
93
+ if str(output).lower() == "stdout":
99
94
  output_handle = sys.stdout
100
95
  else:
101
- mode = "wb" if writer.writes_bytes else "w"
102
- output_handle = click.open_file(output, mode)
96
+ output_handle = click.open_file(str(output), "wb")
103
97
 
104
- entries = (tup._asdict() for tup in df_result.itertuples(index=False))
105
- writer.write(output_handle, entries)
98
+ model.predict(input, output_format=format, output_file=output_handle, **kwargs)
99
+
100
+ #
101
+ # Add required input parameter
102
+ #
103
+ main = click.argument("input", type=click.Path(), nargs=-1, required=True)(main)
106
104
 
107
105
  #
108
106
  # Add job parameters
109
107
  #
110
- for param in config.get("job_parameters", []):
108
+ for param in model.job_parameters:
111
109
  # convert parameter name to spinal case (e.g. "max_confs" -> "max-confs")
112
110
  param_name = spinalcase(param["name"])
113
111
  main = click.option(
@@ -130,17 +128,33 @@ def auto_cli(f, *args, **kwargs):
130
128
  main = click.option(
131
129
  "--format",
132
130
  default="csv",
133
- type=click.Choice(WriterRegistry().supported_formats, case_sensitive=False),
131
+ type=click.Choice(output_format_list, case_sensitive=False),
134
132
  help="The output format.",
135
133
  )(main)
136
134
 
137
135
  main = click.option(
138
136
  "--log-level",
139
137
  default="warning",
140
- type=click.Choice(
141
- ["debug", "info", "warning", "error", "critical"], case_sensitive=False
142
- ),
138
+ type=click.Choice(["debug", "info", "warning", "error", "critical"], case_sensitive=False),
143
139
  help="The logging level.",
144
140
  )(main)
145
141
 
142
+ #
143
+ # Create Rich command
144
+ #
145
+
146
+ # show_metavars_column=False: the column types are not in a separate column
147
+ # append_metavars_help=True: the column types are shown below the help text
148
+ main = click.rich_config(
149
+ help_config=click.RichHelpConfiguration(
150
+ use_markdown=True,
151
+ show_metavars_column=False,
152
+ append_metavars_help=True,
153
+ footer_text=footer,
154
+ )
155
+ )(main)
156
+
157
+ # show_default=True: default values are shown in the help text
158
+ main = click.command(context_settings={"show_default": True}, help=help_text)(main)
159
+
146
160
  return main()
@@ -1,7 +1,7 @@
1
- from .auto_configuration import *
2
1
  from .configuration import *
3
2
  from .default_configuration import *
4
3
  from .dict_configuration import *
5
4
  from .merged_configuration import *
6
5
  from .package_configuration import *
6
+ from .search_yaml_configuration import *
7
7
  from .yaml_configuration import *
@@ -0,0 +1,90 @@
1
+ from abc import ABC, abstractmethod
2
+ from functools import lru_cache
3
+ from typing import List
4
+
5
+ __all__ = ["Configuration"]
6
+
7
+
8
+ def get_property_columns_of_type(config: dict, t: str) -> List[dict]:
9
+ return [c for c in config["result_properties"] if c.get("level", "molecule") == t]
10
+
11
+
12
+ def is_visible(result_property: dict, output_format: str) -> bool:
13
+ formats = result_property.get("formats", {})
14
+
15
+ if isinstance(formats, list):
16
+ return output_format in formats
17
+ elif isinstance(formats, dict):
18
+ include = formats.get("include", "*")
19
+ exclude = formats.get("exclude", [])
20
+ assert include == "*" or isinstance(
21
+ include, list
22
+ ), f"Expected include to be a list or '*', got {include}"
23
+ assert isinstance(exclude, list), f"Expected exclude to be a list, got {exclude}"
24
+ return (include == "*" or output_format in include) and output_format not in exclude
25
+ else:
26
+ raise ValueError(
27
+ f"Invalid formats declaration {formats} in result property " f"{result_property}"
28
+ )
29
+
30
+
31
+ class Configuration(ABC):
32
+ def __init__(self) -> None:
33
+ pass
34
+
35
+ @lru_cache(1)
36
+ def get_dict(self) -> dict:
37
+ config = self._get_dict()
38
+
39
+ if "result_properties" not in config:
40
+ config["result_properties"] = []
41
+
42
+ # check that a module can only predict atom or derivative properties, not both
43
+ num_atom_properties = len(get_property_columns_of_type(config, "atom"))
44
+ num_derivative_properties = len(get_property_columns_of_type(config, "derivative"))
45
+ assert (
46
+ num_atom_properties == 0 or num_derivative_properties == 0
47
+ ), "A module can only predict atom or derivative properties, not both."
48
+
49
+ return config
50
+
51
+ @abstractmethod
52
+ def _get_dict(self) -> dict:
53
+ pass
54
+
55
+ def is_empty(self) -> bool:
56
+ return self.get_dict() == {}
57
+
58
+ def molecular_property_columns(self) -> List[dict]:
59
+ return get_property_columns_of_type(self.get_dict(), "molecule")
60
+
61
+ def atom_property_columns(self) -> List[dict]:
62
+ return get_property_columns_of_type(self.get_dict(), "atom")
63
+
64
+ def derivative_property_columns(self) -> List[dict]:
65
+ return get_property_columns_of_type(self.get_dict(), "derivative")
66
+
67
+ def get_task(self) -> str:
68
+ # if task is specified in the config, use that
69
+ config = self.get_dict()
70
+ if "task" in config:
71
+ return config["task"]
72
+
73
+ # try to derive the task from the result_properties
74
+ num_atom_properties = len(self.atom_property_columns())
75
+ num_derivative_properties = len(self.derivative_property_columns())
76
+
77
+ if num_atom_properties > 0:
78
+ return "atom_property_prediction"
79
+ elif num_derivative_properties > 0:
80
+ return "derivative_property_prediction"
81
+ else:
82
+ return "molecular_property_prediction"
83
+
84
+ def get_visible_properties(self, output_format: str) -> List[dict]:
85
+ return [
86
+ p for p in self.get_dict().get("result_properties", []) if is_visible(p, output_format)
87
+ ]
88
+
89
+ def __repr__(self) -> str:
90
+ return f"{self.__class__.__name__}({self._get_dict()})"
@@ -1,15 +1,15 @@
1
- from stringcase import snakecase
1
+ from typing import Any
2
+
3
+ from stringcase import snakecase # type: ignore
2
4
 
3
5
  from ..polyfills import version
4
- from .configuration import Configuration
6
+ from .dict_configuration import DictConfiguration
5
7
 
6
8
  __all__ = ["DefaultConfiguration"]
7
9
 
8
10
 
9
- class DefaultConfiguration(Configuration):
10
- def __init__(self, nerdd_module):
11
- super().__init__()
12
-
11
+ class DefaultConfiguration(DictConfiguration):
12
+ def __init__(self, nerdd_module: Any) -> None:
13
13
  # generate a name from the module name
14
14
  class_name = nerdd_module.__class__.__name__
15
15
  if class_name.endswith("Model"):
@@ -25,17 +25,15 @@ class DefaultConfiguration(Configuration):
25
25
  try:
26
26
  module = nerdd_module.__module__
27
27
  root_module = module.split(".", 1)[0]
28
- version_ = version(root_module)
28
+ package_version = version(root_module)
29
29
  except ModuleNotFoundError:
30
- pass
30
+ package_version = "0.0.1"
31
31
 
32
- self.config = dict(
32
+ config = dict(
33
33
  name=name,
34
- version=version_,
35
- task="molecular_property_prediction",
34
+ version=package_version,
36
35
  job_parameters=[],
37
36
  result_properties=[],
38
37
  )
39
38
 
40
- def _get_dict(self):
41
- return self.config
39
+ super().__init__(config)
@@ -4,10 +4,9 @@ __all__ = ["DictConfiguration"]
4
4
 
5
5
 
6
6
  class DictConfiguration(Configuration):
7
- def __init__(self, config):
7
+ def __init__(self, config: dict) -> None:
8
8
  super().__init__()
9
+ self._config = config
9
10
 
10
- self.config = config
11
-
12
- def _get_dict(self):
13
- return self.config
11
+ def _get_dict(self) -> dict:
12
+ return self._config
@@ -0,0 +1,44 @@
1
+ from collections import Counter
2
+
3
+ from .configuration import Configuration
4
+ from .dict_configuration import DictConfiguration
5
+
6
+ __all__ = ["MergedConfiguration"]
7
+
8
+
9
+ def merge(*args: dict) -> dict:
10
+ assert len(args) > 0
11
+
12
+ first_entry = args[0]
13
+ assert all(isinstance(d, type(first_entry)) for d in args)
14
+
15
+ if isinstance(first_entry, list):
16
+ return [e for d in args for e in d]
17
+ if isinstance(first_entry, dict):
18
+ count_fields = Counter([k for d in args for k in d.keys()])
19
+
20
+ # merge fields that occur in multiple dicts
21
+ overlapping_fields = [k for k, v in count_fields.items() if v > 1]
22
+ merged_overlapping_fields = {
23
+ k: merge(*[d[k] for d in args if k in d]) for k in overlapping_fields
24
+ }
25
+
26
+ # collect fields that occur in only one dict
27
+ non_overlapping_fields = [k for k, v in count_fields.items() if v == 1]
28
+ merged_non_overlapping_fields = {
29
+ k: v for d in args for k, v in d.items() if k in non_overlapping_fields
30
+ }
31
+
32
+ return {
33
+ **merged_non_overlapping_fields,
34
+ **merged_overlapping_fields,
35
+ }
36
+ else:
37
+ # merge all configurations starting from the first one
38
+ # --> last configuration has the highest priority
39
+ return args[-1]
40
+
41
+
42
+ class MergedConfiguration(DictConfiguration):
43
+ def __init__(self, *configs: Configuration):
44
+ super().__init__(merge(*[c.get_dict() for c in configs]))
@@ -1,3 +1,5 @@
1
+ import logging
2
+
1
3
  from ..polyfills import files
2
4
  from .configuration import Configuration
3
5
  from .dict_configuration import DictConfiguration
@@ -5,9 +7,11 @@ from .yaml_configuration import YamlConfiguration
5
7
 
6
8
  __all__ = ["PackageConfiguration"]
7
9
 
10
+ logger = logging.getLogger(__name__)
11
+
8
12
 
9
13
  class PackageConfiguration(Configuration):
10
- def __init__(self, package):
14
+ def __init__(self, package: str) -> None:
11
15
  super().__init__()
12
16
 
13
17
  # get the resource directory
@@ -16,16 +20,16 @@ class PackageConfiguration(Configuration):
16
20
  except ModuleNotFoundError:
17
21
  root_dir = None
18
22
 
19
- if root_dir is None:
20
- self.config = DictConfiguration({})
21
- else:
23
+ self.config: Configuration = DictConfiguration({})
24
+ if root_dir is not None:
22
25
  # navigate to the config file
23
26
  config_file = root_dir / "nerdd.yml"
24
27
 
25
- if config_file is not None and config_file.exists():
26
- self.config = YamlConfiguration(config_file, base_path=root_dir)
28
+ if config_file is not None and config_file.is_file():
29
+ logger.info(f"Found configuration file in package: {config_file}")
30
+ self.config = YamlConfiguration(config_file.open(), base_path=root_dir)
27
31
  else:
28
32
  self.config = DictConfiguration({})
29
33
 
30
- def _get_dict(self):
34
+ def _get_dict(self) -> dict:
31
35
  return self.config.get_dict()
@@ -0,0 +1,37 @@
1
+ import logging
2
+ import os
3
+ from typing import Optional
4
+
5
+ from .configuration import Configuration
6
+ from .dict_configuration import DictConfiguration
7
+ from .yaml_configuration import YamlConfiguration
8
+
9
+ __all__ = ["SearchYamlConfiguration"]
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ class SearchYamlConfiguration(DictConfiguration):
15
+ def __init__(self, start: str, base_path: Optional[str] = None) -> None:
16
+ # provide a default configuration if no configuration file is found
17
+ config: Configuration = DictConfiguration({})
18
+
19
+ if start is not None:
20
+ # start at the directory containing the file where nerdd_module_class is
21
+ # defined and go up the directory tree until nerdd.yml is found (or root is
22
+ # reached)
23
+ leaf = start
24
+ while True:
25
+ if os.path.isfile(os.path.join(leaf, "nerdd.yml")):
26
+ default_config_file = os.path.join(leaf, "nerdd.yml")
27
+ break
28
+ elif leaf == os.path.dirname(leaf): # reached root
29
+ default_config_file = None
30
+ break
31
+ leaf = os.path.dirname(leaf)
32
+
33
+ if default_config_file is not None:
34
+ logger.info(f"Found configuration file in project directory: {default_config_file}")
35
+ config = YamlConfiguration(default_config_file, base_path)
36
+
37
+ super().__init__(config.get_dict())