reflectorch 1.0.1__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of reflectorch might be problematic. Click here for more details.
- reflectorch/inference/__init__.py +2 -0
- reflectorch/inference/query_matcher.py +82 -0
- reflectorch/runs/utils.py +6 -7
- {reflectorch-1.0.1.dist-info → reflectorch-1.1.0.dist-info}/METADATA +14 -10
- {reflectorch-1.0.1.dist-info → reflectorch-1.1.0.dist-info}/RECORD +8 -7
- {reflectorch-1.0.1.dist-info → reflectorch-1.1.0.dist-info}/WHEEL +1 -1
- {reflectorch-1.0.1.dist-info → reflectorch-1.1.0.dist-info}/LICENSE.txt +0 -0
- {reflectorch-1.0.1.dist-info → reflectorch-1.1.0.dist-info}/top_level.txt +0 -0
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
from reflectorch.inference.inference_model import InferenceModel, EasyInferenceModel
|
|
2
|
+
from reflectorch.inference.query_matcher import HuggingfaceQueryMatcher
|
|
2
3
|
from reflectorch.inference.multilayer_inference_model import MultilayerInferenceModel
|
|
3
4
|
from reflectorch.inference.preprocess_exp import (
|
|
4
5
|
StandardPreprocessing,
|
|
@@ -13,6 +14,7 @@ __all__ = [
|
|
|
13
14
|
"InferenceModel",
|
|
14
15
|
"EasyInferenceModel",
|
|
15
16
|
"MultilayerInferenceModel",
|
|
17
|
+
"HuggingfaceQueryMatcher",
|
|
16
18
|
"StandardPreprocessing",
|
|
17
19
|
"standard_preprocessing",
|
|
18
20
|
"ReflGradientFit",
|
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import tempfile
|
|
3
|
+
import yaml
|
|
4
|
+
from huggingface_hub import hf_hub_download, list_repo_files
|
|
5
|
+
|
|
6
|
+
class HuggingfaceQueryMatcher:
|
|
7
|
+
"""Downloads the available configurations files to a temporary directory and provides functionality for filtering those configuration files matching user specified queries.
|
|
8
|
+
|
|
9
|
+
Args:
|
|
10
|
+
repo_id (str): The Hugging Face repository ID.
|
|
11
|
+
config_dir (str): Directory within the repo where YAML files are stored.
|
|
12
|
+
"""
|
|
13
|
+
def __init__(self, repo_id='valentinsingularity/reflectivity', config_dir='configs'):
|
|
14
|
+
self.repo_id = repo_id
|
|
15
|
+
self.config_dir = config_dir
|
|
16
|
+
self.cache = {
|
|
17
|
+
'parsed_configs': None,
|
|
18
|
+
'temp_dir': None
|
|
19
|
+
}
|
|
20
|
+
self._renew_cache()
|
|
21
|
+
|
|
22
|
+
def _renew_cache(self):
|
|
23
|
+
temp_dir = tempfile.mkdtemp()
|
|
24
|
+
print(f"Temporary directory created at: {temp_dir}")
|
|
25
|
+
|
|
26
|
+
repo_files = list_repo_files(self.repo_id, repo_type='model')
|
|
27
|
+
config_files = [file for file in repo_files if file.startswith(self.config_dir) and file.endswith('.yaml')]
|
|
28
|
+
|
|
29
|
+
downloaded_files = []
|
|
30
|
+
for file in config_files:
|
|
31
|
+
file_path = hf_hub_download(repo_id=self.repo_id, filename=file, local_dir=temp_dir, repo_type='model')
|
|
32
|
+
downloaded_files.append(file_path)
|
|
33
|
+
|
|
34
|
+
parsed_configs = {}
|
|
35
|
+
for file_path in downloaded_files:
|
|
36
|
+
with open(file_path, 'r') as file:
|
|
37
|
+
config_data = yaml.safe_load(file)
|
|
38
|
+
file_name = os.path.basename(file_path)
|
|
39
|
+
parsed_configs[file_name] = config_data
|
|
40
|
+
|
|
41
|
+
self.cache['parsed_configs'] = parsed_configs
|
|
42
|
+
self.cache['temp_dir'] = temp_dir
|
|
43
|
+
|
|
44
|
+
def get_matching_configs(self, query):
|
|
45
|
+
"""retrieves configuration files that match the user specified query.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
query (dict): Dictionary of key-value pairs to filter configurations, e.g. ``query = {'dset.prior_sampler.kwargs.max_num_layers': 3, 'dset.prior_sampler.kwargs.param_ranges.slds': [0., 100.]}``.
|
|
49
|
+
For keys containing the ``param_ranges`` subkey a configuration is selected if the value of the query (i.e. desired parameter range)
|
|
50
|
+
is a subrange of the parameter range in the configuration, in all other cases the values must match exactly.
|
|
51
|
+
|
|
52
|
+
Returns:
|
|
53
|
+
list: List of file names that match the query.
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
filtered_configs = []
|
|
57
|
+
|
|
58
|
+
for file_name, config_data in self.cache['parsed_configs'].items():
|
|
59
|
+
if self.matches_query(config_data, query):
|
|
60
|
+
filtered_configs.append(file_name)
|
|
61
|
+
|
|
62
|
+
return filtered_configs
|
|
63
|
+
|
|
64
|
+
def matches_query(self, config_data, query):
|
|
65
|
+
for q_key, q_value in query.items():
|
|
66
|
+
keys = q_key.split('.')
|
|
67
|
+
value = self.deep_get(config_data, keys)
|
|
68
|
+
if 'param_ranges' in keys:
|
|
69
|
+
if q_value[0] < value[0] or q_value[1] > value[1]:
|
|
70
|
+
return False
|
|
71
|
+
else:
|
|
72
|
+
if value != q_value:
|
|
73
|
+
return False
|
|
74
|
+
|
|
75
|
+
return True
|
|
76
|
+
|
|
77
|
+
def deep_get(self, d, keys):
|
|
78
|
+
for key in keys:
|
|
79
|
+
if isinstance(d, dict):
|
|
80
|
+
d = d.get(key, None)
|
|
81
|
+
|
|
82
|
+
return d
|
reflectorch/runs/utils.py
CHANGED
|
@@ -177,7 +177,7 @@ def get_trainer_from_config(config: dict, folder_paths: dict = None):
|
|
|
177
177
|
return trainer
|
|
178
178
|
|
|
179
179
|
|
|
180
|
-
def get_trainer_by_name(config_name, config_dir=None, model_path=None, load_weights: bool = True, inference_device =
|
|
180
|
+
def get_trainer_by_name(config_name, config_dir=None, model_path=None, load_weights: bool = True, inference_device: str = 'cuda'):
|
|
181
181
|
"""Initializes a trainer object based on a configuration file (i.e. the model name) and optionally loads \
|
|
182
182
|
saved weights into the network
|
|
183
183
|
|
|
@@ -186,7 +186,7 @@ def get_trainer_by_name(config_name, config_dir=None, model_path=None, load_weig
|
|
|
186
186
|
config_dir (str): path of the configuration directory
|
|
187
187
|
model_path (str, optional): path to the network weights.
|
|
188
188
|
load_weights (bool, optional): if True the saved network weights are loaded into the network. Defaults to True.
|
|
189
|
-
|
|
189
|
+
inference_device (str, optional): overwrites the device in the configuration file for the purpose of inference on a different device then the training was performed on. Defaults to 'cuda'.
|
|
190
190
|
|
|
191
191
|
Returns:
|
|
192
192
|
Trainer: the trainer object
|
|
@@ -195,10 +195,9 @@ def get_trainer_by_name(config_name, config_dir=None, model_path=None, load_weig
|
|
|
195
195
|
config['model']['network']['pretrained_name'] = None
|
|
196
196
|
config['training']['logger']['use_neptune'] = False
|
|
197
197
|
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
config['dset']['q_generator']['kwargs']['device'] = inference_device
|
|
198
|
+
config['model']['network']['device'] = inference_device
|
|
199
|
+
config['dset']['prior_sampler']['kwargs']['device'] = inference_device
|
|
200
|
+
config['dset']['q_generator']['kwargs']['device'] = inference_device
|
|
202
201
|
|
|
203
202
|
trainer = get_trainer_from_config(config)
|
|
204
203
|
|
|
@@ -216,7 +215,7 @@ def get_trainer_by_name(config_name, config_dir=None, model_path=None, load_weig
|
|
|
216
215
|
model_path = SAVED_MODELS_DIR / model_name
|
|
217
216
|
|
|
218
217
|
try:
|
|
219
|
-
state_dict = torch.load(model_path)
|
|
218
|
+
state_dict = torch.load(model_path, map_location=inference_device)
|
|
220
219
|
except Exception as err:
|
|
221
220
|
raise RuntimeError(f'Could not load model from {model_path}') from err
|
|
222
221
|
|
|
@@ -1,9 +1,9 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: reflectorch
|
|
3
|
-
Version: 1.0
|
|
3
|
+
Version: 1.1.0
|
|
4
4
|
Summary: A Pytorch-based package for the analysis of reflectometry data
|
|
5
5
|
Author-email: Vladimir Starostin <vladimir.starostin@uni-tuebingen.de>, Valentin Munteanu <valentin.munteanu@uni-tuebingen.de>
|
|
6
|
-
Maintainer-email: Valentin Munteanu <valentin.munteanu@uni-tuebingen.de>, Alexander Hinderhofer <alexander.hinderhofer@uni-tuebingen.de>
|
|
6
|
+
Maintainer-email: Valentin Munteanu <valentin.munteanu@uni-tuebingen.de>, Vladimir Starostin <vladimir.starostin@uni-tuebingen.de>, Alexander Hinderhofer <alexander.hinderhofer@uni-tuebingen.de>
|
|
7
7
|
Project-URL: Source, https://github.com/schreiber-lab/reflectorch/
|
|
8
8
|
Project-URL: Issues, https://github.com/schreiber-lab/reflectorch/issues
|
|
9
9
|
Project-URL: Documentation, https://schreiber-lab.github.io/reflectorch/
|
|
@@ -26,7 +26,6 @@ Requires-Dist: PyYAML
|
|
|
26
26
|
Requires-Dist: click
|
|
27
27
|
Requires-Dist: matplotlib
|
|
28
28
|
Requires-Dist: ipywidgets
|
|
29
|
-
Requires-Dist: torchinfo
|
|
30
29
|
Requires-Dist: huggingface-hub
|
|
31
30
|
Provides-Extra: build
|
|
32
31
|
Requires-Dist: build ; extra == 'build'
|
|
@@ -47,12 +46,12 @@ Requires-Dist: pytest-cov ; extra == 'tests'
|
|
|
47
46
|
[](https://yaml.org/)
|
|
48
47
|
[](https://huggingface.co/valentinsingularity/reflectivity)
|
|
49
48
|
|
|
50
|
-
[](https://www.gnu.org/licenses/gpl-3.0)
|
|
51
49
|
[](https://www.python.org/)
|
|
52
50
|

|
|
53
51
|

|
|
54
52
|
[](https://www.codefactor.io/repository/github/schreiber-lab/reflectorch)
|
|
55
|
-
[](https://
|
|
53
|
+
[](https://jupyterbook.org/)
|
|
54
|
+
[](https://schreiber-lab.github.io/reflectorch/)
|
|
56
55
|
<!-- [](https://github.com/astral-sh/ruff) -->
|
|
57
56
|
|
|
58
57
|
|
|
@@ -92,16 +91,21 @@ Users with Nvidia **GPU**s need to additionally install **Pytorch with CUDA supp
|
|
|
92
91
|
|
|
93
92
|
## Get started
|
|
94
93
|
|
|
95
|
-

|
|
94
|
+
[](https://schreiber-lab.github.io/reflectorch/)
|
|
95
|
+
The full documentation of the package, containing tutorials and the API reference, was built with [Jupyter Book](https://jupyterbook.org/) and [Sphinx](https://www.sphinx-doc.org) and it is hosted at the address: [https://schreiber-lab.github.io/reflectorch/](https://schreiber-lab.github.io/reflectorch/).
|
|
96
96
|
|
|
97
|
-

|
|
97
|
+
[](https://colab.research.google.com/drive/1rf_M8S_5kYvUoK0-9-AYal_fO3oFl7ck?usp=sharing)
|
|
98
|
+
We provide an interactive Google Colab notebook for exploring the basic functionality of the package: [](https://colab.research.google.com/drive/1rf_M8S_5kYvUoK0-9-AYal_fO3oFl7ck?usp=sharing)<br>
|
|
98
99
|
|
|
99
|
-

|
|
100
|
+
[](https://huggingface.co/valentinsingularity/reflectivity)
|
|
101
|
+
Configuration files and the corresponding pretrained model weights are hosted on Huggingface: [https://huggingface.co/valentinsingularity/reflectivity](https://huggingface.co/valentinsingularity/reflectivity).
|
|
102
|
+
|
|
103
|
+
[](https://hub.docker.com/)
|
|
104
|
+
Docker images for reflectorch *will* be hosted on Dockerhub.
|
|
100
105
|
|
|
101
|
-
 Docker images for reflectorch *will* be hosted on Dockerhub.
|
|
102
106
|
|
|
103
107
|
## Citation
|
|
104
|
-
If you find our work useful in your research, please cite:
|
|
108
|
+
If you find our work useful in your research, please cite as follows:
|
|
105
109
|
```
|
|
106
110
|
@Article{Munteanu2024,
|
|
107
111
|
author = {Munteanu, Valentin and Starostin, Vladimir and Greco, Alessandro and Pithan, Linus and Gerlach, Alexander and Hinderhofer, Alexander and Kowarik, Stefan and Schreiber, Frank},
|
|
@@ -38,10 +38,11 @@ reflectorch/extensions/jupyter/__init__.py,sha256=inEXUpeVWeAhkW5nkW_dASBzsAlv4h
|
|
|
38
38
|
reflectorch/extensions/jupyter/callbacks.py,sha256=piDR4ax6JFSOPyqfkk-nxrhyWYdMrxgC8ocoaJbbbu8,1233
|
|
39
39
|
reflectorch/extensions/matplotlib/__init__.py,sha256=8II5pU8015VrMjFI8szCKBP1zjz0dFAzBn7smNQzGuA,263
|
|
40
40
|
reflectorch/extensions/matplotlib/losses.py,sha256=TqcyrFrls1N6RXotFyXDF64Xz6nJGg7n5XMSXFdeRtQ,845
|
|
41
|
-
reflectorch/inference/__init__.py,sha256=
|
|
41
|
+
reflectorch/inference/__init__.py,sha256=i0KNn83XN33mLrV7bpHdLd0SXxuGCKfbQcIoa247Uts,834
|
|
42
42
|
reflectorch/inference/inference_model.py,sha256=QvnQDRRcZByHDJh84T5W8O3X_aJLZI6AmlskE6BaBlU,36265
|
|
43
43
|
reflectorch/inference/multilayer_fitter.py,sha256=0CxDpLOEp1terR4N39yFlxhvA8qAbHf_01NbmvYadck,5510
|
|
44
44
|
reflectorch/inference/multilayer_inference_model.py,sha256=hH_-dJGdMOox8GHXdM_nODXDlNgh_v449xW5FmklRdo,7575
|
|
45
|
+
reflectorch/inference/query_matcher.py,sha256=Dk49dW0XreeCjufzYBTKchfTdVbG6759ryV6I-wQL60,3387
|
|
45
46
|
reflectorch/inference/record_time.py,sha256=3er-aoR8Sd_Kc4qNwUmRqkEz4FYhVxdi1ARnBohybzM,1140
|
|
46
47
|
reflectorch/inference/sampler_solution.py,sha256=DeJM3EXEb6S5EiASj3mmdNI-Y07Cr5UzzA5oq-vEB-Q,2288
|
|
47
48
|
reflectorch/inference/scipy_fitter.py,sha256=339M33OdmfgOpifJGLYk4KVcnnNJrY6_aH7Lz6Vtt24,5404
|
|
@@ -75,9 +76,9 @@ reflectorch/runs/__init__.py,sha256=2BcdMJul5yd726p8w4iqlKhygAAxiu1zu0MKDe96bWk,
|
|
|
75
76
|
reflectorch/runs/config.py,sha256=6aEub3NV0jmoREdegV7S3Nz-5o1xPZnmPpNgYfMpdys,963
|
|
76
77
|
reflectorch/runs/slurm_utils.py,sha256=T5vsWrcduq_N9mS9XAXjAbx7PHcYiiiwjdS0iiXh_TI,2759
|
|
77
78
|
reflectorch/runs/train.py,sha256=NaHMUYApjOCeajyS5UMQkeCVyxVtroohXK5ceHNLOkM,2719
|
|
78
|
-
reflectorch/runs/utils.py,sha256=
|
|
79
|
-
reflectorch-1.0.
|
|
80
|
-
reflectorch-1.0.
|
|
81
|
-
reflectorch-1.0.
|
|
82
|
-
reflectorch-1.0.
|
|
83
|
-
reflectorch-1.0.
|
|
79
|
+
reflectorch/runs/utils.py,sha256=j_gJYrw4fIZvKJWXPdt1mOR0d_Ht6pg0rDjE2iOTLc8,9737
|
|
80
|
+
reflectorch-1.1.0.dist-info/LICENSE.txt,sha256=2kX9kLKiIRiQRqUXwk3J-Ba3fqmztNu8ORskLBlAuKM,1098
|
|
81
|
+
reflectorch-1.1.0.dist-info/METADATA,sha256=Jj8WKCgTrNn8_TT7GZS6gbLX2BxLauF3J4FJbo-18ZM,7777
|
|
82
|
+
reflectorch-1.1.0.dist-info/WHEEL,sha256=rWxmBtp7hEUqVLOnTaDOPpR-cZpCDkzhhcBce-Zyd5k,91
|
|
83
|
+
reflectorch-1.1.0.dist-info/top_level.txt,sha256=2EyIWrt4SeZ3hNadLXvEVpPFhyoZ4An7YflP4y_E3Fc,12
|
|
84
|
+
reflectorch-1.1.0.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|