reflectorch 1.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (96) hide show
  1. reflectorch/__init__.py +17 -0
  2. reflectorch/data_generation/__init__.py +128 -0
  3. reflectorch/data_generation/dataset.py +216 -0
  4. reflectorch/data_generation/likelihoods.py +80 -0
  5. reflectorch/data_generation/noise.py +471 -0
  6. reflectorch/data_generation/priors/__init__.py +60 -0
  7. reflectorch/data_generation/priors/base.py +55 -0
  8. reflectorch/data_generation/priors/exp_subprior_sampler.py +298 -0
  9. reflectorch/data_generation/priors/independent_priors.py +195 -0
  10. reflectorch/data_generation/priors/multilayer_models.py +311 -0
  11. reflectorch/data_generation/priors/multilayer_structures.py +104 -0
  12. reflectorch/data_generation/priors/no_constraints.py +206 -0
  13. reflectorch/data_generation/priors/parametric_models.py +842 -0
  14. reflectorch/data_generation/priors/parametric_subpriors.py +369 -0
  15. reflectorch/data_generation/priors/params.py +252 -0
  16. reflectorch/data_generation/priors/sampler_strategies.py +370 -0
  17. reflectorch/data_generation/priors/scaler_mixin.py +65 -0
  18. reflectorch/data_generation/priors/subprior_sampler.py +371 -0
  19. reflectorch/data_generation/priors/utils.py +118 -0
  20. reflectorch/data_generation/process_data.py +41 -0
  21. reflectorch/data_generation/q_generator.py +280 -0
  22. reflectorch/data_generation/reflectivity/__init__.py +102 -0
  23. reflectorch/data_generation/reflectivity/abeles.py +97 -0
  24. reflectorch/data_generation/reflectivity/kinematical.py +71 -0
  25. reflectorch/data_generation/reflectivity/memory_eff.py +105 -0
  26. reflectorch/data_generation/reflectivity/numpy_implementations.py +120 -0
  27. reflectorch/data_generation/reflectivity/smearing.py +138 -0
  28. reflectorch/data_generation/reflectivity/smearing_pointwise.py +110 -0
  29. reflectorch/data_generation/scale_curves.py +112 -0
  30. reflectorch/data_generation/smearing.py +99 -0
  31. reflectorch/data_generation/utils.py +223 -0
  32. reflectorch/extensions/__init__.py +0 -0
  33. reflectorch/extensions/jupyter/__init__.py +11 -0
  34. reflectorch/extensions/jupyter/api.py +85 -0
  35. reflectorch/extensions/jupyter/callbacks.py +34 -0
  36. reflectorch/extensions/jupyter/components.py +758 -0
  37. reflectorch/extensions/jupyter/custom_select.py +268 -0
  38. reflectorch/extensions/jupyter/log_widget.py +241 -0
  39. reflectorch/extensions/jupyter/model_selection.py +495 -0
  40. reflectorch/extensions/jupyter/plotly_plot_manager.py +329 -0
  41. reflectorch/extensions/jupyter/widget.py +625 -0
  42. reflectorch/extensions/matplotlib/__init__.py +5 -0
  43. reflectorch/extensions/matplotlib/losses.py +32 -0
  44. reflectorch/extensions/refnx/refnx_conversion.py +77 -0
  45. reflectorch/inference/__init__.py +28 -0
  46. reflectorch/inference/inference_model.py +848 -0
  47. reflectorch/inference/input_interface.py +239 -0
  48. reflectorch/inference/loading_data.py +55 -0
  49. reflectorch/inference/multilayer_fitter.py +171 -0
  50. reflectorch/inference/multilayer_inference_model.py +193 -0
  51. reflectorch/inference/plotting.py +524 -0
  52. reflectorch/inference/preprocess_exp/__init__.py +7 -0
  53. reflectorch/inference/preprocess_exp/attenuation.py +36 -0
  54. reflectorch/inference/preprocess_exp/cut_with_q_ratio.py +31 -0
  55. reflectorch/inference/preprocess_exp/footprint.py +81 -0
  56. reflectorch/inference/preprocess_exp/interpolation.py +19 -0
  57. reflectorch/inference/preprocess_exp/normalize.py +21 -0
  58. reflectorch/inference/preprocess_exp/preprocess.py +121 -0
  59. reflectorch/inference/query_matcher.py +82 -0
  60. reflectorch/inference/record_time.py +43 -0
  61. reflectorch/inference/sampler_solution.py +56 -0
  62. reflectorch/inference/scipy_fitter.py +364 -0
  63. reflectorch/inference/torch_fitter.py +87 -0
  64. reflectorch/ml/__init__.py +32 -0
  65. reflectorch/ml/basic_trainer.py +292 -0
  66. reflectorch/ml/callbacks.py +81 -0
  67. reflectorch/ml/dataloaders.py +27 -0
  68. reflectorch/ml/loggers.py +56 -0
  69. reflectorch/ml/schedulers.py +356 -0
  70. reflectorch/ml/trainers.py +201 -0
  71. reflectorch/ml/utils.py +2 -0
  72. reflectorch/models/__init__.py +16 -0
  73. reflectorch/models/activations.py +50 -0
  74. reflectorch/models/encoders/__init__.py +19 -0
  75. reflectorch/models/encoders/conv_encoder.py +219 -0
  76. reflectorch/models/encoders/conv_res_net.py +115 -0
  77. reflectorch/models/encoders/fno.py +134 -0
  78. reflectorch/models/encoders/integral_kernel_embedding.py +390 -0
  79. reflectorch/models/networks/__init__.py +14 -0
  80. reflectorch/models/networks/mlp_networks.py +434 -0
  81. reflectorch/models/networks/residual_net.py +157 -0
  82. reflectorch/paths.py +29 -0
  83. reflectorch/runs/__init__.py +31 -0
  84. reflectorch/runs/config.py +25 -0
  85. reflectorch/runs/slurm_utils.py +93 -0
  86. reflectorch/runs/train.py +78 -0
  87. reflectorch/runs/utils.py +405 -0
  88. reflectorch/test_config.py +4 -0
  89. reflectorch/train.py +4 -0
  90. reflectorch/train_on_cluster.py +4 -0
  91. reflectorch/utils.py +98 -0
  92. reflectorch-1.5.1.dist-info/METADATA +151 -0
  93. reflectorch-1.5.1.dist-info/RECORD +96 -0
  94. reflectorch-1.5.1.dist-info/WHEEL +5 -0
  95. reflectorch-1.5.1.dist-info/licenses/LICENSE.txt +21 -0
  96. reflectorch-1.5.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,19 @@
1
+ import numpy as np
2
+
3
+
4
+ def interp_reflectivity(q_interp, q, reflectivity, min_value: float = 1e-10, logspace = False):
5
+ """Interpolate data on a base 10 logarithmic scale
6
+
7
+ Args:
8
+ q_interp (array-like): reciprocal space points used for the interpolation
9
+ q (array-like): reciprocal space points of the measured reflectivity curve
10
+ reflectivity (array-like): reflectivity curve measured at the points ``q``
11
+ min_value (float, optional): minimum intensity of the reflectivity curve. Defaults to 1e-10.
12
+
13
+ Returns:
14
+ array-like: interpolated reflectivity curve
15
+ """
16
+ if not(logspace):
17
+ return 10 ** np.interp(q_interp, q, np.log10(np.clip(reflectivity, min_value, None)))
18
+ else:
19
+ return 10 ** np.interp(np.log10(q_interp), np.log10(q), np.log10(np.clip(reflectivity, min_value, None)))
@@ -0,0 +1,21 @@
1
+ try:
2
+ from typing import Literal
3
+ except ImportError:
4
+ from typing_extensions import Literal
5
+
6
+ import numpy as np
7
+ from numpy import ndarray
8
+
9
+ NORMALIZE_MODE = Literal["first", "max", "incoming_intensity"]
10
+
11
+
12
+ def intensity2reflectivity(intensity: ndarray, mode: NORMALIZE_MODE, incoming_intensity=None) -> np.ndarray:
13
+ if mode == "first":
14
+ return intensity / intensity[0]
15
+ if mode == "max":
16
+ return intensity / intensity.max()
17
+ if mode == "incoming_intensity":
18
+ if incoming_intensity is None:
19
+ raise ValueError("incoming_intensity is None")
20
+ return intensity / incoming_intensity
21
+ raise ValueError(f"Unknown mode {mode}")
@@ -0,0 +1,121 @@
1
+ from dataclasses import dataclass
2
+
3
+ import numpy as np
4
+
5
+ from reflectorch.inference.preprocess_exp.interpolation import interp_reflectivity
6
+ from reflectorch.inference.preprocess_exp.footprint import apply_footprint_correction, BEAM_SHAPE
7
+ from reflectorch.inference.preprocess_exp.normalize import intensity2reflectivity, NORMALIZE_MODE
8
+ from reflectorch.inference.preprocess_exp.attenuation import apply_attenuation_correction
9
+ from reflectorch.inference.preprocess_exp.cut_with_q_ratio import cut_curve
10
+ from reflectorch.utils import angle_to_q
11
+
12
+
13
+ def standard_preprocessing(
14
+ intensity: np.ndarray,
15
+ scattering_angle: np.ndarray,
16
+ attenuation: np.ndarray,
17
+ q_interp: np.ndarray,
18
+ wavelength: float,
19
+ beam_width: float,
20
+ sample_length: float,
21
+ min_intensity: float = 1e-10,
22
+ beam_shape: BEAM_SHAPE = "gauss",
23
+ normalize_mode: NORMALIZE_MODE = "max",
24
+ incoming_intensity: float = None,
25
+ max_q: float = None, # if provided, max_angle is ignored
26
+ max_angle: float = None,
27
+ ) -> dict:
28
+ """Preprocesses a raw experimental reflectivity curve by applying attenuation correction, footprint correction, cutting at a maximum q value and interpolation
29
+
30
+ Args:
31
+ intensity (np.ndarray): array of intensities of the reflectivity curve
32
+ scattering_angle (np.ndarray): array of scattering angles
33
+ attenuation (np.ndarray): attenuation factors for each measured point
34
+ q_interp (np.ndarray): reciprocal space points used for the interpolation
35
+ wavelength (float): the wavelength of the beam
36
+ beam_width (float): the beam width
37
+ sample_length (float): the sample length
38
+ min_intensity (float, optional): intensities lower than this value are removed. Defaults to 1e-10.
39
+ beam_shape (BEAM_SHAPE, optional): the shape of the beam, either "gauss" or "box". Defaults to "gauss".
40
+ normalize_mode (NORMALIZE_MODE, optional): normalization mode, either "first", "max" or "incoming_intensity". Defaults to "max".
41
+ incoming_intensity (float, optional): array of intensities for the "incoming_intensity" normalization. Defaults to None.
42
+ max_q (float, optional): the maximum q value at which the curve is cut. Defaults to None.
43
+ max_angle (float, optional): the maximum scattering angle at which the curve is cut; only used if max_q is not provided. Defaults to None.
44
+
45
+ Returns:
46
+ dict: dictionary containing the interpolated reflectivity curve, the curve before interpolation, the q values before interpolation, the q values after interpolation and the q ratio of the cutting
47
+ """
48
+ intensity = apply_attenuation_correction(
49
+ intensity,
50
+ attenuation,
51
+ scattering_angle,
52
+ )
53
+
54
+ intensity = apply_footprint_correction(
55
+ intensity, scattering_angle, beam_width=beam_width, sample_length=sample_length, beam_shape=beam_shape
56
+ )
57
+
58
+ curve = intensity2reflectivity(intensity, normalize_mode, incoming_intensity)
59
+
60
+ curve, scattering_angle = remove_low_statistics(curve, scattering_angle, thresh=min_intensity)
61
+
62
+ q = angle_to_q(scattering_angle, wavelength)
63
+
64
+ q, curve, q_ratio = cut_curve(q, curve, max_q, max_angle, wavelength)
65
+
66
+ curve_interp = interp_reflectivity(q_interp, q, curve)
67
+
68
+ assert np.all(np.isfinite(curve_interp))
69
+ assert np.all(np.isfinite(curve))
70
+ assert np.all(np.isfinite(q))
71
+ assert np.all(np.isfinite(q_interp))
72
+ assert np.all(curve > 0.)
73
+ assert np.all(curve_interp > 0.)
74
+
75
+ return {
76
+ "curve_interp": curve_interp, "curve": curve, "q_values": q, "q_interp": q_interp, "q_ratio": q_ratio,
77
+ }
78
+
79
+
80
+ def remove_low_statistics(curve, scattering_angle, thresh: float = 1e-7):
81
+ indices = (curve > thresh) & np.isfinite(curve)
82
+ return curve[indices], scattering_angle[indices]
83
+
84
+
85
+ @dataclass
86
+ class StandardPreprocessing:
87
+ q_interp: np.ndarray = None
88
+ wavelength: float = 1.
89
+ beam_width: float = None
90
+ sample_length: float = None
91
+ beam_shape: BEAM_SHAPE = "gauss"
92
+ normalize_mode: NORMALIZE_MODE = "max"
93
+ incoming_intensity: float = None
94
+
95
+ def preprocess(self,
96
+ intensity: np.ndarray,
97
+ scattering_angle: np.ndarray,
98
+ attenuation: np.ndarray,
99
+ **kwargs
100
+ ) -> dict:
101
+ attrs = self._get_updated_attrs(**kwargs)
102
+ return standard_preprocessing(
103
+ intensity,
104
+ scattering_angle,
105
+ attenuation,
106
+ **attrs
107
+ )
108
+
109
+ __call__ = preprocess
110
+
111
+ def set_parameters(self, **kwargs) -> None:
112
+ for k, v in kwargs.items():
113
+ if k in self.__annotations__:
114
+ setattr(self, k, v)
115
+ else:
116
+ raise KeyError(f'Unknown parameter {k}.')
117
+
118
+ def _get_updated_attrs(self, **kwargs):
119
+ current_attrs = {k: getattr(self, k) for k in self.__annotations__.keys()}
120
+ current_attrs.update(kwargs)
121
+ return current_attrs
@@ -0,0 +1,82 @@
1
+ import os
2
+ import tempfile
3
+ import yaml
4
+ from huggingface_hub import hf_hub_download, list_repo_files
5
+
6
+ class HuggingfaceQueryMatcher:
7
+ """Downloads the available configurations files to a temporary directory and provides functionality for filtering those configuration files matching user specified queries.
8
+
9
+ Args:
10
+ repo_id (str): The Hugging Face repository ID.
11
+ config_dir (str): Directory within the repo where YAML files are stored.
12
+ """
13
+ def __init__(self, repo_id='valentinsingularity/reflectivity', config_dir='configs'):
14
+ self.repo_id = repo_id
15
+ self.config_dir = config_dir
16
+ self.cache = {
17
+ 'parsed_configs': None,
18
+ 'temp_dir': None
19
+ }
20
+ self._renew_cache()
21
+
22
+ def _renew_cache(self):
23
+ temp_dir = tempfile.mkdtemp()
24
+ print(f"Temporary directory created at: {temp_dir}")
25
+
26
+ repo_files = list_repo_files(self.repo_id, repo_type='model')
27
+ config_files = [file for file in repo_files if file.startswith(self.config_dir) and file.endswith('.yaml')]
28
+
29
+ downloaded_files = []
30
+ for file in config_files:
31
+ file_path = hf_hub_download(repo_id=self.repo_id, filename=file, local_dir=temp_dir, repo_type='model')
32
+ downloaded_files.append(file_path)
33
+
34
+ parsed_configs = {}
35
+ for file_path in downloaded_files:
36
+ with open(file_path, 'r') as file:
37
+ config_data = yaml.safe_load(file)
38
+ file_name = os.path.basename(file_path)
39
+ parsed_configs[file_name] = config_data
40
+
41
+ self.cache['parsed_configs'] = parsed_configs
42
+ self.cache['temp_dir'] = temp_dir
43
+
44
+ def get_matching_configs(self, query):
45
+ """retrieves configuration files that match the user specified query.
46
+
47
+ Args:
48
+ query (dict): Dictionary of key-value pairs to filter configurations, e.g. ``query = {'dset.prior_sampler.kwargs.max_num_layers': 3, 'dset.prior_sampler.kwargs.param_ranges.slds': [0., 100.]}``.
49
+ For keys containing the ``param_ranges`` subkey a configuration is selected if the value of the query (i.e. desired parameter range)
50
+ is a subrange of the parameter range in the configuration, in all other cases the values must match exactly.
51
+
52
+ Returns:
53
+ list: List of file names that match the query.
54
+ """
55
+
56
+ filtered_configs = []
57
+
58
+ for file_name, config_data in self.cache['parsed_configs'].items():
59
+ if self.matches_query(config_data, query):
60
+ filtered_configs.append(file_name)
61
+
62
+ return filtered_configs
63
+
64
+ def matches_query(self, config_data, query):
65
+ for q_key, q_value in query.items():
66
+ keys = q_key.split('.')
67
+ value = self.deep_get(config_data, keys)
68
+ if 'param_ranges' in keys:
69
+ if q_value[0] < value[0] or q_value[1] > value[1]:
70
+ return False
71
+ else:
72
+ if value != q_value:
73
+ return False
74
+
75
+ return True
76
+
77
+ def deep_get(self, d, keys):
78
+ for key in keys:
79
+ if isinstance(d, dict):
80
+ d = d.get(key, None)
81
+
82
+ return d
@@ -0,0 +1,43 @@
1
+ from time import perf_counter
2
+ from contextlib import contextmanager
3
+ from functools import wraps
4
+
5
+
6
+ class EvaluateTime(list):
7
+ @contextmanager
8
+ def __call__(self, name: str, *args, **kwargs):
9
+ start = perf_counter()
10
+ yield
11
+ self.action(perf_counter() - start, name, *args, **kwargs)
12
+
13
+ @staticmethod
14
+ def action(delta_time, name, *args, **kwargs):
15
+ print(f"Time for {name} = {delta_time:.2f} sec")
16
+
17
+ def __repr__(self):
18
+ return f'EvaluateTime(total={sum(self)}, num_records={len(self)})'
19
+
20
+
21
+ def print_time(name: str or callable):
22
+ if isinstance(name, str):
23
+ return _print_time_context(name)
24
+ else:
25
+ return _print_time_wrap(name)
26
+
27
+
28
+ def _print_time_wrap(func, name: str = None):
29
+ name = name or func.__name__
30
+
31
+ @wraps(func)
32
+ def wrapped_func(*args, **kwargs):
33
+ with _print_time_context(name):
34
+ return func(*args, **kwargs)
35
+
36
+ return wrapped_func
37
+
38
+
39
+ @contextmanager
40
+ def _print_time_context(name: str):
41
+ start = perf_counter()
42
+ yield
43
+ print(f"Time for {name} = {(perf_counter() - start):.2f} sec")
@@ -0,0 +1,56 @@
1
+ import torch
2
+ from torch import Tensor
3
+
4
+ from reflectorch.data_generation.priors.utils import uniform_sampler
5
+ from reflectorch.data_generation.priors.subprior_sampler import UniformSubPriorParams
6
+ from reflectorch.data_generation.priors.params import Params
7
+ from reflectorch.data_generation.likelihoods import LogLikelihood
8
+
9
+
10
+ def simple_sampler_solution(
11
+ likelihood: LogLikelihood,
12
+ predicted_params: UniformSubPriorParams,
13
+ total_min_bounds: Tensor,
14
+ total_max_bounds: Tensor,
15
+ num: int = 2 ** 15,
16
+ coef: float = 0.1,
17
+ ) -> UniformSubPriorParams:
18
+ sampled_params_t = sample_around_params(predicted_params, total_min_bounds, total_max_bounds, num=num, coef=coef)
19
+ sampled_params = Params.from_tensor(sampled_params_t)
20
+ return get_best_mse_param(sampled_params, likelihood, predicted_params.min_bounds, predicted_params.max_bounds)
21
+
22
+
23
+ def sample_around_params(predicted_params: UniformSubPriorParams,
24
+ total_min_bounds: Tensor,
25
+ total_max_bounds: Tensor,
26
+ num: int = 2 ** 15,
27
+ coef: float = 0.1,
28
+ ) -> Tensor:
29
+ params_t = predicted_params.as_tensor(add_bounds=False)
30
+
31
+ delta = (predicted_params.max_bounds - predicted_params.min_bounds) * coef
32
+ min_bounds = torch.clamp(params_t - delta, total_min_bounds, total_max_bounds)
33
+ max_bounds = torch.clamp(params_t + delta, total_min_bounds, total_max_bounds)
34
+
35
+ sampled_params_t = uniform_sampler(min_bounds, max_bounds, num, params_t.shape[-1])
36
+ sampled_params_t[0] = params_t[0]
37
+
38
+ return sampled_params_t
39
+
40
+
41
+ def get_best_mse_param(
42
+ params: Params,
43
+ likelihood: LogLikelihood,
44
+ min_bounds: Tensor = None,
45
+ max_bounds: Tensor = None,
46
+ ):
47
+ sampled_curves = params.reflectivity(likelihood.q)
48
+ log_probs = likelihood.calc_log_likelihood(sampled_curves)
49
+ best_idx = torch.argmax(log_probs)
50
+ best_param = params[best_idx:best_idx + 1]
51
+
52
+ if min_bounds is not None:
53
+ best_param = UniformSubPriorParams.from_tensor(
54
+ torch.cat([best_param.as_tensor(), torch.atleast_2d(min_bounds), torch.atleast_2d(max_bounds)], -1)
55
+ )
56
+ return best_param