reflectorch 1.3.0__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of reflectorch might be problematic. Click here for more details.
- reflectorch/__init__.py +17 -17
- reflectorch/data_generation/__init__.py +128 -126
- reflectorch/data_generation/dataset.py +210 -210
- reflectorch/data_generation/likelihoods.py +80 -80
- reflectorch/data_generation/noise.py +470 -470
- reflectorch/data_generation/priors/__init__.py +60 -60
- reflectorch/data_generation/priors/base.py +55 -55
- reflectorch/data_generation/priors/exp_subprior_sampler.py +298 -298
- reflectorch/data_generation/priors/independent_priors.py +195 -195
- reflectorch/data_generation/priors/multilayer_models.py +311 -311
- reflectorch/data_generation/priors/multilayer_structures.py +104 -104
- reflectorch/data_generation/priors/no_constraints.py +206 -206
- reflectorch/data_generation/priors/parametric_models.py +841 -841
- reflectorch/data_generation/priors/parametric_subpriors.py +369 -369
- reflectorch/data_generation/priors/params.py +252 -252
- reflectorch/data_generation/priors/sampler_strategies.py +369 -369
- reflectorch/data_generation/priors/scaler_mixin.py +65 -65
- reflectorch/data_generation/priors/subprior_sampler.py +371 -371
- reflectorch/data_generation/priors/utils.py +118 -118
- reflectorch/data_generation/process_data.py +41 -41
- reflectorch/data_generation/q_generator.py +280 -246
- reflectorch/data_generation/reflectivity/__init__.py +102 -102
- reflectorch/data_generation/reflectivity/abeles.py +97 -97
- reflectorch/data_generation/reflectivity/kinematical.py +70 -70
- reflectorch/data_generation/reflectivity/memory_eff.py +105 -105
- reflectorch/data_generation/reflectivity/numpy_implementations.py +120 -120
- reflectorch/data_generation/reflectivity/smearing.py +138 -138
- reflectorch/data_generation/reflectivity/smearing_pointwise.py +109 -109
- reflectorch/data_generation/scale_curves.py +112 -112
- reflectorch/data_generation/smearing.py +98 -98
- reflectorch/data_generation/utils.py +223 -222
- reflectorch/extensions/jupyter/__init__.py +11 -6
- reflectorch/extensions/jupyter/api.py +85 -0
- reflectorch/extensions/jupyter/callbacks.py +34 -34
- reflectorch/extensions/jupyter/components.py +758 -0
- reflectorch/extensions/jupyter/custom_select.py +268 -0
- reflectorch/extensions/jupyter/log_widget.py +241 -0
- reflectorch/extensions/jupyter/model_selection.py +495 -0
- reflectorch/extensions/jupyter/plotly_plot_manager.py +329 -0
- reflectorch/extensions/jupyter/widget.py +625 -0
- reflectorch/extensions/matplotlib/__init__.py +5 -5
- reflectorch/extensions/matplotlib/losses.py +32 -32
- reflectorch/extensions/refnx/refnx_conversion.py +76 -76
- reflectorch/inference/__init__.py +28 -24
- reflectorch/inference/inference_model.py +847 -851
- reflectorch/inference/input_interface.py +239 -0
- reflectorch/inference/loading_data.py +37 -0
- reflectorch/inference/multilayer_fitter.py +171 -171
- reflectorch/inference/multilayer_inference_model.py +193 -193
- reflectorch/inference/plotting.py +524 -98
- reflectorch/inference/preprocess_exp/__init__.py +6 -6
- reflectorch/inference/preprocess_exp/attenuation.py +36 -36
- reflectorch/inference/preprocess_exp/cut_with_q_ratio.py +31 -31
- reflectorch/inference/preprocess_exp/footprint.py +81 -81
- reflectorch/inference/preprocess_exp/interpolation.py +19 -16
- reflectorch/inference/preprocess_exp/normalize.py +21 -21
- reflectorch/inference/preprocess_exp/preprocess.py +121 -121
- reflectorch/inference/query_matcher.py +81 -81
- reflectorch/inference/record_time.py +43 -43
- reflectorch/inference/sampler_solution.py +56 -56
- reflectorch/inference/scipy_fitter.py +272 -248
- reflectorch/inference/torch_fitter.py +87 -87
- reflectorch/ml/__init__.py +32 -32
- reflectorch/ml/basic_trainer.py +292 -292
- reflectorch/ml/callbacks.py +80 -80
- reflectorch/ml/dataloaders.py +26 -26
- reflectorch/ml/loggers.py +55 -55
- reflectorch/ml/schedulers.py +355 -355
- reflectorch/ml/trainers.py +200 -191
- reflectorch/ml/utils.py +2 -2
- reflectorch/models/__init__.py +15 -14
- reflectorch/models/activations.py +50 -50
- reflectorch/models/encoders/__init__.py +19 -17
- reflectorch/models/encoders/conv_encoder.py +218 -218
- reflectorch/models/encoders/conv_res_net.py +115 -115
- reflectorch/models/encoders/fno.py +133 -133
- reflectorch/models/encoders/integral_kernel_embedding.py +390 -0
- reflectorch/models/networks/__init__.py +14 -14
- reflectorch/models/networks/mlp_networks.py +434 -428
- reflectorch/models/networks/residual_net.py +156 -156
- reflectorch/paths.py +29 -27
- reflectorch/runs/__init__.py +31 -31
- reflectorch/runs/config.py +25 -25
- reflectorch/runs/slurm_utils.py +93 -93
- reflectorch/runs/train.py +78 -78
- reflectorch/runs/utils.py +404 -401
- reflectorch/test_config.py +4 -4
- reflectorch/train.py +4 -4
- reflectorch/train_on_cluster.py +4 -4
- reflectorch/utils.py +98 -68
- {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/METADATA +129 -125
- reflectorch-1.5.0.dist-info/RECORD +96 -0
- {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/WHEEL +1 -1
- {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/licenses/LICENSE.txt +20 -20
- reflectorch-1.3.0.dist-info/RECORD +0 -86
- {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/top_level.txt +0 -0
|
@@ -1,112 +1,112 @@
|
|
|
1
|
-
from pathlib import Path
|
|
2
|
-
|
|
3
|
-
import torch
|
|
4
|
-
from torch import Tensor
|
|
5
|
-
|
|
6
|
-
from reflectorch.data_generation.priors import PriorSampler
|
|
7
|
-
from reflectorch.paths import SAVED_MODELS_DIR
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
class CurvesScaler(object):
|
|
11
|
-
"""Base class for curve scalers"""
|
|
12
|
-
def scale(self, curves: Tensor):
|
|
13
|
-
raise NotImplementedError
|
|
14
|
-
|
|
15
|
-
def restore(self, curves: Tensor):
|
|
16
|
-
raise NotImplementedError
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
class LogAffineCurvesScaler(CurvesScaler):
|
|
20
|
-
""" Curve scaler which scales the reflectivity curves according to the logarithmic affine transformation:
|
|
21
|
-
:math:`\log_{10}(R + eps) \cdot weight + bias`.
|
|
22
|
-
|
|
23
|
-
Args:
|
|
24
|
-
weight (float): multiplication factor in the transformation
|
|
25
|
-
bias (float): addition term in the transformation
|
|
26
|
-
eps (float): sets the minimum intensity value of the reflectivity curves which is considered
|
|
27
|
-
"""
|
|
28
|
-
def __init__(self, weight: float = 0.1, bias: float = 0.5, eps: float = 1e-10):
|
|
29
|
-
self.weight = weight
|
|
30
|
-
self.bias = bias
|
|
31
|
-
self.eps = eps
|
|
32
|
-
|
|
33
|
-
def scale(self, curves: Tensor):
|
|
34
|
-
"""scales the reflectivity curves to a ML-friendly range
|
|
35
|
-
|
|
36
|
-
Args:
|
|
37
|
-
curves (Tensor): original reflectivity curves
|
|
38
|
-
|
|
39
|
-
Returns:
|
|
40
|
-
Tensor: reflectivity curves scaled to a ML-friendly range
|
|
41
|
-
"""
|
|
42
|
-
return torch.log10(curves + self.eps) * self.weight + self.bias
|
|
43
|
-
|
|
44
|
-
def restore(self, curves: Tensor):
|
|
45
|
-
"""restores the physical reflectivity curves
|
|
46
|
-
|
|
47
|
-
Args:
|
|
48
|
-
curves (Tensor): scaled reflectivity curves
|
|
49
|
-
|
|
50
|
-
Returns:
|
|
51
|
-
Tensor: reflectivity curves restored to the physical range
|
|
52
|
-
"""
|
|
53
|
-
return 10 ** ((curves - self.bias) / self.weight) - self.eps
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
class MeanNormalizationCurvesScaler(CurvesScaler):
|
|
57
|
-
"""Curve scaler which scales the reflectivity curves by the precomputed mean of a batch of curves
|
|
58
|
-
|
|
59
|
-
Args:
|
|
60
|
-
path (str, optional): path to the precomputed mean of the curves, only used if ``curves_mean`` is None. Defaults to None.
|
|
61
|
-
curves_mean (Tensor, optional): the precomputed mean of the curves. Defaults to None.
|
|
62
|
-
device (torch.device, optional): the Pytorch device. Defaults to 'cuda'.
|
|
63
|
-
"""
|
|
64
|
-
|
|
65
|
-
def __init__(self, path: str = None, curves_mean: Tensor = None, device: torch.device = 'cuda'):
|
|
66
|
-
if curves_mean is None:
|
|
67
|
-
curves_mean = torch.load(self.get_path(path))
|
|
68
|
-
self.curves_mean = curves_mean.to(device)
|
|
69
|
-
|
|
70
|
-
def scale(self, curves: Tensor):
|
|
71
|
-
"""scales the reflectivity curves to a ML-friendly range
|
|
72
|
-
|
|
73
|
-
Args:
|
|
74
|
-
curves (Tensor): original reflectivity curves
|
|
75
|
-
|
|
76
|
-
Returns:
|
|
77
|
-
Tensor: reflectivity curves scaled to a ML-friendly range
|
|
78
|
-
"""
|
|
79
|
-
self.curves_mean = self.curves_mean.to(curves)
|
|
80
|
-
return curves / self.curves_mean - 1
|
|
81
|
-
|
|
82
|
-
def restore(self, curves: Tensor):
|
|
83
|
-
"""restores the physical reflectivity curves
|
|
84
|
-
|
|
85
|
-
Args:
|
|
86
|
-
curves (Tensor): scaled reflectivity curves
|
|
87
|
-
|
|
88
|
-
Returns:
|
|
89
|
-
Tensor: reflectivity curves restored to the physical range
|
|
90
|
-
"""
|
|
91
|
-
self.curves_mean = self.curves_mean.to(curves)
|
|
92
|
-
return (curves + 1) * self.curves_mean
|
|
93
|
-
|
|
94
|
-
@staticmethod
|
|
95
|
-
def save(prior_sampler: PriorSampler, q: Tensor, path: str, num: int = 16384):
|
|
96
|
-
"""computes the mean of a batch of reflectivity curves and saves it
|
|
97
|
-
|
|
98
|
-
Args:
|
|
99
|
-
prior_sampler (PriorSampler): the prior sampler
|
|
100
|
-
q (Tensor): the q values
|
|
101
|
-
path (str): the path for saving the mean of the curves
|
|
102
|
-
num (int, optional): the number of curves used to compute the mean. Defaults to 16384.
|
|
103
|
-
"""
|
|
104
|
-
params = prior_sampler.sample(num)
|
|
105
|
-
curves_mean = params.reflectivity(q, log=False).mean(0).cpu()
|
|
106
|
-
torch.save(curves_mean, MeanNormalizationCurvesScaler.get_path(path))
|
|
107
|
-
|
|
108
|
-
@staticmethod
|
|
109
|
-
def get_path(path: str) -> Path:
|
|
110
|
-
if not path.endswith('.pt'):
|
|
111
|
-
path = path + '.pt'
|
|
112
|
-
return SAVED_MODELS_DIR / path
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from torch import Tensor
|
|
5
|
+
|
|
6
|
+
from reflectorch.data_generation.priors import PriorSampler
|
|
7
|
+
from reflectorch.paths import SAVED_MODELS_DIR
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class CurvesScaler(object):
|
|
11
|
+
"""Base class for curve scalers"""
|
|
12
|
+
def scale(self, curves: Tensor):
|
|
13
|
+
raise NotImplementedError
|
|
14
|
+
|
|
15
|
+
def restore(self, curves: Tensor):
|
|
16
|
+
raise NotImplementedError
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class LogAffineCurvesScaler(CurvesScaler):
|
|
20
|
+
""" Curve scaler which scales the reflectivity curves according to the logarithmic affine transformation:
|
|
21
|
+
:math:`\log_{10}(R + eps) \cdot weight + bias`.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
weight (float): multiplication factor in the transformation
|
|
25
|
+
bias (float): addition term in the transformation
|
|
26
|
+
eps (float): sets the minimum intensity value of the reflectivity curves which is considered
|
|
27
|
+
"""
|
|
28
|
+
def __init__(self, weight: float = 0.1, bias: float = 0.5, eps: float = 1e-10):
|
|
29
|
+
self.weight = weight
|
|
30
|
+
self.bias = bias
|
|
31
|
+
self.eps = eps
|
|
32
|
+
|
|
33
|
+
def scale(self, curves: Tensor):
|
|
34
|
+
"""scales the reflectivity curves to a ML-friendly range
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
curves (Tensor): original reflectivity curves
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
Tensor: reflectivity curves scaled to a ML-friendly range
|
|
41
|
+
"""
|
|
42
|
+
return torch.log10(curves + self.eps) * self.weight + self.bias
|
|
43
|
+
|
|
44
|
+
def restore(self, curves: Tensor):
|
|
45
|
+
"""restores the physical reflectivity curves
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
curves (Tensor): scaled reflectivity curves
|
|
49
|
+
|
|
50
|
+
Returns:
|
|
51
|
+
Tensor: reflectivity curves restored to the physical range
|
|
52
|
+
"""
|
|
53
|
+
return 10 ** ((curves - self.bias) / self.weight) - self.eps
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class MeanNormalizationCurvesScaler(CurvesScaler):
|
|
57
|
+
"""Curve scaler which scales the reflectivity curves by the precomputed mean of a batch of curves
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
path (str, optional): path to the precomputed mean of the curves, only used if ``curves_mean`` is None. Defaults to None.
|
|
61
|
+
curves_mean (Tensor, optional): the precomputed mean of the curves. Defaults to None.
|
|
62
|
+
device (torch.device, optional): the Pytorch device. Defaults to 'cuda'.
|
|
63
|
+
"""
|
|
64
|
+
|
|
65
|
+
def __init__(self, path: str = None, curves_mean: Tensor = None, device: torch.device = 'cuda'):
|
|
66
|
+
if curves_mean is None:
|
|
67
|
+
curves_mean = torch.load(self.get_path(path))
|
|
68
|
+
self.curves_mean = curves_mean.to(device)
|
|
69
|
+
|
|
70
|
+
def scale(self, curves: Tensor):
|
|
71
|
+
"""scales the reflectivity curves to a ML-friendly range
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
curves (Tensor): original reflectivity curves
|
|
75
|
+
|
|
76
|
+
Returns:
|
|
77
|
+
Tensor: reflectivity curves scaled to a ML-friendly range
|
|
78
|
+
"""
|
|
79
|
+
self.curves_mean = self.curves_mean.to(curves)
|
|
80
|
+
return curves / self.curves_mean - 1
|
|
81
|
+
|
|
82
|
+
def restore(self, curves: Tensor):
|
|
83
|
+
"""restores the physical reflectivity curves
|
|
84
|
+
|
|
85
|
+
Args:
|
|
86
|
+
curves (Tensor): scaled reflectivity curves
|
|
87
|
+
|
|
88
|
+
Returns:
|
|
89
|
+
Tensor: reflectivity curves restored to the physical range
|
|
90
|
+
"""
|
|
91
|
+
self.curves_mean = self.curves_mean.to(curves)
|
|
92
|
+
return (curves + 1) * self.curves_mean
|
|
93
|
+
|
|
94
|
+
@staticmethod
|
|
95
|
+
def save(prior_sampler: PriorSampler, q: Tensor, path: str, num: int = 16384):
|
|
96
|
+
"""computes the mean of a batch of reflectivity curves and saves it
|
|
97
|
+
|
|
98
|
+
Args:
|
|
99
|
+
prior_sampler (PriorSampler): the prior sampler
|
|
100
|
+
q (Tensor): the q values
|
|
101
|
+
path (str): the path for saving the mean of the curves
|
|
102
|
+
num (int, optional): the number of curves used to compute the mean. Defaults to 16384.
|
|
103
|
+
"""
|
|
104
|
+
params = prior_sampler.sample(num)
|
|
105
|
+
curves_mean = params.reflectivity(q, log=False).mean(0).cpu()
|
|
106
|
+
torch.save(curves_mean, MeanNormalizationCurvesScaler.get_path(path))
|
|
107
|
+
|
|
108
|
+
@staticmethod
|
|
109
|
+
def get_path(path: str) -> Path:
|
|
110
|
+
if not path.endswith('.pt'):
|
|
111
|
+
path = path + '.pt'
|
|
112
|
+
return SAVED_MODELS_DIR / path
|
|
@@ -1,99 +1,99 @@
|
|
|
1
|
-
import torch
|
|
2
|
-
from torch import Tensor
|
|
3
|
-
|
|
4
|
-
from reflectorch.data_generation.priors.parametric_subpriors import BasicParams
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
class Smearing(object):
|
|
8
|
-
"""Class which applies resolution smearing to the reflectivity curves.
|
|
9
|
-
The intensity at a q point will be the average of the intensities of neighbouring q points, weighted by a gaussian profile.
|
|
10
|
-
|
|
11
|
-
Args:
|
|
12
|
-
sigma_range (tuple, optional): the range for sampling the resolutions. Defaults to (0.01, 0.1).
|
|
13
|
-
constant_dq (bool, optional): if ``True`` the smearing is constant (the resolution is given by the constant dq at each point in the curve)
|
|
14
|
-
otherwise the smearing is linear (the resolution is given by the constant dq/q at each point in the curve). Defaults to True.
|
|
15
|
-
gauss_num (int, optional): the number of interpolating gaussian profiles. Defaults to 31.
|
|
16
|
-
share_smeared (float, optional): the share of curves in the batch for which the resolution smearing is applied. Defaults to 0.2.
|
|
17
|
-
"""
|
|
18
|
-
def __init__(self,
|
|
19
|
-
sigma_range: tuple = (0.01, 0.1),
|
|
20
|
-
constant_dq: bool = False,
|
|
21
|
-
gauss_num: int = 31,
|
|
22
|
-
share_smeared: float = 0.2,
|
|
23
|
-
):
|
|
24
|
-
self.sigma_min, self.sigma_max = sigma_range
|
|
25
|
-
self.sigma_delta = self.sigma_max - self.sigma_min
|
|
26
|
-
self.constant_dq = constant_dq
|
|
27
|
-
self.gauss_num = gauss_num
|
|
28
|
-
self.share_smeared = share_smeared
|
|
29
|
-
|
|
30
|
-
def __repr__(self):
|
|
31
|
-
return f'Smearing(({self.sigma_min}, {self.sigma_max})'
|
|
32
|
-
|
|
33
|
-
def generate_resolutions(self, batch_size: int, device=None, dtype=None):
|
|
34
|
-
num_smeared = int(batch_size * self.share_smeared)
|
|
35
|
-
if not num_smeared:
|
|
36
|
-
return None, None
|
|
37
|
-
dq = torch.rand(num_smeared, 1, device=device, dtype=dtype) * self.sigma_delta + self.sigma_min
|
|
38
|
-
indices = torch.zeros(batch_size, device=device, dtype=torch.bool)
|
|
39
|
-
indices[torch.randperm(batch_size, device=device)[:num_smeared]] = True
|
|
40
|
-
return dq, indices
|
|
41
|
-
|
|
42
|
-
def scale_resolutions(self, resolutions: Tensor) -> Tensor:
|
|
43
|
-
"""Scales the q-resolution values to [-1,1] range using the internal sigma range"""
|
|
44
|
-
sigma_min = 0.0 if self.share_smeared != 1.0 else self.sigma_min
|
|
45
|
-
return 2 * (resolutions - sigma_min) / (self.sigma_max - sigma_min) - 1
|
|
46
|
-
|
|
47
|
-
def get_curves(self, q_values: Tensor, params: BasicParams, refl_kwargs:dict = None):
|
|
48
|
-
refl_kwargs = refl_kwargs or {}
|
|
49
|
-
|
|
50
|
-
dq, indices = self.generate_resolutions(params.batch_size, device=params.device, dtype=params.dtype)
|
|
51
|
-
q_resolutions = torch.zeros(q_values.shape[0], 1, dtype=q_values.dtype, device=q_values.device)
|
|
52
|
-
|
|
53
|
-
if dq is None:
|
|
54
|
-
return params.reflectivity(q_values, **refl_kwargs), q_resolutions
|
|
55
|
-
|
|
56
|
-
refl_kwargs_not_smeared = {}
|
|
57
|
-
refl_kwargs_smeared = {}
|
|
58
|
-
for key, value in refl_kwargs.items():
|
|
59
|
-
if isinstance(value, torch.Tensor) and value.shape[0] == params.batch_size:
|
|
60
|
-
refl_kwargs_not_smeared[key] = value[~indices]
|
|
61
|
-
refl_kwargs_smeared[key] = value[indices]
|
|
62
|
-
else:
|
|
63
|
-
refl_kwargs_not_smeared[key] = value
|
|
64
|
-
refl_kwargs_smeared[key] = value
|
|
65
|
-
|
|
66
|
-
# Compute unsmeared reflectivity
|
|
67
|
-
if (~indices).sum().item():
|
|
68
|
-
if q_values.dim() == 2 and q_values.shape[0] > 1:
|
|
69
|
-
q = q_values[~indices]
|
|
70
|
-
else:
|
|
71
|
-
q = q_values
|
|
72
|
-
|
|
73
|
-
reflectivity_not_smeared = params[~indices].reflectivity(q, **refl_kwargs_not_smeared)
|
|
74
|
-
else:
|
|
75
|
-
reflectivity_not_smeared = None
|
|
76
|
-
|
|
77
|
-
# Compute smeared reflectivity
|
|
78
|
-
if indices.sum().item():
|
|
79
|
-
if q_values.dim() == 2 and q_values.shape[0] > 1:
|
|
80
|
-
q = q_values[indices]
|
|
81
|
-
else:
|
|
82
|
-
q = q_values
|
|
83
|
-
|
|
84
|
-
reflectivity_smeared = params[indices].reflectivity(
|
|
85
|
-
q, dq=dq, constant_dq=self.constant_dq, gauss_num=self.gauss_num, **refl_kwargs_smeared
|
|
86
|
-
)
|
|
87
|
-
else:
|
|
88
|
-
reflectivity_smeared = None
|
|
89
|
-
|
|
90
|
-
curves = torch.empty(params.batch_size, q_values.shape[-1], device=params.device, dtype=params.dtype)
|
|
91
|
-
|
|
92
|
-
if (~indices).sum().item():
|
|
93
|
-
curves[~indices] = reflectivity_not_smeared
|
|
94
|
-
|
|
95
|
-
curves[indices] = reflectivity_smeared
|
|
96
|
-
|
|
97
|
-
q_resolutions[indices] = dq
|
|
98
|
-
|
|
1
|
+
import torch
|
|
2
|
+
from torch import Tensor
|
|
3
|
+
|
|
4
|
+
from reflectorch.data_generation.priors.parametric_subpriors import BasicParams
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class Smearing(object):
|
|
8
|
+
"""Class which applies resolution smearing to the reflectivity curves.
|
|
9
|
+
The intensity at a q point will be the average of the intensities of neighbouring q points, weighted by a gaussian profile.
|
|
10
|
+
|
|
11
|
+
Args:
|
|
12
|
+
sigma_range (tuple, optional): the range for sampling the resolutions. Defaults to (0.01, 0.1).
|
|
13
|
+
constant_dq (bool, optional): if ``True`` the smearing is constant (the resolution is given by the constant dq at each point in the curve)
|
|
14
|
+
otherwise the smearing is linear (the resolution is given by the constant dq/q at each point in the curve). Defaults to True.
|
|
15
|
+
gauss_num (int, optional): the number of interpolating gaussian profiles. Defaults to 31.
|
|
16
|
+
share_smeared (float, optional): the share of curves in the batch for which the resolution smearing is applied. Defaults to 0.2.
|
|
17
|
+
"""
|
|
18
|
+
def __init__(self,
|
|
19
|
+
sigma_range: tuple = (0.01, 0.1),
|
|
20
|
+
constant_dq: bool = False,
|
|
21
|
+
gauss_num: int = 31,
|
|
22
|
+
share_smeared: float = 0.2,
|
|
23
|
+
):
|
|
24
|
+
self.sigma_min, self.sigma_max = sigma_range
|
|
25
|
+
self.sigma_delta = self.sigma_max - self.sigma_min
|
|
26
|
+
self.constant_dq = constant_dq
|
|
27
|
+
self.gauss_num = gauss_num
|
|
28
|
+
self.share_smeared = share_smeared
|
|
29
|
+
|
|
30
|
+
def __repr__(self):
|
|
31
|
+
return f'Smearing(({self.sigma_min}, {self.sigma_max})'
|
|
32
|
+
|
|
33
|
+
def generate_resolutions(self, batch_size: int, device=None, dtype=None):
|
|
34
|
+
num_smeared = int(batch_size * self.share_smeared)
|
|
35
|
+
if not num_smeared:
|
|
36
|
+
return None, None
|
|
37
|
+
dq = torch.rand(num_smeared, 1, device=device, dtype=dtype) * self.sigma_delta + self.sigma_min
|
|
38
|
+
indices = torch.zeros(batch_size, device=device, dtype=torch.bool)
|
|
39
|
+
indices[torch.randperm(batch_size, device=device)[:num_smeared]] = True
|
|
40
|
+
return dq, indices
|
|
41
|
+
|
|
42
|
+
def scale_resolutions(self, resolutions: Tensor) -> Tensor:
|
|
43
|
+
"""Scales the q-resolution values to [-1,1] range using the internal sigma range"""
|
|
44
|
+
sigma_min = 0.0 if self.share_smeared != 1.0 else self.sigma_min
|
|
45
|
+
return 2 * (resolutions - sigma_min) / (self.sigma_max - sigma_min) - 1
|
|
46
|
+
|
|
47
|
+
def get_curves(self, q_values: Tensor, params: BasicParams, refl_kwargs:dict = None):
|
|
48
|
+
refl_kwargs = refl_kwargs or {}
|
|
49
|
+
|
|
50
|
+
dq, indices = self.generate_resolutions(params.batch_size, device=params.device, dtype=params.dtype)
|
|
51
|
+
q_resolutions = torch.zeros(q_values.shape[0], 1, dtype=q_values.dtype, device=q_values.device)
|
|
52
|
+
|
|
53
|
+
if dq is None:
|
|
54
|
+
return params.reflectivity(q_values, **refl_kwargs), q_resolutions
|
|
55
|
+
|
|
56
|
+
refl_kwargs_not_smeared = {}
|
|
57
|
+
refl_kwargs_smeared = {}
|
|
58
|
+
for key, value in refl_kwargs.items():
|
|
59
|
+
if isinstance(value, torch.Tensor) and value.shape[0] == params.batch_size:
|
|
60
|
+
refl_kwargs_not_smeared[key] = value[~indices]
|
|
61
|
+
refl_kwargs_smeared[key] = value[indices]
|
|
62
|
+
else:
|
|
63
|
+
refl_kwargs_not_smeared[key] = value
|
|
64
|
+
refl_kwargs_smeared[key] = value
|
|
65
|
+
|
|
66
|
+
# Compute unsmeared reflectivity
|
|
67
|
+
if (~indices).sum().item():
|
|
68
|
+
if q_values.dim() == 2 and q_values.shape[0] > 1:
|
|
69
|
+
q = q_values[~indices]
|
|
70
|
+
else:
|
|
71
|
+
q = q_values
|
|
72
|
+
|
|
73
|
+
reflectivity_not_smeared = params[~indices].reflectivity(q, **refl_kwargs_not_smeared)
|
|
74
|
+
else:
|
|
75
|
+
reflectivity_not_smeared = None
|
|
76
|
+
|
|
77
|
+
# Compute smeared reflectivity
|
|
78
|
+
if indices.sum().item():
|
|
79
|
+
if q_values.dim() == 2 and q_values.shape[0] > 1:
|
|
80
|
+
q = q_values[indices]
|
|
81
|
+
else:
|
|
82
|
+
q = q_values
|
|
83
|
+
|
|
84
|
+
reflectivity_smeared = params[indices].reflectivity(
|
|
85
|
+
q, dq=dq, constant_dq=self.constant_dq, gauss_num=self.gauss_num, **refl_kwargs_smeared
|
|
86
|
+
)
|
|
87
|
+
else:
|
|
88
|
+
reflectivity_smeared = None
|
|
89
|
+
|
|
90
|
+
curves = torch.empty(params.batch_size, q_values.shape[-1], device=params.device, dtype=params.dtype)
|
|
91
|
+
|
|
92
|
+
if (~indices).sum().item():
|
|
93
|
+
curves[~indices] = reflectivity_not_smeared
|
|
94
|
+
|
|
95
|
+
curves[indices] = reflectivity_smeared
|
|
96
|
+
|
|
97
|
+
q_resolutions[indices] = dq
|
|
98
|
+
|
|
99
99
|
return curves, q_resolutions
|