reflectorch 1.0.1__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of reflectorch might be problematic. Click here for more details.

@@ -1,4 +1,5 @@
1
1
  from reflectorch.inference.inference_model import InferenceModel, EasyInferenceModel
2
+ from reflectorch.inference.query_matcher import HuggingfaceQueryMatcher
2
3
  from reflectorch.inference.multilayer_inference_model import MultilayerInferenceModel
3
4
  from reflectorch.inference.preprocess_exp import (
4
5
  StandardPreprocessing,
@@ -13,6 +14,7 @@ __all__ = [
13
14
  "InferenceModel",
14
15
  "EasyInferenceModel",
15
16
  "MultilayerInferenceModel",
17
+ "HuggingfaceQueryMatcher",
16
18
  "StandardPreprocessing",
17
19
  "standard_preprocessing",
18
20
  "ReflGradientFit",
@@ -0,0 +1,82 @@
1
+ import os
2
+ import tempfile
3
+ import yaml
4
+ from huggingface_hub import hf_hub_download, list_repo_files
5
+
6
+ class HuggingfaceQueryMatcher:
7
+ """Downloads the available configurations files to a temporary directory and provides functionality for filtering those configuration files matching user specified queries.
8
+
9
+ Args:
10
+ repo_id (str): The Hugging Face repository ID.
11
+ config_dir (str): Directory within the repo where YAML files are stored.
12
+ """
13
+ def __init__(self, repo_id='valentinsingularity/reflectivity', config_dir='configs'):
14
+ self.repo_id = repo_id
15
+ self.config_dir = config_dir
16
+ self.cache = {
17
+ 'parsed_configs': None,
18
+ 'temp_dir': None
19
+ }
20
+ self._renew_cache()
21
+
22
+ def _renew_cache(self):
23
+ temp_dir = tempfile.mkdtemp()
24
+ print(f"Temporary directory created at: {temp_dir}")
25
+
26
+ repo_files = list_repo_files(self.repo_id, repo_type='model')
27
+ config_files = [file for file in repo_files if file.startswith(self.config_dir) and file.endswith('.yaml')]
28
+
29
+ downloaded_files = []
30
+ for file in config_files:
31
+ file_path = hf_hub_download(repo_id=self.repo_id, filename=file, local_dir=temp_dir, repo_type='model')
32
+ downloaded_files.append(file_path)
33
+
34
+ parsed_configs = {}
35
+ for file_path in downloaded_files:
36
+ with open(file_path, 'r') as file:
37
+ config_data = yaml.safe_load(file)
38
+ file_name = os.path.basename(file_path)
39
+ parsed_configs[file_name] = config_data
40
+
41
+ self.cache['parsed_configs'] = parsed_configs
42
+ self.cache['temp_dir'] = temp_dir
43
+
44
+ def get_matching_configs(self, query):
45
+ """retrieves configuration files that match the user specified query.
46
+
47
+ Args:
48
+ query (dict): Dictionary of key-value pairs to filter configurations, e.g. ``query = {'dset.prior_sampler.kwargs.max_num_layers': 3, 'dset.prior_sampler.kwargs.param_ranges.slds': [0., 100.]}``.
49
+ For keys containing the ``param_ranges`` subkey a configuration is selected if the value of the query (i.e. desired parameter range)
50
+ is a subrange of the parameter range in the configuration, in all other cases the values must match exactly.
51
+
52
+ Returns:
53
+ list: List of file names that match the query.
54
+ """
55
+
56
+ filtered_configs = []
57
+
58
+ for file_name, config_data in self.cache['parsed_configs'].items():
59
+ if self.matches_query(config_data, query):
60
+ filtered_configs.append(file_name)
61
+
62
+ return filtered_configs
63
+
64
+ def matches_query(self, config_data, query):
65
+ for q_key, q_value in query.items():
66
+ keys = q_key.split('.')
67
+ value = self.deep_get(config_data, keys)
68
+ if 'param_ranges' in keys:
69
+ if q_value[0] < value[0] or q_value[1] > value[1]:
70
+ return False
71
+ else:
72
+ if value != q_value:
73
+ return False
74
+
75
+ return True
76
+
77
+ def deep_get(self, d, keys):
78
+ for key in keys:
79
+ if isinstance(d, dict):
80
+ d = d.get(key, None)
81
+
82
+ return d
reflectorch/runs/utils.py CHANGED
@@ -177,7 +177,7 @@ def get_trainer_from_config(config: dict, folder_paths: dict = None):
177
177
  return trainer
178
178
 
179
179
 
180
- def get_trainer_by_name(config_name, config_dir=None, model_path=None, load_weights: bool = True, inference_device = None):
180
+ def get_trainer_by_name(config_name, config_dir=None, model_path=None, load_weights: bool = True, inference_device: str = 'cuda'):
181
181
  """Initializes a trainer object based on a configuration file (i.e. the model name) and optionally loads \
182
182
  saved weights into the network
183
183
 
@@ -186,7 +186,7 @@ def get_trainer_by_name(config_name, config_dir=None, model_path=None, load_weig
186
186
  config_dir (str): path of the configuration directory
187
187
  model_path (str, optional): path to the network weights.
188
188
  load_weights (bool, optional): if True the saved network weights are loaded into the network. Defaults to True.
189
-
189
+ inference_device (str, optional): overwrites the device in the configuration file for the purpose of inference on a different device then the training was performed on. Defaults to 'cuda'.
190
190
 
191
191
  Returns:
192
192
  Trainer: the trainer object
@@ -195,10 +195,9 @@ def get_trainer_by_name(config_name, config_dir=None, model_path=None, load_weig
195
195
  config['model']['network']['pretrained_name'] = None
196
196
  config['training']['logger']['use_neptune'] = False
197
197
 
198
- if inference_device:
199
- config['model']['network']['device'] = inference_device
200
- config['dset']['prior_sampler']['kwargs']['device'] = inference_device
201
- config['dset']['q_generator']['kwargs']['device'] = inference_device
198
+ config['model']['network']['device'] = inference_device
199
+ config['dset']['prior_sampler']['kwargs']['device'] = inference_device
200
+ config['dset']['q_generator']['kwargs']['device'] = inference_device
202
201
 
203
202
  trainer = get_trainer_from_config(config)
204
203
 
@@ -216,7 +215,7 @@ def get_trainer_by_name(config_name, config_dir=None, model_path=None, load_weig
216
215
  model_path = SAVED_MODELS_DIR / model_name
217
216
 
218
217
  try:
219
- state_dict = torch.load(model_path)
218
+ state_dict = torch.load(model_path, map_location=inference_device)
220
219
  except Exception as err:
221
220
  raise RuntimeError(f'Could not load model from {model_path}') from err
222
221
 
@@ -1,9 +1,9 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: reflectorch
3
- Version: 1.0.1
3
+ Version: 1.1.0
4
4
  Summary: A Pytorch-based package for the analysis of reflectometry data
5
5
  Author-email: Vladimir Starostin <vladimir.starostin@uni-tuebingen.de>, Valentin Munteanu <valentin.munteanu@uni-tuebingen.de>
6
- Maintainer-email: Valentin Munteanu <valentin.munteanu@uni-tuebingen.de>, Alexander Hinderhofer <alexander.hinderhofer@uni-tuebingen.de>
6
+ Maintainer-email: Valentin Munteanu <valentin.munteanu@uni-tuebingen.de>, Vladimir Starostin <vladimir.starostin@uni-tuebingen.de>, Alexander Hinderhofer <alexander.hinderhofer@uni-tuebingen.de>
7
7
  Project-URL: Source, https://github.com/schreiber-lab/reflectorch/
8
8
  Project-URL: Issues, https://github.com/schreiber-lab/reflectorch/issues
9
9
  Project-URL: Documentation, https://schreiber-lab.github.io/reflectorch/
@@ -26,7 +26,6 @@ Requires-Dist: PyYAML
26
26
  Requires-Dist: click
27
27
  Requires-Dist: matplotlib
28
28
  Requires-Dist: ipywidgets
29
- Requires-Dist: torchinfo
30
29
  Requires-Dist: huggingface-hub
31
30
  Provides-Extra: build
32
31
  Requires-Dist: build ; extra == 'build'
@@ -47,12 +46,12 @@ Requires-Dist: pytest-cov ; extra == 'tests'
47
46
  [![YAML](https://img.shields.io/badge/yaml-%23ffffff.svg?style=for-the-badge&logo=yaml&logoColor=151515)](https://yaml.org/)
48
47
  [![Hugging Face](https://img.shields.io/badge/Hugging%20Face-%23FFD700.svg?style=for-the-badge&logo=huggingface&logoColor=black)](https://huggingface.co/valentinsingularity/reflectivity)
49
48
 
50
- [![License: GPLv3](https://img.shields.io/badge/License-GPLv3-blue.svg)](https://www.gnu.org/licenses/gpl-3.0)
51
49
  [![Python version](https://img.shields.io/badge/python-3.7%7C3.8%7C3.9%7C3.10%7C3.11%7C3.12-blue.svg)](https://www.python.org/)
52
50
  ![CI workflow status](https://github.com/schreiber-lab/reflectorch/actions/workflows/ci.yml/badge.svg)
53
51
  ![Repos size](https://img.shields.io/github/repo-size/schreiber-lab/reflectorch)
54
52
  [![CodeFactor](https://www.codefactor.io/repository/github/schreiber-lab/reflectorch/badge)](https://www.codefactor.io/repository/github/schreiber-lab/reflectorch)
55
- [![Jupyter Book Documentation](https://jupyterbook.org/badge.svg)](https://schreiber-lab.github.io/reflectorch/)
53
+ [![Jupyter Book Documentation](https://jupyterbook.org/badge.svg)](https://jupyterbook.org/)
54
+ [![Documentation Page](https://img.shields.io/badge/Documentation%20Page-%23FFDD33.svg?style=flat&logo=read-the-docs&logoColor=black)](https://schreiber-lab.github.io/reflectorch/)
56
55
  <!-- [![Code style: Ruff](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/ruff/main/assets/badge/v2.json)](https://github.com/astral-sh/ruff) -->
57
56
 
58
57
 
@@ -92,16 +91,21 @@ Users with Nvidia **GPU**s need to additionally install **Pytorch with CUDA supp
92
91
 
93
92
  ## Get started
94
93
 
95
- ![](https://img.shields.io/badge/Documentation%20Page-%23FFDD33.svg?style=flat&logo=read-the-docs&logoColor=black) The full documentation of the package, containing tutorials and the API reference, was built with [Jupyter Book](https://jupyterbook.org/) and [Sphinx](https://www.sphinx-doc.org) and it is hosted at the address: [https://schreiber-lab.github.io/reflectorch/](https://schreiber-lab.github.io/reflectorch/).
94
+ [![Documentation Page](https://img.shields.io/badge/Documentation%20Page-%23FFDD33.svg?style=flat&logo=read-the-docs&logoColor=black)](https://schreiber-lab.github.io/reflectorch/)
95
+ The full documentation of the package, containing tutorials and the API reference, was built with [Jupyter Book](https://jupyterbook.org/) and [Sphinx](https://www.sphinx-doc.org) and it is hosted at the address: [https://schreiber-lab.github.io/reflectorch/](https://schreiber-lab.github.io/reflectorch/).
96
96
 
97
- ![](https://img.shields.io/badge/Interactive%20Notebook-%23F9AB00.svg?style=flat&logo=google-colab&logoColor=black) We provide an interactive Google Colab notebook for exploring the basic functionality of the package: [![Explore reflectorch in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1rf_M8S_5kYvUoK0-9-AYal_fO3oFl7ck?usp=sharing)<br>
97
+ [![Interactive Notebook](https://img.shields.io/badge/Interactive%20Notebook-%23F9AB00.svg?style=flat&logo=google-colab&logoColor=black)](https://colab.research.google.com/drive/1rf_M8S_5kYvUoK0-9-AYal_fO3oFl7ck?usp=sharing)
98
+ We provide an interactive Google Colab notebook for exploring the basic functionality of the package: [![Explore reflectorch in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1rf_M8S_5kYvUoK0-9-AYal_fO3oFl7ck?usp=sharing)<br>
98
99
 
99
- ![](https://img.shields.io/badge/Hugging%20Face-%23FFD700.svg?style=flat&logo=huggingface&logoColor=black) Configuration files and the corresponding pretrained model weights are hosted on Huggingface: [https://huggingface.co/valentinsingularity/reflectivity](https://huggingface.co/valentinsingularity/reflectivity).
100
+ [![Hugging Face](https://img.shields.io/badge/Hugging%20Face-%23FFD700.svg?style=flat&logo=huggingface&logoColor=black)](https://huggingface.co/valentinsingularity/reflectivity)
101
+ Configuration files and the corresponding pretrained model weights are hosted on Huggingface: [https://huggingface.co/valentinsingularity/reflectivity](https://huggingface.co/valentinsingularity/reflectivity).
102
+
103
+ [![Docker](https://img.shields.io/badge/Docker-2496ED.svg?style=flat&logo=docker&logoColor=white)](https://hub.docker.com/)
104
+ Docker images for reflectorch *will* be hosted on Dockerhub.
100
105
 
101
- ![](https://img.shields.io/badge/Docker-2496ED.svg?style=flat&logo=docker&logoColor=white) Docker images for reflectorch *will* be hosted on Dockerhub.
102
106
 
103
107
  ## Citation
104
- If you find our work useful in your research, please cite:
108
+ If you find our work useful in your research, please cite as follows:
105
109
  ```
106
110
  @Article{Munteanu2024,
107
111
  author = {Munteanu, Valentin and Starostin, Vladimir and Greco, Alessandro and Pithan, Linus and Gerlach, Alexander and Hinderhofer, Alexander and Kowarik, Stefan and Schreiber, Frank},
@@ -38,10 +38,11 @@ reflectorch/extensions/jupyter/__init__.py,sha256=inEXUpeVWeAhkW5nkW_dASBzsAlv4h
38
38
  reflectorch/extensions/jupyter/callbacks.py,sha256=piDR4ax6JFSOPyqfkk-nxrhyWYdMrxgC8ocoaJbbbu8,1233
39
39
  reflectorch/extensions/matplotlib/__init__.py,sha256=8II5pU8015VrMjFI8szCKBP1zjz0dFAzBn7smNQzGuA,263
40
40
  reflectorch/extensions/matplotlib/losses.py,sha256=TqcyrFrls1N6RXotFyXDF64Xz6nJGg7n5XMSXFdeRtQ,845
41
- reflectorch/inference/__init__.py,sha256=8KDIazaanBhAdtGg9AXc29SBfO5rO4Q0_BeBy0H6G54,729
41
+ reflectorch/inference/__init__.py,sha256=i0KNn83XN33mLrV7bpHdLd0SXxuGCKfbQcIoa247Uts,834
42
42
  reflectorch/inference/inference_model.py,sha256=QvnQDRRcZByHDJh84T5W8O3X_aJLZI6AmlskE6BaBlU,36265
43
43
  reflectorch/inference/multilayer_fitter.py,sha256=0CxDpLOEp1terR4N39yFlxhvA8qAbHf_01NbmvYadck,5510
44
44
  reflectorch/inference/multilayer_inference_model.py,sha256=hH_-dJGdMOox8GHXdM_nODXDlNgh_v449xW5FmklRdo,7575
45
+ reflectorch/inference/query_matcher.py,sha256=Dk49dW0XreeCjufzYBTKchfTdVbG6759ryV6I-wQL60,3387
45
46
  reflectorch/inference/record_time.py,sha256=3er-aoR8Sd_Kc4qNwUmRqkEz4FYhVxdi1ARnBohybzM,1140
46
47
  reflectorch/inference/sampler_solution.py,sha256=DeJM3EXEb6S5EiASj3mmdNI-Y07Cr5UzzA5oq-vEB-Q,2288
47
48
  reflectorch/inference/scipy_fitter.py,sha256=339M33OdmfgOpifJGLYk4KVcnnNJrY6_aH7Lz6Vtt24,5404
@@ -75,9 +76,9 @@ reflectorch/runs/__init__.py,sha256=2BcdMJul5yd726p8w4iqlKhygAAxiu1zu0MKDe96bWk,
75
76
  reflectorch/runs/config.py,sha256=6aEub3NV0jmoREdegV7S3Nz-5o1xPZnmPpNgYfMpdys,963
76
77
  reflectorch/runs/slurm_utils.py,sha256=T5vsWrcduq_N9mS9XAXjAbx7PHcYiiiwjdS0iiXh_TI,2759
77
78
  reflectorch/runs/train.py,sha256=NaHMUYApjOCeajyS5UMQkeCVyxVtroohXK5ceHNLOkM,2719
78
- reflectorch/runs/utils.py,sha256=8hFWDmPTvfIrrk9v-nVCVyV3-_lzm0HvV_qWtjtAlBQ,9541
79
- reflectorch-1.0.1.dist-info/LICENSE.txt,sha256=2kX9kLKiIRiQRqUXwk3J-Ba3fqmztNu8ORskLBlAuKM,1098
80
- reflectorch-1.0.1.dist-info/METADATA,sha256=DFSHHEXKXrsWjGj2xthK49n43C_0AAjG9VM2dsrDSbY,7398
81
- reflectorch-1.0.1.dist-info/WHEEL,sha256=FZ75kcLy9M91ncbIgG8dnpCncbiKXSRGJ_PFILs6SFg,91
82
- reflectorch-1.0.1.dist-info/top_level.txt,sha256=2EyIWrt4SeZ3hNadLXvEVpPFhyoZ4An7YflP4y_E3Fc,12
83
- reflectorch-1.0.1.dist-info/RECORD,,
79
+ reflectorch/runs/utils.py,sha256=j_gJYrw4fIZvKJWXPdt1mOR0d_Ht6pg0rDjE2iOTLc8,9737
80
+ reflectorch-1.1.0.dist-info/LICENSE.txt,sha256=2kX9kLKiIRiQRqUXwk3J-Ba3fqmztNu8ORskLBlAuKM,1098
81
+ reflectorch-1.1.0.dist-info/METADATA,sha256=Jj8WKCgTrNn8_TT7GZS6gbLX2BxLauF3J4FJbo-18ZM,7777
82
+ reflectorch-1.1.0.dist-info/WHEEL,sha256=rWxmBtp7hEUqVLOnTaDOPpR-cZpCDkzhhcBce-Zyd5k,91
83
+ reflectorch-1.1.0.dist-info/top_level.txt,sha256=2EyIWrt4SeZ3hNadLXvEVpPFhyoZ4An7YflP4y_E3Fc,12
84
+ reflectorch-1.1.0.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (71.0.1)
2
+ Generator: setuptools (71.0.4)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5