reflectorch 1.4.0__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of reflectorch might be problematic. Click here for more details.
- reflectorch/__init__.py +17 -17
- reflectorch/data_generation/__init__.py +128 -128
- reflectorch/data_generation/dataset.py +210 -210
- reflectorch/data_generation/likelihoods.py +80 -80
- reflectorch/data_generation/noise.py +470 -470
- reflectorch/data_generation/priors/__init__.py +60 -60
- reflectorch/data_generation/priors/base.py +55 -55
- reflectorch/data_generation/priors/exp_subprior_sampler.py +298 -298
- reflectorch/data_generation/priors/independent_priors.py +195 -195
- reflectorch/data_generation/priors/multilayer_models.py +311 -311
- reflectorch/data_generation/priors/multilayer_structures.py +104 -104
- reflectorch/data_generation/priors/no_constraints.py +206 -206
- reflectorch/data_generation/priors/parametric_models.py +841 -841
- reflectorch/data_generation/priors/parametric_subpriors.py +369 -369
- reflectorch/data_generation/priors/params.py +252 -252
- reflectorch/data_generation/priors/sampler_strategies.py +369 -369
- reflectorch/data_generation/priors/scaler_mixin.py +65 -65
- reflectorch/data_generation/priors/subprior_sampler.py +371 -371
- reflectorch/data_generation/priors/utils.py +118 -118
- reflectorch/data_generation/process_data.py +41 -41
- reflectorch/data_generation/q_generator.py +280 -280
- reflectorch/data_generation/reflectivity/__init__.py +102 -102
- reflectorch/data_generation/reflectivity/abeles.py +97 -97
- reflectorch/data_generation/reflectivity/kinematical.py +70 -70
- reflectorch/data_generation/reflectivity/memory_eff.py +105 -105
- reflectorch/data_generation/reflectivity/numpy_implementations.py +120 -120
- reflectorch/data_generation/reflectivity/smearing.py +138 -138
- reflectorch/data_generation/reflectivity/smearing_pointwise.py +109 -109
- reflectorch/data_generation/scale_curves.py +112 -112
- reflectorch/data_generation/smearing.py +98 -98
- reflectorch/data_generation/utils.py +223 -223
- reflectorch/extensions/jupyter/__init__.py +11 -6
- reflectorch/extensions/jupyter/api.py +85 -0
- reflectorch/extensions/jupyter/callbacks.py +34 -34
- reflectorch/extensions/jupyter/components.py +758 -0
- reflectorch/extensions/jupyter/custom_select.py +268 -0
- reflectorch/extensions/jupyter/log_widget.py +241 -0
- reflectorch/extensions/jupyter/model_selection.py +495 -0
- reflectorch/extensions/jupyter/plotly_plot_manager.py +329 -0
- reflectorch/extensions/jupyter/widget.py +625 -0
- reflectorch/extensions/matplotlib/__init__.py +5 -5
- reflectorch/extensions/matplotlib/losses.py +32 -32
- reflectorch/extensions/refnx/refnx_conversion.py +76 -76
- reflectorch/inference/__init__.py +28 -24
- reflectorch/inference/inference_model.py +847 -1374
- reflectorch/inference/input_interface.py +239 -0
- reflectorch/inference/loading_data.py +36 -36
- reflectorch/inference/multilayer_fitter.py +171 -171
- reflectorch/inference/multilayer_inference_model.py +193 -193
- reflectorch/inference/plotting.py +523 -516
- reflectorch/inference/preprocess_exp/__init__.py +6 -6
- reflectorch/inference/preprocess_exp/attenuation.py +36 -36
- reflectorch/inference/preprocess_exp/cut_with_q_ratio.py +31 -31
- reflectorch/inference/preprocess_exp/footprint.py +81 -81
- reflectorch/inference/preprocess_exp/interpolation.py +19 -19
- reflectorch/inference/preprocess_exp/normalize.py +21 -21
- reflectorch/inference/preprocess_exp/preprocess.py +121 -121
- reflectorch/inference/query_matcher.py +81 -81
- reflectorch/inference/record_time.py +43 -43
- reflectorch/inference/sampler_solution.py +56 -56
- reflectorch/inference/scipy_fitter.py +272 -262
- reflectorch/inference/torch_fitter.py +87 -87
- reflectorch/ml/__init__.py +32 -32
- reflectorch/ml/basic_trainer.py +292 -292
- reflectorch/ml/callbacks.py +80 -80
- reflectorch/ml/dataloaders.py +26 -26
- reflectorch/ml/loggers.py +55 -55
- reflectorch/ml/schedulers.py +355 -355
- reflectorch/ml/trainers.py +200 -200
- reflectorch/ml/utils.py +2 -2
- reflectorch/models/__init__.py +15 -15
- reflectorch/models/activations.py +50 -50
- reflectorch/models/encoders/__init__.py +19 -19
- reflectorch/models/encoders/conv_encoder.py +218 -218
- reflectorch/models/encoders/conv_res_net.py +115 -115
- reflectorch/models/encoders/fno.py +133 -133
- reflectorch/models/encoders/integral_kernel_embedding.py +389 -389
- reflectorch/models/networks/__init__.py +14 -14
- reflectorch/models/networks/mlp_networks.py +434 -434
- reflectorch/models/networks/residual_net.py +156 -156
- reflectorch/paths.py +29 -27
- reflectorch/runs/__init__.py +31 -31
- reflectorch/runs/config.py +25 -25
- reflectorch/runs/slurm_utils.py +93 -93
- reflectorch/runs/train.py +78 -78
- reflectorch/runs/utils.py +404 -404
- reflectorch/test_config.py +4 -4
- reflectorch/train.py +4 -4
- reflectorch/train_on_cluster.py +4 -4
- reflectorch/utils.py +97 -97
- {reflectorch-1.4.0.dist-info → reflectorch-1.5.0.dist-info}/METADATA +129 -126
- reflectorch-1.5.0.dist-info/RECORD +96 -0
- {reflectorch-1.4.0.dist-info → reflectorch-1.5.0.dist-info}/licenses/LICENSE.txt +20 -20
- reflectorch-1.4.0.dist-info/RECORD +0 -88
- {reflectorch-1.4.0.dist-info → reflectorch-1.5.0.dist-info}/WHEEL +0 -0
- {reflectorch-1.4.0.dist-info → reflectorch-1.5.0.dist-info}/top_level.txt +0 -0
|
@@ -1,82 +1,82 @@
|
|
|
1
|
-
import os
|
|
2
|
-
import tempfile
|
|
3
|
-
import yaml
|
|
4
|
-
from huggingface_hub import hf_hub_download, list_repo_files
|
|
5
|
-
|
|
6
|
-
class HuggingfaceQueryMatcher:
|
|
7
|
-
"""Downloads the available configurations files to a temporary directory and provides functionality for filtering those configuration files matching user specified queries.
|
|
8
|
-
|
|
9
|
-
Args:
|
|
10
|
-
repo_id (str): The Hugging Face repository ID.
|
|
11
|
-
config_dir (str): Directory within the repo where YAML files are stored.
|
|
12
|
-
"""
|
|
13
|
-
def __init__(self, repo_id='valentinsingularity/reflectivity', config_dir='configs'):
|
|
14
|
-
self.repo_id = repo_id
|
|
15
|
-
self.config_dir = config_dir
|
|
16
|
-
self.cache = {
|
|
17
|
-
'parsed_configs': None,
|
|
18
|
-
'temp_dir': None
|
|
19
|
-
}
|
|
20
|
-
self._renew_cache()
|
|
21
|
-
|
|
22
|
-
def _renew_cache(self):
|
|
23
|
-
temp_dir = tempfile.mkdtemp()
|
|
24
|
-
print(f"Temporary directory created at: {temp_dir}")
|
|
25
|
-
|
|
26
|
-
repo_files = list_repo_files(self.repo_id, repo_type='model')
|
|
27
|
-
config_files = [file for file in repo_files if file.startswith(self.config_dir) and file.endswith('.yaml')]
|
|
28
|
-
|
|
29
|
-
downloaded_files = []
|
|
30
|
-
for file in config_files:
|
|
31
|
-
file_path = hf_hub_download(repo_id=self.repo_id, filename=file, local_dir=temp_dir, repo_type='model')
|
|
32
|
-
downloaded_files.append(file_path)
|
|
33
|
-
|
|
34
|
-
parsed_configs = {}
|
|
35
|
-
for file_path in downloaded_files:
|
|
36
|
-
with open(file_path, 'r') as file:
|
|
37
|
-
config_data = yaml.safe_load(file)
|
|
38
|
-
file_name = os.path.basename(file_path)
|
|
39
|
-
parsed_configs[file_name] = config_data
|
|
40
|
-
|
|
41
|
-
self.cache['parsed_configs'] = parsed_configs
|
|
42
|
-
self.cache['temp_dir'] = temp_dir
|
|
43
|
-
|
|
44
|
-
def get_matching_configs(self, query):
|
|
45
|
-
"""retrieves configuration files that match the user specified query.
|
|
46
|
-
|
|
47
|
-
Args:
|
|
48
|
-
query (dict): Dictionary of key-value pairs to filter configurations, e.g. ``query = {'dset.prior_sampler.kwargs.max_num_layers': 3, 'dset.prior_sampler.kwargs.param_ranges.slds': [0., 100.]}``.
|
|
49
|
-
For keys containing the ``param_ranges`` subkey a configuration is selected if the value of the query (i.e. desired parameter range)
|
|
50
|
-
is a subrange of the parameter range in the configuration, in all other cases the values must match exactly.
|
|
51
|
-
|
|
52
|
-
Returns:
|
|
53
|
-
list: List of file names that match the query.
|
|
54
|
-
"""
|
|
55
|
-
|
|
56
|
-
filtered_configs = []
|
|
57
|
-
|
|
58
|
-
for file_name, config_data in self.cache['parsed_configs'].items():
|
|
59
|
-
if self.matches_query(config_data, query):
|
|
60
|
-
filtered_configs.append(file_name)
|
|
61
|
-
|
|
62
|
-
return filtered_configs
|
|
63
|
-
|
|
64
|
-
def matches_query(self, config_data, query):
|
|
65
|
-
for q_key, q_value in query.items():
|
|
66
|
-
keys = q_key.split('.')
|
|
67
|
-
value = self.deep_get(config_data, keys)
|
|
68
|
-
if 'param_ranges' in keys:
|
|
69
|
-
if q_value[0] < value[0] or q_value[1] > value[1]:
|
|
70
|
-
return False
|
|
71
|
-
else:
|
|
72
|
-
if value != q_value:
|
|
73
|
-
return False
|
|
74
|
-
|
|
75
|
-
return True
|
|
76
|
-
|
|
77
|
-
def deep_get(self, d, keys):
|
|
78
|
-
for key in keys:
|
|
79
|
-
if isinstance(d, dict):
|
|
80
|
-
d = d.get(key, None)
|
|
81
|
-
|
|
1
|
+
import os
|
|
2
|
+
import tempfile
|
|
3
|
+
import yaml
|
|
4
|
+
from huggingface_hub import hf_hub_download, list_repo_files
|
|
5
|
+
|
|
6
|
+
class HuggingfaceQueryMatcher:
|
|
7
|
+
"""Downloads the available configurations files to a temporary directory and provides functionality for filtering those configuration files matching user specified queries.
|
|
8
|
+
|
|
9
|
+
Args:
|
|
10
|
+
repo_id (str): The Hugging Face repository ID.
|
|
11
|
+
config_dir (str): Directory within the repo where YAML files are stored.
|
|
12
|
+
"""
|
|
13
|
+
def __init__(self, repo_id='valentinsingularity/reflectivity', config_dir='configs'):
|
|
14
|
+
self.repo_id = repo_id
|
|
15
|
+
self.config_dir = config_dir
|
|
16
|
+
self.cache = {
|
|
17
|
+
'parsed_configs': None,
|
|
18
|
+
'temp_dir': None
|
|
19
|
+
}
|
|
20
|
+
self._renew_cache()
|
|
21
|
+
|
|
22
|
+
def _renew_cache(self):
|
|
23
|
+
temp_dir = tempfile.mkdtemp()
|
|
24
|
+
print(f"Temporary directory created at: {temp_dir}")
|
|
25
|
+
|
|
26
|
+
repo_files = list_repo_files(self.repo_id, repo_type='model')
|
|
27
|
+
config_files = [file for file in repo_files if file.startswith(self.config_dir) and file.endswith('.yaml')]
|
|
28
|
+
|
|
29
|
+
downloaded_files = []
|
|
30
|
+
for file in config_files:
|
|
31
|
+
file_path = hf_hub_download(repo_id=self.repo_id, filename=file, local_dir=temp_dir, repo_type='model')
|
|
32
|
+
downloaded_files.append(file_path)
|
|
33
|
+
|
|
34
|
+
parsed_configs = {}
|
|
35
|
+
for file_path in downloaded_files:
|
|
36
|
+
with open(file_path, 'r') as file:
|
|
37
|
+
config_data = yaml.safe_load(file)
|
|
38
|
+
file_name = os.path.basename(file_path)
|
|
39
|
+
parsed_configs[file_name] = config_data
|
|
40
|
+
|
|
41
|
+
self.cache['parsed_configs'] = parsed_configs
|
|
42
|
+
self.cache['temp_dir'] = temp_dir
|
|
43
|
+
|
|
44
|
+
def get_matching_configs(self, query):
|
|
45
|
+
"""retrieves configuration files that match the user specified query.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
query (dict): Dictionary of key-value pairs to filter configurations, e.g. ``query = {'dset.prior_sampler.kwargs.max_num_layers': 3, 'dset.prior_sampler.kwargs.param_ranges.slds': [0., 100.]}``.
|
|
49
|
+
For keys containing the ``param_ranges`` subkey a configuration is selected if the value of the query (i.e. desired parameter range)
|
|
50
|
+
is a subrange of the parameter range in the configuration, in all other cases the values must match exactly.
|
|
51
|
+
|
|
52
|
+
Returns:
|
|
53
|
+
list: List of file names that match the query.
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
filtered_configs = []
|
|
57
|
+
|
|
58
|
+
for file_name, config_data in self.cache['parsed_configs'].items():
|
|
59
|
+
if self.matches_query(config_data, query):
|
|
60
|
+
filtered_configs.append(file_name)
|
|
61
|
+
|
|
62
|
+
return filtered_configs
|
|
63
|
+
|
|
64
|
+
def matches_query(self, config_data, query):
|
|
65
|
+
for q_key, q_value in query.items():
|
|
66
|
+
keys = q_key.split('.')
|
|
67
|
+
value = self.deep_get(config_data, keys)
|
|
68
|
+
if 'param_ranges' in keys:
|
|
69
|
+
if q_value[0] < value[0] or q_value[1] > value[1]:
|
|
70
|
+
return False
|
|
71
|
+
else:
|
|
72
|
+
if value != q_value:
|
|
73
|
+
return False
|
|
74
|
+
|
|
75
|
+
return True
|
|
76
|
+
|
|
77
|
+
def deep_get(self, d, keys):
|
|
78
|
+
for key in keys:
|
|
79
|
+
if isinstance(d, dict):
|
|
80
|
+
d = d.get(key, None)
|
|
81
|
+
|
|
82
82
|
return d
|
|
@@ -1,43 +1,43 @@
|
|
|
1
|
-
from time import perf_counter
|
|
2
|
-
from contextlib import contextmanager
|
|
3
|
-
from functools import wraps
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
class EvaluateTime(list):
|
|
7
|
-
@contextmanager
|
|
8
|
-
def __call__(self, name: str, *args, **kwargs):
|
|
9
|
-
start = perf_counter()
|
|
10
|
-
yield
|
|
11
|
-
self.action(perf_counter() - start, name, *args, **kwargs)
|
|
12
|
-
|
|
13
|
-
@staticmethod
|
|
14
|
-
def action(delta_time, name, *args, **kwargs):
|
|
15
|
-
print(f"Time for {name} = {delta_time:.2f} sec")
|
|
16
|
-
|
|
17
|
-
def __repr__(self):
|
|
18
|
-
return f'EvaluateTime(total={sum(self)}, num_records={len(self)})'
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
def print_time(name: str or callable):
|
|
22
|
-
if isinstance(name, str):
|
|
23
|
-
return _print_time_context(name)
|
|
24
|
-
else:
|
|
25
|
-
return _print_time_wrap(name)
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
def _print_time_wrap(func, name: str = None):
|
|
29
|
-
name = name or func.__name__
|
|
30
|
-
|
|
31
|
-
@wraps(func)
|
|
32
|
-
def wrapped_func(*args, **kwargs):
|
|
33
|
-
with _print_time_context(name):
|
|
34
|
-
return func(*args, **kwargs)
|
|
35
|
-
|
|
36
|
-
return wrapped_func
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
@contextmanager
|
|
40
|
-
def _print_time_context(name: str):
|
|
41
|
-
start = perf_counter()
|
|
42
|
-
yield
|
|
43
|
-
print(f"Time for {name} = {(perf_counter() - start):.2f} sec")
|
|
1
|
+
from time import perf_counter
|
|
2
|
+
from contextlib import contextmanager
|
|
3
|
+
from functools import wraps
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class EvaluateTime(list):
|
|
7
|
+
@contextmanager
|
|
8
|
+
def __call__(self, name: str, *args, **kwargs):
|
|
9
|
+
start = perf_counter()
|
|
10
|
+
yield
|
|
11
|
+
self.action(perf_counter() - start, name, *args, **kwargs)
|
|
12
|
+
|
|
13
|
+
@staticmethod
|
|
14
|
+
def action(delta_time, name, *args, **kwargs):
|
|
15
|
+
print(f"Time for {name} = {delta_time:.2f} sec")
|
|
16
|
+
|
|
17
|
+
def __repr__(self):
|
|
18
|
+
return f'EvaluateTime(total={sum(self)}, num_records={len(self)})'
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def print_time(name: str or callable):
|
|
22
|
+
if isinstance(name, str):
|
|
23
|
+
return _print_time_context(name)
|
|
24
|
+
else:
|
|
25
|
+
return _print_time_wrap(name)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def _print_time_wrap(func, name: str = None):
|
|
29
|
+
name = name or func.__name__
|
|
30
|
+
|
|
31
|
+
@wraps(func)
|
|
32
|
+
def wrapped_func(*args, **kwargs):
|
|
33
|
+
with _print_time_context(name):
|
|
34
|
+
return func(*args, **kwargs)
|
|
35
|
+
|
|
36
|
+
return wrapped_func
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
@contextmanager
|
|
40
|
+
def _print_time_context(name: str):
|
|
41
|
+
start = perf_counter()
|
|
42
|
+
yield
|
|
43
|
+
print(f"Time for {name} = {(perf_counter() - start):.2f} sec")
|
|
@@ -1,56 +1,56 @@
|
|
|
1
|
-
import torch
|
|
2
|
-
from torch import Tensor
|
|
3
|
-
|
|
4
|
-
from reflectorch.data_generation.priors.utils import uniform_sampler
|
|
5
|
-
from reflectorch.data_generation.priors.subprior_sampler import UniformSubPriorParams
|
|
6
|
-
from reflectorch.data_generation.priors.params import Params
|
|
7
|
-
from reflectorch.data_generation.likelihoods import LogLikelihood
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
def simple_sampler_solution(
|
|
11
|
-
likelihood: LogLikelihood,
|
|
12
|
-
predicted_params: UniformSubPriorParams,
|
|
13
|
-
total_min_bounds: Tensor,
|
|
14
|
-
total_max_bounds: Tensor,
|
|
15
|
-
num: int = 2 ** 15,
|
|
16
|
-
coef: float = 0.1,
|
|
17
|
-
) -> UniformSubPriorParams:
|
|
18
|
-
sampled_params_t = sample_around_params(predicted_params, total_min_bounds, total_max_bounds, num=num, coef=coef)
|
|
19
|
-
sampled_params = Params.from_tensor(sampled_params_t)
|
|
20
|
-
return get_best_mse_param(sampled_params, likelihood, predicted_params.min_bounds, predicted_params.max_bounds)
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
def sample_around_params(predicted_params: UniformSubPriorParams,
|
|
24
|
-
total_min_bounds: Tensor,
|
|
25
|
-
total_max_bounds: Tensor,
|
|
26
|
-
num: int = 2 ** 15,
|
|
27
|
-
coef: float = 0.1,
|
|
28
|
-
) -> Tensor:
|
|
29
|
-
params_t = predicted_params.as_tensor(add_bounds=False)
|
|
30
|
-
|
|
31
|
-
delta = (predicted_params.max_bounds - predicted_params.min_bounds) * coef
|
|
32
|
-
min_bounds = torch.clamp(params_t - delta, total_min_bounds, total_max_bounds)
|
|
33
|
-
max_bounds = torch.clamp(params_t + delta, total_min_bounds, total_max_bounds)
|
|
34
|
-
|
|
35
|
-
sampled_params_t = uniform_sampler(min_bounds, max_bounds, num, params_t.shape[-1])
|
|
36
|
-
sampled_params_t[0] = params_t[0]
|
|
37
|
-
|
|
38
|
-
return sampled_params_t
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
def get_best_mse_param(
|
|
42
|
-
params: Params,
|
|
43
|
-
likelihood: LogLikelihood,
|
|
44
|
-
min_bounds: Tensor = None,
|
|
45
|
-
max_bounds: Tensor = None,
|
|
46
|
-
):
|
|
47
|
-
sampled_curves = params.reflectivity(likelihood.q)
|
|
48
|
-
log_probs = likelihood.calc_log_likelihood(sampled_curves)
|
|
49
|
-
best_idx = torch.argmax(log_probs)
|
|
50
|
-
best_param = params[best_idx:best_idx + 1]
|
|
51
|
-
|
|
52
|
-
if min_bounds is not None:
|
|
53
|
-
best_param = UniformSubPriorParams.from_tensor(
|
|
54
|
-
torch.cat([best_param.as_tensor(), torch.atleast_2d(min_bounds), torch.atleast_2d(max_bounds)], -1)
|
|
55
|
-
)
|
|
56
|
-
return best_param
|
|
1
|
+
import torch
|
|
2
|
+
from torch import Tensor
|
|
3
|
+
|
|
4
|
+
from reflectorch.data_generation.priors.utils import uniform_sampler
|
|
5
|
+
from reflectorch.data_generation.priors.subprior_sampler import UniformSubPriorParams
|
|
6
|
+
from reflectorch.data_generation.priors.params import Params
|
|
7
|
+
from reflectorch.data_generation.likelihoods import LogLikelihood
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def simple_sampler_solution(
|
|
11
|
+
likelihood: LogLikelihood,
|
|
12
|
+
predicted_params: UniformSubPriorParams,
|
|
13
|
+
total_min_bounds: Tensor,
|
|
14
|
+
total_max_bounds: Tensor,
|
|
15
|
+
num: int = 2 ** 15,
|
|
16
|
+
coef: float = 0.1,
|
|
17
|
+
) -> UniformSubPriorParams:
|
|
18
|
+
sampled_params_t = sample_around_params(predicted_params, total_min_bounds, total_max_bounds, num=num, coef=coef)
|
|
19
|
+
sampled_params = Params.from_tensor(sampled_params_t)
|
|
20
|
+
return get_best_mse_param(sampled_params, likelihood, predicted_params.min_bounds, predicted_params.max_bounds)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def sample_around_params(predicted_params: UniformSubPriorParams,
|
|
24
|
+
total_min_bounds: Tensor,
|
|
25
|
+
total_max_bounds: Tensor,
|
|
26
|
+
num: int = 2 ** 15,
|
|
27
|
+
coef: float = 0.1,
|
|
28
|
+
) -> Tensor:
|
|
29
|
+
params_t = predicted_params.as_tensor(add_bounds=False)
|
|
30
|
+
|
|
31
|
+
delta = (predicted_params.max_bounds - predicted_params.min_bounds) * coef
|
|
32
|
+
min_bounds = torch.clamp(params_t - delta, total_min_bounds, total_max_bounds)
|
|
33
|
+
max_bounds = torch.clamp(params_t + delta, total_min_bounds, total_max_bounds)
|
|
34
|
+
|
|
35
|
+
sampled_params_t = uniform_sampler(min_bounds, max_bounds, num, params_t.shape[-1])
|
|
36
|
+
sampled_params_t[0] = params_t[0]
|
|
37
|
+
|
|
38
|
+
return sampled_params_t
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def get_best_mse_param(
|
|
42
|
+
params: Params,
|
|
43
|
+
likelihood: LogLikelihood,
|
|
44
|
+
min_bounds: Tensor = None,
|
|
45
|
+
max_bounds: Tensor = None,
|
|
46
|
+
):
|
|
47
|
+
sampled_curves = params.reflectivity(likelihood.q)
|
|
48
|
+
log_probs = likelihood.calc_log_likelihood(sampled_curves)
|
|
49
|
+
best_idx = torch.argmax(log_probs)
|
|
50
|
+
best_param = params[best_idx:best_idx + 1]
|
|
51
|
+
|
|
52
|
+
if min_bounds is not None:
|
|
53
|
+
best_param = UniformSubPriorParams.from_tensor(
|
|
54
|
+
torch.cat([best_param.as_tensor(), torch.atleast_2d(min_bounds), torch.atleast_2d(max_bounds)], -1)
|
|
55
|
+
)
|
|
56
|
+
return best_param
|