reflectorch 1.0.1__py3-none-any.whl → 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of reflectorch might be problematic. Click here for more details.

Files changed (59) hide show
  1. reflectorch/__init__.py +0 -6
  2. reflectorch/data_generation/__init__.py +0 -6
  3. reflectorch/data_generation/dataset.py +0 -6
  4. reflectorch/data_generation/likelihoods.py +0 -6
  5. reflectorch/data_generation/noise.py +0 -6
  6. reflectorch/data_generation/priors/__init__.py +0 -6
  7. reflectorch/data_generation/priors/base.py +0 -6
  8. reflectorch/data_generation/priors/exp_subprior_sampler.py +0 -6
  9. reflectorch/data_generation/priors/independent_priors.py +0 -6
  10. reflectorch/data_generation/priors/multilayer_structures.py +0 -6
  11. reflectorch/data_generation/priors/no_constraints.py +0 -6
  12. reflectorch/data_generation/priors/parametric_subpriors.py +0 -6
  13. reflectorch/data_generation/priors/params.py +0 -6
  14. reflectorch/data_generation/priors/subprior_sampler.py +0 -6
  15. reflectorch/data_generation/priors/utils.py +0 -6
  16. reflectorch/data_generation/process_data.py +0 -6
  17. reflectorch/data_generation/q_generator.py +0 -6
  18. reflectorch/data_generation/reflectivity/__init__.py +13 -9
  19. reflectorch/data_generation/reflectivity/abeles.py +6 -5
  20. reflectorch/data_generation/reflectivity/kinematical.py +14 -0
  21. reflectorch/data_generation/reflectivity/memory_eff.py +13 -0
  22. reflectorch/data_generation/reflectivity/smearing.py +2 -2
  23. reflectorch/data_generation/scale_curves.py +0 -6
  24. reflectorch/data_generation/smearing.py +3 -2
  25. reflectorch/data_generation/utils.py +0 -6
  26. reflectorch/extensions/__init__.py +0 -6
  27. reflectorch/extensions/jupyter/__init__.py +0 -6
  28. reflectorch/extensions/jupyter/callbacks.py +0 -6
  29. reflectorch/extensions/matplotlib/__init__.py +0 -6
  30. reflectorch/extensions/matplotlib/losses.py +0 -6
  31. reflectorch/inference/__init__.py +2 -0
  32. reflectorch/inference/inference_model.py +9 -6
  33. reflectorch/inference/query_matcher.py +82 -0
  34. reflectorch/ml/__init__.py +0 -7
  35. reflectorch/ml/basic_trainer.py +0 -6
  36. reflectorch/ml/callbacks.py +0 -6
  37. reflectorch/ml/loggers.py +0 -7
  38. reflectorch/ml/schedulers.py +0 -6
  39. reflectorch/ml/trainers.py +1 -34
  40. reflectorch/ml/utils.py +0 -7
  41. reflectorch/models/__init__.py +0 -7
  42. reflectorch/models/encoders/__init__.py +0 -8
  43. reflectorch/models/encoders/conv_encoder.py +0 -6
  44. reflectorch/models/encoders/conv_res_net.py +1 -5
  45. reflectorch/models/networks/__init__.py +0 -6
  46. reflectorch/paths.py +0 -6
  47. reflectorch/runs/__init__.py +2 -6
  48. reflectorch/runs/config.py +0 -6
  49. reflectorch/runs/slurm_utils.py +0 -6
  50. reflectorch/runs/train.py +0 -6
  51. reflectorch/runs/utils.py +82 -14
  52. reflectorch/utils.py +0 -6
  53. {reflectorch-1.0.1.dist-info → reflectorch-1.2.0.dist-info}/METADATA +15 -10
  54. reflectorch-1.2.0.dist-info/RECORD +83 -0
  55. {reflectorch-1.0.1.dist-info → reflectorch-1.2.0.dist-info}/WHEEL +1 -1
  56. reflectorch/models/encoders/transformers.py +0 -56
  57. reflectorch-1.0.1.dist-info/RECORD +0 -83
  58. {reflectorch-1.0.1.dist-info → reflectorch-1.2.0.dist-info}/LICENSE.txt +0 -0
  59. {reflectorch-1.0.1.dist-info → reflectorch-1.2.0.dist-info}/top_level.txt +0 -0
reflectorch/__init__.py CHANGED
@@ -1,9 +1,3 @@
1
- # -*- coding: utf-8 -*-
2
- #
3
- #
4
- # This source code is licensed under the GPL license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
1
  from reflectorch.data_generation import *
8
2
  from reflectorch.ml import *
9
3
  from reflectorch.models import *
@@ -1,9 +1,3 @@
1
- # -*- coding: utf-8 -*-
2
- #
3
- #
4
- # This source code is licensed under the GPL license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
1
  from reflectorch.data_generation.dataset import BasicDataset, BATCH_DATA_TYPE
8
2
  from reflectorch.data_generation.priors import (
9
3
  Params,
@@ -1,9 +1,3 @@
1
- # -*- coding: utf-8 -*-
2
- #
3
- #
4
- # This source code is licensed under the GPL license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
1
  from typing import Dict, Union
8
2
  import warnings
9
3
 
@@ -1,9 +1,3 @@
1
- # -*- coding: utf-8 -*-
2
- #
3
- #
4
- # This source code is licensed under the GPL license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
1
  from typing import Union, Tuple
8
2
 
9
3
  import torch
@@ -1,9 +1,3 @@
1
- # -*- coding: utf-8 -*-
2
- #
3
- #
4
- # This source code is licensed under the GPL license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
1
  from typing import Union, Tuple
8
2
  from math import log10
9
3
 
@@ -1,9 +1,3 @@
1
- # -*- coding: utf-8 -*-
2
- #
3
- #
4
- # This source code is licensed under the GPL license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
1
  from reflectorch.data_generation.priors.params import Params
8
2
  from reflectorch.data_generation.priors.base import PriorSampler
9
3
  from reflectorch.data_generation.priors.no_constraints import BasicPriorSampler
@@ -1,9 +1,3 @@
1
- # -*- coding: utf-8 -*-
2
- #
3
- #
4
- # This source code is licensed under the GPL license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
1
  from torch import Tensor
8
2
 
9
3
  from reflectorch.data_generation.priors.params import Params
@@ -1,9 +1,3 @@
1
- # -*- coding: utf-8 -*-
2
- #
3
- #
4
- # This source code is licensed under the GPL license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
1
  from typing import Tuple, Union, List
8
2
 
9
3
  import torch
@@ -1,9 +1,3 @@
1
- # -*- coding: utf-8 -*-
2
- #
3
- #
4
- # This source code is licensed under the GPL license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
1
  from typing import Union, Tuple
8
2
  from math import log
9
3
 
@@ -1,9 +1,3 @@
1
- # -*- coding: utf-8 -*-
2
- #
3
- #
4
- # This source code is licensed under the GPL license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
1
  from typing import Tuple, Dict
8
2
 
9
3
  import numpy as np
@@ -1,9 +1,3 @@
1
- # -*- coding: utf-8 -*-
2
- #
3
- #
4
- # This source code is licensed under the GPL license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
1
  import logging
8
2
  from functools import lru_cache
9
3
  from typing import Tuple
@@ -1,9 +1,3 @@
1
- # -*- coding: utf-8 -*-
2
- #
3
- #
4
- # This source code is licensed under the GPL license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
1
  from typing import Tuple, Dict, Type, List
8
2
 
9
3
  import torch
@@ -1,9 +1,3 @@
1
- # -*- coding: utf-8 -*-
2
- #
3
- #
4
- # This source code is licensed under the GPL license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
1
  from typing import List, Tuple
8
2
 
9
3
  import torch
@@ -1,9 +1,3 @@
1
- # -*- coding: utf-8 -*-
2
- #
3
- #
4
- # This source code is licensed under the GPL license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
1
  from functools import lru_cache
8
2
  from typing import Tuple
9
3
 
@@ -1,9 +1,3 @@
1
- # -*- coding: utf-8 -*-
2
- #
3
- #
4
- # This source code is licensed under the GPL license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
1
  from typing import Tuple
8
2
 
9
3
  import torch
@@ -1,9 +1,3 @@
1
- # -*- coding: utf-8 -*-
2
- #
3
- #
4
- # This source code is licensed under the GPL license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
1
  from typing import Any
8
2
 
9
3
  __all__ = [
@@ -1,9 +1,3 @@
1
- # -*- coding: utf-8 -*-
2
- #
3
- #
4
- # This source code is licensed under the GPL license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
1
  from typing import Tuple, Union
8
2
 
9
3
  import numpy as np
@@ -24,21 +24,25 @@ def reflectivity(
24
24
  log: bool = False,
25
25
  abeles_func=None,
26
26
  ):
27
- """Function which computes the reflectivity from thin film parameters using the Abeles matrix formalism
27
+ """Function which computes the reflectivity curves from thin film parameters.
28
+ By default it uses the fast implementation of the Abeles matrix formalism.
28
29
 
29
30
  Args:
30
- q (Tensor): the momentum transfer (q) values
31
- thickness (Tensor): the layer thicknesses
32
- roughness (Tensor): the interlayer roughnesses
33
- sld (Tensor): the SLDs of the layers
34
- dq (Tensor, optional): the resolution for curve smearing. Defaults to None.
31
+ q (Tensor): tensor of momentum transfer (q) values with shape [batch_size, n_points] or [n_points]
32
+ thickness (Tensor): tensor containing the layer thicknesses (ordered from top to bottom) with shape [batch_size, n_layers]
33
+ roughness (Tensor): tensor containing the interlayer roughnesses (ordered from top to bottom) with shape [batch_size, n_layers + 1]
34
+ sld (Tensor): tensors containing the layer SLDs (real or complex; ordered from top to bottom) with shape [batch_size, n_layers + 1].
35
+ It includes the substrate but excludes the ambient medium which is assumed to have an SLD of 0.
36
+ dq (Tensor, optional): tensor of resolutions used for curve smearing with shape [batch_size, 1].
37
+ Either dq if ``constant_dq`` is ``True`` or dq/q if ``constant_dq`` is ``False``. Defaults to None.
35
38
  gauss_num (int, optional): the number of gaussians for curve smearing. Defaults to 51.
36
- constant_dq (bool, optional): whether the smearing is constant. Defaults to True.
39
+ constant_dq (bool, optional): if ``True`` the smearing is constant (constant dq at each point in the curve)
40
+ otherwise the smearing is linear (constant dq/q at each point in the curve). Defaults to True.
37
41
  log (bool, optional): if True the base 10 logarithm of the reflectivity curves is returned. Defaults to False.
38
- abeles_func (Callable, optional): a function implementing the simulation of the reflectivity curves, if different than the default implementation. Defaults to None.
42
+ abeles_func (Callable, optional): a function implementing the simulation of the reflectivity curves, if different than the default Abeles matrix implementation. Defaults to None.
39
43
 
40
44
  Returns:
41
- Tensor: the computed reflectivity curves
45
+ Tensor: tensor containing the simulated reflectivity curves with shape [batch_size, n_points]
42
46
  """
43
47
  abeles_func = abeles_func or abeles
44
48
  q = torch.atleast_2d(q)
@@ -14,13 +14,14 @@ def abeles(
14
14
  """Simulates reflectivity curves for SLD profiles with box model parameterization using the Abeles matrix method
15
15
 
16
16
  Args:
17
- q (Tensor): q values
18
- thickness (Tensor): layer thicknesses
19
- roughness (Tensor): interlayer roughnesses
20
- sld (Tensor): layer SLDs
17
+ q (Tensor): tensor of momentum transfer (q) values with shape [batch_size, n_points] or [n_points]
18
+ thickness (Tensor): tensor containing the layer thicknesses (ordered from top to bottom) with shape [batch_size, n_layers]
19
+ roughness (Tensor): tensor containing the interlayer roughnesses (ordered from top to bottom) with shape [batch_size, n_layers + 1]
20
+ sld (Tensor): tensors containing the layer SLDs (real or complex; ordered from top to bottom) with shape [batch_size, n_layers + 1].
21
+ It includes the substrate but excludes the ambient medium which is assumed to have an SLD of 0.
21
22
 
22
23
  Returns:
23
- Tensor: simulated reflectivity curves
24
+ Tensor: tensor containing the simulated reflectivity curves with shape [batch_size, n_points]
24
25
  """
25
26
  c_dtype = torch.complex128 if q.dtype is torch.float64 else torch.complex64
26
27
 
@@ -13,6 +13,20 @@ def kinematical_approximation(
13
13
  apply_fresnel: bool = True,
14
14
  log: bool = False,
15
15
  ):
16
+ """Simulates reflectivity curves for SLD profiles with box model parameterization using the kinematical approximation
17
+
18
+ Args:
19
+ q (Tensor): tensor of momentum transfer (q) values with shape [batch_size, n_points] or [n_points]
20
+ thickness (Tensor): tensor containing the layer thicknesses (ordered from top to bottom) with shape [batch_size, n_layers]
21
+ roughness (Tensor): tensor containing the interlayer roughnesses (ordered from top to bottom) with shape [batch_size, n_layers + 1]
22
+ sld (Tensor): tensors containing the layer SLDs (real or complex; ordered from top to bottom) with shape [batch_size, n_layers + 1].
23
+ It includes the substrate but excludes the ambient medium which is assumed to have an SLD of 0.
24
+ apply_fresnel (bool, optional): whether to use the Fresnel coefficient in the computation. Defaults to ``True``.
25
+ log (bool, optional): if True the base 10 logarithm of the reflectivity curves is returned. Defaults to ``False``.
26
+
27
+ Returns:
28
+ Tensor: tensor containing the simulated reflectivity curves with shape [batch_size, n_points]
29
+ """
16
30
  c_dtype = torch.complex128 if q.dtype is torch.float64 else torch.complex64
17
31
 
18
32
  batch_size, num_layers = thickness.shape
@@ -12,6 +12,19 @@ def abeles_memory_eff(
12
12
  roughness: Tensor,
13
13
  sld: Tensor,
14
14
  ):
15
+ """Simulates reflectivity curves for SLD profiles with box model parameterization using a memory-efficient implementation the Abeles matrix method.
16
+ It is computationally slower compared to the implementation in the 'abeles' function.
17
+
18
+ Args:
19
+ q (Tensor): tensor of momentum transfer (q) values with shape [batch_size, n_points] or [n_points]
20
+ thickness (Tensor): tensor containing the layer thicknesses (ordered from top to bottom) with shape [batch_size, n_layers]
21
+ roughness (Tensor): tensor containing the interlayer roughnesses (ordered from top to bottom) with shape [batch_size, n_layers + 1]
22
+ sld (Tensor): tensors containing the layer SLDs (real or complex; ordered from top to bottom) with shape [batch_size, n_layers + 1].
23
+ It includes the substrate but excludes the ambient medium which is assumed to have an SLD of 0.
24
+
25
+ Returns:
26
+ Tensor: tensor containing the simulated reflectivity curves with shape [batch_size, n_points]
27
+ """
15
28
  c_dtype = torch.complex128 if q.dtype is torch.float64 else torch.complex64
16
29
 
17
30
  batch_size, num_layers = thickness.shape
@@ -68,8 +68,8 @@ def _get_q_axes_for_linear_dq(q: Tensor, resolutions: Tensor, gaussnum: int = 51
68
68
  lowq = torch.clamp_min_(q.min(1).values, 1e-6)
69
69
  highq = q.max(1).values
70
70
 
71
- start = torch.log10(lowq) - 6 * resolutions / _FWHM
72
- end = torch.log10(highq * (1 + 6 * resolutions / _FWHM))
71
+ start = torch.log10(lowq)[:, None] - 6 * resolutions / _FWHM
72
+ end = torch.log10(highq[:, None] * (1 + 6 * resolutions / _FWHM))
73
73
 
74
74
  interpnums = torch.abs(
75
75
  (torch.abs(end - start)) / (1.7 * resolutions / _FWHM / gaussgpoint)
@@ -1,9 +1,3 @@
1
- # -*- coding: utf-8 -*-
2
- #
3
- #
4
- # This source code is licensed under the GPL license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
1
  from pathlib import Path
8
2
 
9
3
  import torch
@@ -9,8 +9,9 @@ class Smearing(object):
9
9
  The intensity at a q point will be the average of the intensities of neighbouring q points, weighted by a gaussian profile.
10
10
 
11
11
  Args:
12
- sigma_range (tuple, optional): the range for sampling the standard deviation of the gaussians. Defaults to (1e-4, 5e-3).
13
- constant_dq (bool, optional): whether the smearing is constant for each q point. Defaults to True.
12
+ sigma_range (tuple, optional): the range for sampling the resolutions. Defaults to (1e-4, 5e-3).
13
+ constant_dq (bool, optional): if ``True`` the smearing is constant (the resolution is given by the constant dq at each point in the curve)
14
+ otherwise the smearing is linear (the resolution is given by the constant dq/q at each point in the curve). Defaults to True.
14
15
  gauss_num (int, optional): the number of interpolating gaussian profiles. Defaults to 31.
15
16
  share_smeared (float, optional): the share of curves in the batch for which the resolution smearing is applied. Defaults to 0.2.
16
17
  """
@@ -1,9 +1,3 @@
1
- # -*- coding: utf-8 -*-
2
- #
3
- #
4
- # This source code is licensed under the GPL license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
1
  from typing import List, Union
8
2
  from math import sqrt, pi, log10
9
3
 
@@ -1,6 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
- #
3
- #
4
- # This source code is licensed under the GPL license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
@@ -1,9 +1,3 @@
1
- # -*- coding: utf-8 -*-
2
- #
3
- #
4
- # This source code is licensed under the GPL license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
1
  from .callbacks import JPlotLoss
8
2
 
9
3
 
@@ -1,9 +1,3 @@
1
- # -*- coding: utf-8 -*-
2
- #
3
- #
4
- # This source code is licensed under the GPL license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
1
  from IPython.display import clear_output
8
2
 
9
3
  from ...ml import TrainerCallback, Trainer
@@ -1,9 +1,3 @@
1
- # -*- coding: utf-8 -*-
2
- #
3
- #
4
- # This source code is licensed under the GPL license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
1
  from reflectorch.extensions.matplotlib.losses import plot_losses
8
2
 
9
3
  __all__ = [
@@ -1,9 +1,3 @@
1
- # -*- coding: utf-8 -*-
2
- #
3
- #
4
- # This source code is licensed under the GPL license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
1
  import matplotlib.pyplot as plt
8
2
 
9
3
 
@@ -1,4 +1,5 @@
1
1
  from reflectorch.inference.inference_model import InferenceModel, EasyInferenceModel
2
+ from reflectorch.inference.query_matcher import HuggingfaceQueryMatcher
2
3
  from reflectorch.inference.multilayer_inference_model import MultilayerInferenceModel
3
4
  from reflectorch.inference.preprocess_exp import (
4
5
  StandardPreprocessing,
@@ -13,6 +14,7 @@ __all__ = [
13
14
  "InferenceModel",
14
15
  "EasyInferenceModel",
15
16
  "MultilayerInferenceModel",
17
+ "HuggingfaceQueryMatcher",
16
18
  "StandardPreprocessing",
17
19
  "standard_preprocessing",
18
20
  "ReflGradientFit",
@@ -35,15 +35,17 @@ class EasyInferenceModel(object):
35
35
  config_name (str, optional): the name of the configuration file used to initialize the model (either with or without the '.yaml' extension). Defaults to None.
36
36
  model_name (str, optional): the name of the file containing the weights of the model (either with or without the '.pt' extension), only required if different than: `'model_' + config_name + '.pt'`. Defaults to None
37
37
  root_dir (str, optional): path to root directory containing the 'configs' and 'saved_models' subdirectories, if different from the package root directory (ROOT_DIR). Defaults to None.
38
+ weights_format (str, optional): format (extension) of the weights file, either 'pt' or 'safetensors'. Defaults to 'safetensors'.
38
39
  repo_id (str, optional): the id of the Huggingface repository from which the configuration files and model weights should be downloaded automatically if not found locally (in the 'configs' and 'saved_models' subdirectories of the root directory). Defaults to 'valentinsingularity/reflectivity'.
39
40
  trainer (PointEstimatorTrainer, optional): if provided, this trainer instance is used directly instead of being initialized from the configuration file. Defaults to None.
40
41
  device (str, optional): the Pytorch device ('cuda' or 'cpu'). Defaults to 'cuda'.
41
42
  """
42
- def __init__(self, config_name: str = None, model_name: str = None, root_dir:str = None, repo_id: str = 'valentinsingularity/reflectivity',
43
- trainer: PointEstimatorTrainer = None, device='cuda'):
43
+ def __init__(self, config_name: str = None, model_name: str = None, root_dir:str = None, weights_format: str = 'safetensors',
44
+ repo_id: str = 'valentinsingularity/reflectivity', trainer: PointEstimatorTrainer = None, device='cuda'):
44
45
  self.config_name = config_name
45
46
  self.model_name = model_name
46
47
  self.root_dir = root_dir
48
+ self.weights_format = weights_format
47
49
  self.repo_id = repo_id
48
50
  self.trainer = trainer
49
51
  self.device = device
@@ -58,7 +60,7 @@ class EasyInferenceModel(object):
58
60
 
59
61
  Args:
60
62
  config_name (str): the name of the configuration file used to initialize the model (either with or without the '.yaml' extension).
61
- model_name (str): the name of the file containing the weights of the model (either with or without the '.pt' extension), only required if different than: `'model_' + config_name + '.pt'`.
63
+ model_name (str): the name of the file containing the weights of the model (either with or without the '.pt' or '.safetensors' extension), only required if different than: `'model_' + config_name + extension`.
62
64
  root_dir (str): path to root directory containing the 'configs' and 'saved_models' subdirectories, if different from the package root directory (ROOT_DIR).
63
65
  """
64
66
  if self.config_name == config_name and self.trainer is not None:
@@ -72,9 +74,10 @@ class EasyInferenceModel(object):
72
74
  self.config_name = config_name
73
75
 
74
76
  self.config_dir = Path(root_dir) / 'configs' if root_dir else CONFIG_DIR
75
- self.model_name = model_name or 'model_' + config_name_no_extension + '.pt'
76
- if not self.model_name.endswith('.pt'):
77
- self.model_name += '.pt'
77
+ weights_extension = '.' + self.weights_format
78
+ self.model_name = model_name or 'model_' + config_name_no_extension + weights_extension
79
+ if not self.model_name.endswith(weights_extension):
80
+ self.model_name += weights_extension
78
81
  self.model_dir = Path(root_dir) / 'saved_models' if root_dir else SAVED_MODELS_DIR
79
82
 
80
83
  config_path = Path(self.config_dir) / self.config_name
@@ -0,0 +1,82 @@
1
+ import os
2
+ import tempfile
3
+ import yaml
4
+ from huggingface_hub import hf_hub_download, list_repo_files
5
+
6
+ class HuggingfaceQueryMatcher:
7
+ """Downloads the available configurations files to a temporary directory and provides functionality for filtering those configuration files matching user specified queries.
8
+
9
+ Args:
10
+ repo_id (str): The Hugging Face repository ID.
11
+ config_dir (str): Directory within the repo where YAML files are stored.
12
+ """
13
+ def __init__(self, repo_id='valentinsingularity/reflectivity', config_dir='configs'):
14
+ self.repo_id = repo_id
15
+ self.config_dir = config_dir
16
+ self.cache = {
17
+ 'parsed_configs': None,
18
+ 'temp_dir': None
19
+ }
20
+ self._renew_cache()
21
+
22
+ def _renew_cache(self):
23
+ temp_dir = tempfile.mkdtemp()
24
+ print(f"Temporary directory created at: {temp_dir}")
25
+
26
+ repo_files = list_repo_files(self.repo_id, repo_type='model')
27
+ config_files = [file for file in repo_files if file.startswith(self.config_dir) and file.endswith('.yaml')]
28
+
29
+ downloaded_files = []
30
+ for file in config_files:
31
+ file_path = hf_hub_download(repo_id=self.repo_id, filename=file, local_dir=temp_dir, repo_type='model')
32
+ downloaded_files.append(file_path)
33
+
34
+ parsed_configs = {}
35
+ for file_path in downloaded_files:
36
+ with open(file_path, 'r') as file:
37
+ config_data = yaml.safe_load(file)
38
+ file_name = os.path.basename(file_path)
39
+ parsed_configs[file_name] = config_data
40
+
41
+ self.cache['parsed_configs'] = parsed_configs
42
+ self.cache['temp_dir'] = temp_dir
43
+
44
+ def get_matching_configs(self, query):
45
+ """retrieves configuration files that match the user specified query.
46
+
47
+ Args:
48
+ query (dict): Dictionary of key-value pairs to filter configurations, e.g. ``query = {'dset.prior_sampler.kwargs.max_num_layers': 3, 'dset.prior_sampler.kwargs.param_ranges.slds': [0., 100.]}``.
49
+ For keys containing the ``param_ranges`` subkey a configuration is selected if the value of the query (i.e. desired parameter range)
50
+ is a subrange of the parameter range in the configuration, in all other cases the values must match exactly.
51
+
52
+ Returns:
53
+ list: List of file names that match the query.
54
+ """
55
+
56
+ filtered_configs = []
57
+
58
+ for file_name, config_data in self.cache['parsed_configs'].items():
59
+ if self.matches_query(config_data, query):
60
+ filtered_configs.append(file_name)
61
+
62
+ return filtered_configs
63
+
64
+ def matches_query(self, config_data, query):
65
+ for q_key, q_value in query.items():
66
+ keys = q_key.split('.')
67
+ value = self.deep_get(config_data, keys)
68
+ if 'param_ranges' in keys:
69
+ if q_value[0] < value[0] or q_value[1] > value[1]:
70
+ return False
71
+ else:
72
+ if value != q_value:
73
+ return False
74
+
75
+ return True
76
+
77
+ def deep_get(self, d, keys):
78
+ for key in keys:
79
+ if isinstance(d, dict):
80
+ d = d.get(key, None)
81
+
82
+ return d
@@ -1,9 +1,3 @@
1
- # -*- coding: utf-8 -*-
2
- #
3
- #
4
- # This source code is licensed under the GPL license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
1
  from reflectorch.ml.basic_trainer import *
8
2
  from reflectorch.ml.callbacks import *
9
3
  from reflectorch.ml.trainers import *
@@ -32,6 +26,5 @@ __all__ = [
32
26
  'MultilayerDataLoader',
33
27
  'RealTimeSimTrainer',
34
28
  'DenoisingAETrainer',
35
- 'VAETrainer',
36
29
  'PointEstimatorTrainer',
37
30
  ]
@@ -1,9 +1,3 @@
1
- # -*- coding: utf-8 -*-
2
- #
3
- #
4
- # This source code is licensed under the GPL license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
1
  from typing import Optional, Tuple, Iterable, Any, Union, Type
8
2
  from collections import defaultdict
9
3
 
@@ -1,9 +1,3 @@
1
- # -*- coding: utf-8 -*-
2
- #
3
- #
4
- # This source code is licensed under the GPL license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
1
  import torch
8
2
 
9
3
  import numpy as np
reflectorch/ml/loggers.py CHANGED
@@ -1,10 +1,3 @@
1
- # -*- coding: utf-8 -*-
2
- #
3
- #
4
- # This source code is licensed under the GPL license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
-
8
1
  __all__ = [
9
2
  'Logger',
10
3
  'Loggers',
@@ -1,9 +1,3 @@
1
- # -*- coding: utf-8 -*-
2
- #
3
- #
4
- # This source code is licensed under the GPL license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
1
  from torch.optim import lr_scheduler
8
2
 
9
3
  import numpy as np
@@ -1,9 +1,3 @@
1
- # -*- coding: utf-8 -*-
2
- #
3
- #
4
- # This source code is licensed under the GPL license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
1
  import numpy as np
8
2
  import torch
9
3
  import torch.nn.functional as F
@@ -16,7 +10,6 @@ from reflectorch.ml.dataloaders import ReflectivityDataLoader
16
10
  __all__ = [
17
11
  'RealTimeSimTrainer',
18
12
  'DenoisingAETrainer',
19
- 'VAETrainer',
20
13
  'PointEstimatorTrainer',
21
14
  ]
22
15
 
@@ -97,30 +90,4 @@ class DenoisingAETrainer(RealTimeSimTrainer):
97
90
  restored_curves = self.model(scaled_noisy_curves)
98
91
  loss = self.criterion(scaled_curves, restored_curves)
99
92
  return {'loss': loss}
100
-
101
-
102
- class VAETrainer(DenoisingAETrainer):
103
- """Trainer which can be used for training a denoising autoencoder model. Overrides _get_batch and get_loss_dict methods """
104
- def init(self):
105
- self.loader.calc_denoised_curves = True
106
- self.freebits = 0.05
107
-
108
- def calc_kl(self, z_mu, z_logvar):
109
- return 0.5*(z_mu**2 + torch.exp(z_logvar) - 1 - z_logvar)
110
-
111
- def gaussian_log_prob(self, z, mu, logvar):
112
- return -0.5*(np.log(2*np.pi) + logvar + (z-mu)**2/torch.exp(logvar))
113
-
114
- def get_loss_dict(self, batch_data):
115
- """returns the reconstruction loss of the autoencoder"""
116
- scaled_noisy_curves, scaled_curves = batch_data
117
- _, (z_mu, z_logvar, restored_curves_mu, restored_curves_logvar) = self.model(scaled_noisy_curves)
118
-
119
- l_rec = -torch.mean(self.gaussian_log_prob(scaled_curves, restored_curves_mu, restored_curves_logvar), dim=-1)
120
- l_kl = torch.mean(F.relu(self.calc_kl(z_mu, z_logvar) - self.freebits*np.log(2)) + self.freebits*np.log(2), dim=-1)
121
- loss = torch.mean(l_rec + l_kl)/np.log(2)
122
-
123
- l_rec = torch.mean(l_rec)
124
- l_kl = torch.mean(l_kl)
125
-
126
- return {'loss': loss}
93
+
reflectorch/ml/utils.py CHANGED
@@ -1,9 +1,2 @@
1
- # -*- coding: utf-8 -*-
2
- #
3
- #
4
- # This source code is licensed under the GPL license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
-
8
1
  def is_divisor(num: int, div: int):
9
2
  return num and not num % div