reflectorch 1.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- reflectorch/__init__.py +17 -0
- reflectorch/data_generation/__init__.py +128 -0
- reflectorch/data_generation/dataset.py +216 -0
- reflectorch/data_generation/likelihoods.py +80 -0
- reflectorch/data_generation/noise.py +471 -0
- reflectorch/data_generation/priors/__init__.py +60 -0
- reflectorch/data_generation/priors/base.py +55 -0
- reflectorch/data_generation/priors/exp_subprior_sampler.py +298 -0
- reflectorch/data_generation/priors/independent_priors.py +195 -0
- reflectorch/data_generation/priors/multilayer_models.py +311 -0
- reflectorch/data_generation/priors/multilayer_structures.py +104 -0
- reflectorch/data_generation/priors/no_constraints.py +206 -0
- reflectorch/data_generation/priors/parametric_models.py +842 -0
- reflectorch/data_generation/priors/parametric_subpriors.py +369 -0
- reflectorch/data_generation/priors/params.py +252 -0
- reflectorch/data_generation/priors/sampler_strategies.py +370 -0
- reflectorch/data_generation/priors/scaler_mixin.py +65 -0
- reflectorch/data_generation/priors/subprior_sampler.py +371 -0
- reflectorch/data_generation/priors/utils.py +118 -0
- reflectorch/data_generation/process_data.py +41 -0
- reflectorch/data_generation/q_generator.py +280 -0
- reflectorch/data_generation/reflectivity/__init__.py +102 -0
- reflectorch/data_generation/reflectivity/abeles.py +97 -0
- reflectorch/data_generation/reflectivity/kinematical.py +71 -0
- reflectorch/data_generation/reflectivity/memory_eff.py +105 -0
- reflectorch/data_generation/reflectivity/numpy_implementations.py +120 -0
- reflectorch/data_generation/reflectivity/smearing.py +138 -0
- reflectorch/data_generation/reflectivity/smearing_pointwise.py +110 -0
- reflectorch/data_generation/scale_curves.py +112 -0
- reflectorch/data_generation/smearing.py +99 -0
- reflectorch/data_generation/utils.py +223 -0
- reflectorch/extensions/__init__.py +0 -0
- reflectorch/extensions/jupyter/__init__.py +11 -0
- reflectorch/extensions/jupyter/api.py +85 -0
- reflectorch/extensions/jupyter/callbacks.py +34 -0
- reflectorch/extensions/jupyter/components.py +758 -0
- reflectorch/extensions/jupyter/custom_select.py +268 -0
- reflectorch/extensions/jupyter/log_widget.py +241 -0
- reflectorch/extensions/jupyter/model_selection.py +495 -0
- reflectorch/extensions/jupyter/plotly_plot_manager.py +329 -0
- reflectorch/extensions/jupyter/widget.py +625 -0
- reflectorch/extensions/matplotlib/__init__.py +5 -0
- reflectorch/extensions/matplotlib/losses.py +32 -0
- reflectorch/extensions/refnx/refnx_conversion.py +77 -0
- reflectorch/inference/__init__.py +28 -0
- reflectorch/inference/inference_model.py +848 -0
- reflectorch/inference/input_interface.py +239 -0
- reflectorch/inference/loading_data.py +55 -0
- reflectorch/inference/multilayer_fitter.py +171 -0
- reflectorch/inference/multilayer_inference_model.py +193 -0
- reflectorch/inference/plotting.py +524 -0
- reflectorch/inference/preprocess_exp/__init__.py +7 -0
- reflectorch/inference/preprocess_exp/attenuation.py +36 -0
- reflectorch/inference/preprocess_exp/cut_with_q_ratio.py +31 -0
- reflectorch/inference/preprocess_exp/footprint.py +81 -0
- reflectorch/inference/preprocess_exp/interpolation.py +19 -0
- reflectorch/inference/preprocess_exp/normalize.py +21 -0
- reflectorch/inference/preprocess_exp/preprocess.py +121 -0
- reflectorch/inference/query_matcher.py +82 -0
- reflectorch/inference/record_time.py +43 -0
- reflectorch/inference/sampler_solution.py +56 -0
- reflectorch/inference/scipy_fitter.py +364 -0
- reflectorch/inference/torch_fitter.py +87 -0
- reflectorch/ml/__init__.py +32 -0
- reflectorch/ml/basic_trainer.py +292 -0
- reflectorch/ml/callbacks.py +81 -0
- reflectorch/ml/dataloaders.py +27 -0
- reflectorch/ml/loggers.py +56 -0
- reflectorch/ml/schedulers.py +356 -0
- reflectorch/ml/trainers.py +201 -0
- reflectorch/ml/utils.py +2 -0
- reflectorch/models/__init__.py +16 -0
- reflectorch/models/activations.py +50 -0
- reflectorch/models/encoders/__init__.py +19 -0
- reflectorch/models/encoders/conv_encoder.py +219 -0
- reflectorch/models/encoders/conv_res_net.py +115 -0
- reflectorch/models/encoders/fno.py +134 -0
- reflectorch/models/encoders/integral_kernel_embedding.py +390 -0
- reflectorch/models/networks/__init__.py +14 -0
- reflectorch/models/networks/mlp_networks.py +434 -0
- reflectorch/models/networks/residual_net.py +157 -0
- reflectorch/paths.py +29 -0
- reflectorch/runs/__init__.py +31 -0
- reflectorch/runs/config.py +25 -0
- reflectorch/runs/slurm_utils.py +93 -0
- reflectorch/runs/train.py +78 -0
- reflectorch/runs/utils.py +405 -0
- reflectorch/test_config.py +4 -0
- reflectorch/train.py +4 -0
- reflectorch/train_on_cluster.py +4 -0
- reflectorch/utils.py +98 -0
- reflectorch-1.5.1.dist-info/METADATA +151 -0
- reflectorch-1.5.1.dist-info/RECORD +96 -0
- reflectorch-1.5.1.dist-info/WHEEL +5 -0
- reflectorch-1.5.1.dist-info/licenses/LICENSE.txt +21 -0
- reflectorch-1.5.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def interp_reflectivity(q_interp, q, reflectivity, min_value: float = 1e-10, logspace = False):
|
|
5
|
+
"""Interpolate data on a base 10 logarithmic scale
|
|
6
|
+
|
|
7
|
+
Args:
|
|
8
|
+
q_interp (array-like): reciprocal space points used for the interpolation
|
|
9
|
+
q (array-like): reciprocal space points of the measured reflectivity curve
|
|
10
|
+
reflectivity (array-like): reflectivity curve measured at the points ``q``
|
|
11
|
+
min_value (float, optional): minimum intensity of the reflectivity curve. Defaults to 1e-10.
|
|
12
|
+
|
|
13
|
+
Returns:
|
|
14
|
+
array-like: interpolated reflectivity curve
|
|
15
|
+
"""
|
|
16
|
+
if not(logspace):
|
|
17
|
+
return 10 ** np.interp(q_interp, q, np.log10(np.clip(reflectivity, min_value, None)))
|
|
18
|
+
else:
|
|
19
|
+
return 10 ** np.interp(np.log10(q_interp), np.log10(q), np.log10(np.clip(reflectivity, min_value, None)))
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
try:
|
|
2
|
+
from typing import Literal
|
|
3
|
+
except ImportError:
|
|
4
|
+
from typing_extensions import Literal
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
from numpy import ndarray
|
|
8
|
+
|
|
9
|
+
NORMALIZE_MODE = Literal["first", "max", "incoming_intensity"]
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def intensity2reflectivity(intensity: ndarray, mode: NORMALIZE_MODE, incoming_intensity=None) -> np.ndarray:
|
|
13
|
+
if mode == "first":
|
|
14
|
+
return intensity / intensity[0]
|
|
15
|
+
if mode == "max":
|
|
16
|
+
return intensity / intensity.max()
|
|
17
|
+
if mode == "incoming_intensity":
|
|
18
|
+
if incoming_intensity is None:
|
|
19
|
+
raise ValueError("incoming_intensity is None")
|
|
20
|
+
return intensity / incoming_intensity
|
|
21
|
+
raise ValueError(f"Unknown mode {mode}")
|
|
@@ -0,0 +1,121 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
from reflectorch.inference.preprocess_exp.interpolation import interp_reflectivity
|
|
6
|
+
from reflectorch.inference.preprocess_exp.footprint import apply_footprint_correction, BEAM_SHAPE
|
|
7
|
+
from reflectorch.inference.preprocess_exp.normalize import intensity2reflectivity, NORMALIZE_MODE
|
|
8
|
+
from reflectorch.inference.preprocess_exp.attenuation import apply_attenuation_correction
|
|
9
|
+
from reflectorch.inference.preprocess_exp.cut_with_q_ratio import cut_curve
|
|
10
|
+
from reflectorch.utils import angle_to_q
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def standard_preprocessing(
|
|
14
|
+
intensity: np.ndarray,
|
|
15
|
+
scattering_angle: np.ndarray,
|
|
16
|
+
attenuation: np.ndarray,
|
|
17
|
+
q_interp: np.ndarray,
|
|
18
|
+
wavelength: float,
|
|
19
|
+
beam_width: float,
|
|
20
|
+
sample_length: float,
|
|
21
|
+
min_intensity: float = 1e-10,
|
|
22
|
+
beam_shape: BEAM_SHAPE = "gauss",
|
|
23
|
+
normalize_mode: NORMALIZE_MODE = "max",
|
|
24
|
+
incoming_intensity: float = None,
|
|
25
|
+
max_q: float = None, # if provided, max_angle is ignored
|
|
26
|
+
max_angle: float = None,
|
|
27
|
+
) -> dict:
|
|
28
|
+
"""Preprocesses a raw experimental reflectivity curve by applying attenuation correction, footprint correction, cutting at a maximum q value and interpolation
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
intensity (np.ndarray): array of intensities of the reflectivity curve
|
|
32
|
+
scattering_angle (np.ndarray): array of scattering angles
|
|
33
|
+
attenuation (np.ndarray): attenuation factors for each measured point
|
|
34
|
+
q_interp (np.ndarray): reciprocal space points used for the interpolation
|
|
35
|
+
wavelength (float): the wavelength of the beam
|
|
36
|
+
beam_width (float): the beam width
|
|
37
|
+
sample_length (float): the sample length
|
|
38
|
+
min_intensity (float, optional): intensities lower than this value are removed. Defaults to 1e-10.
|
|
39
|
+
beam_shape (BEAM_SHAPE, optional): the shape of the beam, either "gauss" or "box". Defaults to "gauss".
|
|
40
|
+
normalize_mode (NORMALIZE_MODE, optional): normalization mode, either "first", "max" or "incoming_intensity". Defaults to "max".
|
|
41
|
+
incoming_intensity (float, optional): array of intensities for the "incoming_intensity" normalization. Defaults to None.
|
|
42
|
+
max_q (float, optional): the maximum q value at which the curve is cut. Defaults to None.
|
|
43
|
+
max_angle (float, optional): the maximum scattering angle at which the curve is cut; only used if max_q is not provided. Defaults to None.
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
dict: dictionary containing the interpolated reflectivity curve, the curve before interpolation, the q values before interpolation, the q values after interpolation and the q ratio of the cutting
|
|
47
|
+
"""
|
|
48
|
+
intensity = apply_attenuation_correction(
|
|
49
|
+
intensity,
|
|
50
|
+
attenuation,
|
|
51
|
+
scattering_angle,
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
intensity = apply_footprint_correction(
|
|
55
|
+
intensity, scattering_angle, beam_width=beam_width, sample_length=sample_length, beam_shape=beam_shape
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
curve = intensity2reflectivity(intensity, normalize_mode, incoming_intensity)
|
|
59
|
+
|
|
60
|
+
curve, scattering_angle = remove_low_statistics(curve, scattering_angle, thresh=min_intensity)
|
|
61
|
+
|
|
62
|
+
q = angle_to_q(scattering_angle, wavelength)
|
|
63
|
+
|
|
64
|
+
q, curve, q_ratio = cut_curve(q, curve, max_q, max_angle, wavelength)
|
|
65
|
+
|
|
66
|
+
curve_interp = interp_reflectivity(q_interp, q, curve)
|
|
67
|
+
|
|
68
|
+
assert np.all(np.isfinite(curve_interp))
|
|
69
|
+
assert np.all(np.isfinite(curve))
|
|
70
|
+
assert np.all(np.isfinite(q))
|
|
71
|
+
assert np.all(np.isfinite(q_interp))
|
|
72
|
+
assert np.all(curve > 0.)
|
|
73
|
+
assert np.all(curve_interp > 0.)
|
|
74
|
+
|
|
75
|
+
return {
|
|
76
|
+
"curve_interp": curve_interp, "curve": curve, "q_values": q, "q_interp": q_interp, "q_ratio": q_ratio,
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def remove_low_statistics(curve, scattering_angle, thresh: float = 1e-7):
|
|
81
|
+
indices = (curve > thresh) & np.isfinite(curve)
|
|
82
|
+
return curve[indices], scattering_angle[indices]
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
@dataclass
|
|
86
|
+
class StandardPreprocessing:
|
|
87
|
+
q_interp: np.ndarray = None
|
|
88
|
+
wavelength: float = 1.
|
|
89
|
+
beam_width: float = None
|
|
90
|
+
sample_length: float = None
|
|
91
|
+
beam_shape: BEAM_SHAPE = "gauss"
|
|
92
|
+
normalize_mode: NORMALIZE_MODE = "max"
|
|
93
|
+
incoming_intensity: float = None
|
|
94
|
+
|
|
95
|
+
def preprocess(self,
|
|
96
|
+
intensity: np.ndarray,
|
|
97
|
+
scattering_angle: np.ndarray,
|
|
98
|
+
attenuation: np.ndarray,
|
|
99
|
+
**kwargs
|
|
100
|
+
) -> dict:
|
|
101
|
+
attrs = self._get_updated_attrs(**kwargs)
|
|
102
|
+
return standard_preprocessing(
|
|
103
|
+
intensity,
|
|
104
|
+
scattering_angle,
|
|
105
|
+
attenuation,
|
|
106
|
+
**attrs
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
__call__ = preprocess
|
|
110
|
+
|
|
111
|
+
def set_parameters(self, **kwargs) -> None:
|
|
112
|
+
for k, v in kwargs.items():
|
|
113
|
+
if k in self.__annotations__:
|
|
114
|
+
setattr(self, k, v)
|
|
115
|
+
else:
|
|
116
|
+
raise KeyError(f'Unknown parameter {k}.')
|
|
117
|
+
|
|
118
|
+
def _get_updated_attrs(self, **kwargs):
|
|
119
|
+
current_attrs = {k: getattr(self, k) for k in self.__annotations__.keys()}
|
|
120
|
+
current_attrs.update(kwargs)
|
|
121
|
+
return current_attrs
|
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import tempfile
|
|
3
|
+
import yaml
|
|
4
|
+
from huggingface_hub import hf_hub_download, list_repo_files
|
|
5
|
+
|
|
6
|
+
class HuggingfaceQueryMatcher:
|
|
7
|
+
"""Downloads the available configurations files to a temporary directory and provides functionality for filtering those configuration files matching user specified queries.
|
|
8
|
+
|
|
9
|
+
Args:
|
|
10
|
+
repo_id (str): The Hugging Face repository ID.
|
|
11
|
+
config_dir (str): Directory within the repo where YAML files are stored.
|
|
12
|
+
"""
|
|
13
|
+
def __init__(self, repo_id='valentinsingularity/reflectivity', config_dir='configs'):
|
|
14
|
+
self.repo_id = repo_id
|
|
15
|
+
self.config_dir = config_dir
|
|
16
|
+
self.cache = {
|
|
17
|
+
'parsed_configs': None,
|
|
18
|
+
'temp_dir': None
|
|
19
|
+
}
|
|
20
|
+
self._renew_cache()
|
|
21
|
+
|
|
22
|
+
def _renew_cache(self):
|
|
23
|
+
temp_dir = tempfile.mkdtemp()
|
|
24
|
+
print(f"Temporary directory created at: {temp_dir}")
|
|
25
|
+
|
|
26
|
+
repo_files = list_repo_files(self.repo_id, repo_type='model')
|
|
27
|
+
config_files = [file for file in repo_files if file.startswith(self.config_dir) and file.endswith('.yaml')]
|
|
28
|
+
|
|
29
|
+
downloaded_files = []
|
|
30
|
+
for file in config_files:
|
|
31
|
+
file_path = hf_hub_download(repo_id=self.repo_id, filename=file, local_dir=temp_dir, repo_type='model')
|
|
32
|
+
downloaded_files.append(file_path)
|
|
33
|
+
|
|
34
|
+
parsed_configs = {}
|
|
35
|
+
for file_path in downloaded_files:
|
|
36
|
+
with open(file_path, 'r') as file:
|
|
37
|
+
config_data = yaml.safe_load(file)
|
|
38
|
+
file_name = os.path.basename(file_path)
|
|
39
|
+
parsed_configs[file_name] = config_data
|
|
40
|
+
|
|
41
|
+
self.cache['parsed_configs'] = parsed_configs
|
|
42
|
+
self.cache['temp_dir'] = temp_dir
|
|
43
|
+
|
|
44
|
+
def get_matching_configs(self, query):
|
|
45
|
+
"""retrieves configuration files that match the user specified query.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
query (dict): Dictionary of key-value pairs to filter configurations, e.g. ``query = {'dset.prior_sampler.kwargs.max_num_layers': 3, 'dset.prior_sampler.kwargs.param_ranges.slds': [0., 100.]}``.
|
|
49
|
+
For keys containing the ``param_ranges`` subkey a configuration is selected if the value of the query (i.e. desired parameter range)
|
|
50
|
+
is a subrange of the parameter range in the configuration, in all other cases the values must match exactly.
|
|
51
|
+
|
|
52
|
+
Returns:
|
|
53
|
+
list: List of file names that match the query.
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
filtered_configs = []
|
|
57
|
+
|
|
58
|
+
for file_name, config_data in self.cache['parsed_configs'].items():
|
|
59
|
+
if self.matches_query(config_data, query):
|
|
60
|
+
filtered_configs.append(file_name)
|
|
61
|
+
|
|
62
|
+
return filtered_configs
|
|
63
|
+
|
|
64
|
+
def matches_query(self, config_data, query):
|
|
65
|
+
for q_key, q_value in query.items():
|
|
66
|
+
keys = q_key.split('.')
|
|
67
|
+
value = self.deep_get(config_data, keys)
|
|
68
|
+
if 'param_ranges' in keys:
|
|
69
|
+
if q_value[0] < value[0] or q_value[1] > value[1]:
|
|
70
|
+
return False
|
|
71
|
+
else:
|
|
72
|
+
if value != q_value:
|
|
73
|
+
return False
|
|
74
|
+
|
|
75
|
+
return True
|
|
76
|
+
|
|
77
|
+
def deep_get(self, d, keys):
|
|
78
|
+
for key in keys:
|
|
79
|
+
if isinstance(d, dict):
|
|
80
|
+
d = d.get(key, None)
|
|
81
|
+
|
|
82
|
+
return d
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
from time import perf_counter
|
|
2
|
+
from contextlib import contextmanager
|
|
3
|
+
from functools import wraps
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class EvaluateTime(list):
|
|
7
|
+
@contextmanager
|
|
8
|
+
def __call__(self, name: str, *args, **kwargs):
|
|
9
|
+
start = perf_counter()
|
|
10
|
+
yield
|
|
11
|
+
self.action(perf_counter() - start, name, *args, **kwargs)
|
|
12
|
+
|
|
13
|
+
@staticmethod
|
|
14
|
+
def action(delta_time, name, *args, **kwargs):
|
|
15
|
+
print(f"Time for {name} = {delta_time:.2f} sec")
|
|
16
|
+
|
|
17
|
+
def __repr__(self):
|
|
18
|
+
return f'EvaluateTime(total={sum(self)}, num_records={len(self)})'
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def print_time(name: str or callable):
|
|
22
|
+
if isinstance(name, str):
|
|
23
|
+
return _print_time_context(name)
|
|
24
|
+
else:
|
|
25
|
+
return _print_time_wrap(name)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def _print_time_wrap(func, name: str = None):
|
|
29
|
+
name = name or func.__name__
|
|
30
|
+
|
|
31
|
+
@wraps(func)
|
|
32
|
+
def wrapped_func(*args, **kwargs):
|
|
33
|
+
with _print_time_context(name):
|
|
34
|
+
return func(*args, **kwargs)
|
|
35
|
+
|
|
36
|
+
return wrapped_func
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
@contextmanager
|
|
40
|
+
def _print_time_context(name: str):
|
|
41
|
+
start = perf_counter()
|
|
42
|
+
yield
|
|
43
|
+
print(f"Time for {name} = {(perf_counter() - start):.2f} sec")
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from torch import Tensor
|
|
3
|
+
|
|
4
|
+
from reflectorch.data_generation.priors.utils import uniform_sampler
|
|
5
|
+
from reflectorch.data_generation.priors.subprior_sampler import UniformSubPriorParams
|
|
6
|
+
from reflectorch.data_generation.priors.params import Params
|
|
7
|
+
from reflectorch.data_generation.likelihoods import LogLikelihood
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def simple_sampler_solution(
|
|
11
|
+
likelihood: LogLikelihood,
|
|
12
|
+
predicted_params: UniformSubPriorParams,
|
|
13
|
+
total_min_bounds: Tensor,
|
|
14
|
+
total_max_bounds: Tensor,
|
|
15
|
+
num: int = 2 ** 15,
|
|
16
|
+
coef: float = 0.1,
|
|
17
|
+
) -> UniformSubPriorParams:
|
|
18
|
+
sampled_params_t = sample_around_params(predicted_params, total_min_bounds, total_max_bounds, num=num, coef=coef)
|
|
19
|
+
sampled_params = Params.from_tensor(sampled_params_t)
|
|
20
|
+
return get_best_mse_param(sampled_params, likelihood, predicted_params.min_bounds, predicted_params.max_bounds)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def sample_around_params(predicted_params: UniformSubPriorParams,
|
|
24
|
+
total_min_bounds: Tensor,
|
|
25
|
+
total_max_bounds: Tensor,
|
|
26
|
+
num: int = 2 ** 15,
|
|
27
|
+
coef: float = 0.1,
|
|
28
|
+
) -> Tensor:
|
|
29
|
+
params_t = predicted_params.as_tensor(add_bounds=False)
|
|
30
|
+
|
|
31
|
+
delta = (predicted_params.max_bounds - predicted_params.min_bounds) * coef
|
|
32
|
+
min_bounds = torch.clamp(params_t - delta, total_min_bounds, total_max_bounds)
|
|
33
|
+
max_bounds = torch.clamp(params_t + delta, total_min_bounds, total_max_bounds)
|
|
34
|
+
|
|
35
|
+
sampled_params_t = uniform_sampler(min_bounds, max_bounds, num, params_t.shape[-1])
|
|
36
|
+
sampled_params_t[0] = params_t[0]
|
|
37
|
+
|
|
38
|
+
return sampled_params_t
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def get_best_mse_param(
|
|
42
|
+
params: Params,
|
|
43
|
+
likelihood: LogLikelihood,
|
|
44
|
+
min_bounds: Tensor = None,
|
|
45
|
+
max_bounds: Tensor = None,
|
|
46
|
+
):
|
|
47
|
+
sampled_curves = params.reflectivity(likelihood.q)
|
|
48
|
+
log_probs = likelihood.calc_log_likelihood(sampled_curves)
|
|
49
|
+
best_idx = torch.argmax(log_probs)
|
|
50
|
+
best_param = params[best_idx:best_idx + 1]
|
|
51
|
+
|
|
52
|
+
if min_bounds is not None:
|
|
53
|
+
best_param = UniformSubPriorParams.from_tensor(
|
|
54
|
+
torch.cat([best_param.as_tensor(), torch.atleast_2d(min_bounds), torch.atleast_2d(max_bounds)], -1)
|
|
55
|
+
)
|
|
56
|
+
return best_param
|