reflectorch 1.2.1__py3-none-any.whl → 1.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of reflectorch might be problematic. Click here for more details.
- reflectorch/data_generation/__init__.py +2 -0
- reflectorch/data_generation/dataset.py +27 -7
- reflectorch/data_generation/noise.py +115 -9
- reflectorch/data_generation/priors/parametric_models.py +90 -15
- reflectorch/data_generation/priors/parametric_subpriors.py +28 -7
- reflectorch/data_generation/priors/sampler_strategies.py +67 -3
- reflectorch/data_generation/q_generator.py +31 -11
- reflectorch/data_generation/reflectivity/__init__.py +53 -11
- reflectorch/data_generation/reflectivity/kinematical.py +4 -5
- reflectorch/data_generation/reflectivity/smearing.py +25 -10
- reflectorch/data_generation/reflectivity/smearing_pointwise.py +110 -0
- reflectorch/data_generation/smearing.py +42 -11
- reflectorch/data_generation/utils.py +92 -18
- reflectorch/extensions/refnx/refnx_conversion.py +77 -0
- reflectorch/inference/inference_model.py +216 -103
- reflectorch/inference/plotting.py +98 -0
- reflectorch/inference/scipy_fitter.py +84 -7
- reflectorch/ml/__init__.py +2 -0
- reflectorch/ml/basic_trainer.py +18 -6
- reflectorch/ml/callbacks.py +5 -4
- reflectorch/ml/loggers.py +25 -0
- reflectorch/ml/schedulers.py +116 -0
- reflectorch/ml/trainers.py +122 -23
- reflectorch/models/__init__.py +1 -1
- reflectorch/models/encoders/__init__.py +0 -2
- reflectorch/models/encoders/conv_encoder.py +54 -40
- reflectorch/models/encoders/fno.py +23 -16
- reflectorch/models/networks/__init__.py +2 -0
- reflectorch/models/networks/mlp_networks.py +324 -152
- reflectorch/models/networks/residual_net.py +31 -5
- reflectorch/runs/train.py +0 -1
- reflectorch/runs/utils.py +43 -9
- {reflectorch-1.2.1.dist-info → reflectorch-1.3.0.dist-info}/METADATA +19 -17
- {reflectorch-1.2.1.dist-info → reflectorch-1.3.0.dist-info}/RECORD +37 -34
- {reflectorch-1.2.1.dist-info → reflectorch-1.3.0.dist-info}/WHEEL +1 -1
- {reflectorch-1.2.1.dist-info → reflectorch-1.3.0.dist-info/licenses}/LICENSE.txt +0 -0
- {reflectorch-1.2.1.dist-info → reflectorch-1.3.0.dist-info}/top_level.txt +0 -0
|
@@ -50,6 +50,8 @@ class ConstantQ(QGenerator):
|
|
|
50
50
|
q = q[1:]
|
|
51
51
|
else:
|
|
52
52
|
q = q[:-1]
|
|
53
|
+
self.q_min = q.min().item()
|
|
54
|
+
self.q_max = q.max().item()
|
|
53
55
|
self.q = q
|
|
54
56
|
|
|
55
57
|
def get_batch(self, batch_size: int, context: dict = None) -> Tensor:
|
|
@@ -63,14 +65,26 @@ class ConstantQ(QGenerator):
|
|
|
63
65
|
"""
|
|
64
66
|
return self.q.clone()[None].expand(batch_size, self.q.shape[0])
|
|
65
67
|
|
|
68
|
+
def scale_q(self, q):
|
|
69
|
+
"""Scales the q values to the range [-1, 1].
|
|
70
|
+
|
|
71
|
+
Args:
|
|
72
|
+
q (Tensor): unscaled q values
|
|
73
|
+
|
|
74
|
+
Returns:
|
|
75
|
+
Tensor: scaled q values
|
|
76
|
+
"""
|
|
77
|
+
scaled_q_01 = (q - self.q_min) / (self.q_max - self.q_min)
|
|
78
|
+
return 2.0 * (scaled_q_01 - 0.5)
|
|
79
|
+
|
|
66
80
|
|
|
67
81
|
class VariableQ(QGenerator):
|
|
68
82
|
"""Q generator for reflectivity curves with variable discretization
|
|
69
83
|
|
|
70
84
|
Args:
|
|
71
|
-
q_min_range (list, optional): the range for sampling the minimum q value of the curves,
|
|
72
|
-
q_max_range (list, optional): the range for sampling the maximum q value of the curves,
|
|
73
|
-
n_q_range (list, optional): the range for the number of points in the curves (equidistantly sampled between
|
|
85
|
+
q_min_range (list, optional): the range for sampling the minimum q value of the curves, q_min. Defaults to [0.01, 0.03].
|
|
86
|
+
q_max_range (list, optional): the range for sampling the maximum q value of the curves, q_max. Defaults to [0.1, 0.5].
|
|
87
|
+
n_q_range (list, optional): the range for the number of points in the curves (equidistantly sampled between q_min and q_max,
|
|
74
88
|
the number of points varies between batches but is constant within a batch). Defaults to [64, 256].
|
|
75
89
|
device (optional): the Pytorch device. Defaults to DEFAULT_DEVICE.
|
|
76
90
|
dtype (optional): the Pytorch data type. Defaults to DEFAULT_DTYPE.
|
|
@@ -80,12 +94,14 @@ class VariableQ(QGenerator):
|
|
|
80
94
|
q_min_range: Tuple[float, float] = (0.01, 0.03),
|
|
81
95
|
q_max_range: Tuple[float, float] = (0.1, 0.5),
|
|
82
96
|
n_q_range: Tuple[int, int] = (64, 256),
|
|
97
|
+
mode: str = 'equidistant',
|
|
83
98
|
device=DEFAULT_DEVICE,
|
|
84
99
|
dtype=DEFAULT_DTYPE,
|
|
85
100
|
):
|
|
86
101
|
self.q_min_range = q_min_range
|
|
87
102
|
self.q_max_range = q_max_range
|
|
88
103
|
self.n_q_range = n_q_range
|
|
104
|
+
self.mode = mode
|
|
89
105
|
self.device = device
|
|
90
106
|
self.dtype = dtype
|
|
91
107
|
|
|
@@ -98,14 +114,18 @@ class VariableQ(QGenerator):
|
|
|
98
114
|
Returns:
|
|
99
115
|
Tensor: generated batch of q values
|
|
100
116
|
"""
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
117
|
+
|
|
118
|
+
q_min = torch.rand(batch_size, device=self.device, dtype=self.dtype) * (self.q_min_range[1] - self.q_min_range[0]) + self.q_min_range[0]
|
|
119
|
+
q_max = torch.rand(batch_size, device=self.device, dtype=self.dtype) * (self.q_max_range[1] - self.q_max_range[0]) + self.q_max_range[0]
|
|
120
|
+
|
|
121
|
+
n_q = torch.randint(self.n_q_range[0], self.n_q_range[1] + 1, (1,), device=self.device).item()
|
|
122
|
+
|
|
123
|
+
if self.mode == 'equidistant':
|
|
124
|
+
q = torch.linspace(0, 1, n_q, device=self.device, dtype=self.dtype)
|
|
125
|
+
elif self.mode == 'random':
|
|
126
|
+
q = torch.rand(n_q, device=self.device, dtype=self.dtype).sort().values
|
|
127
|
+
|
|
128
|
+
q = q_min[:, None] + q * (q_max - q_min)[:, None]
|
|
109
129
|
|
|
110
130
|
return q
|
|
111
131
|
|
|
@@ -10,6 +10,7 @@ from reflectorch.data_generation.reflectivity.numpy_implementations import (
|
|
|
10
10
|
abeles_np,
|
|
11
11
|
)
|
|
12
12
|
from reflectorch.data_generation.reflectivity.smearing import abeles_constant_smearing
|
|
13
|
+
from reflectorch.data_generation.reflectivity.smearing_pointwise import abeles_pointwise_smearing
|
|
13
14
|
from reflectorch.data_generation.reflectivity.kinematical import kinematical_approximation
|
|
14
15
|
|
|
15
16
|
|
|
@@ -20,9 +21,15 @@ def reflectivity(
|
|
|
20
21
|
sld: Tensor,
|
|
21
22
|
dq: Tensor = None,
|
|
22
23
|
gauss_num: int = 51,
|
|
23
|
-
constant_dq: bool =
|
|
24
|
+
constant_dq: bool = False,
|
|
24
25
|
log: bool = False,
|
|
25
|
-
|
|
26
|
+
q_shift: Tensor = 0.0,
|
|
27
|
+
r_scale: Tensor = 1.0,
|
|
28
|
+
background: Tensor = 0.0,
|
|
29
|
+
solvent_vf = None,
|
|
30
|
+
solvent_mode = 'fronting',
|
|
31
|
+
abeles_func = None,
|
|
32
|
+
**abeles_kwargs
|
|
26
33
|
):
|
|
27
34
|
"""Function which computes the reflectivity curves from thin film parameters.
|
|
28
35
|
By default it uses the fast implementation of the Abeles matrix formalism.
|
|
@@ -37,24 +44,59 @@ def reflectivity(
|
|
|
37
44
|
Either dq if ``constant_dq`` is ``True`` or dq/q if ``constant_dq`` is ``False``. Defaults to None.
|
|
38
45
|
gauss_num (int, optional): the number of gaussians for curve smearing. Defaults to 51.
|
|
39
46
|
constant_dq (bool, optional): if ``True`` the smearing is constant (constant dq at each point in the curve)
|
|
40
|
-
otherwise the smearing is linear (constant dq/q at each point in the curve). Defaults to
|
|
47
|
+
otherwise the smearing is linear (constant dq/q at each point in the curve). Defaults to False.
|
|
41
48
|
log (bool, optional): if True the base 10 logarithm of the reflectivity curves is returned. Defaults to False.
|
|
49
|
+
q_shift (float or Tensor, optional): misalignment in q.
|
|
50
|
+
r_scale (float or Tensor, optional): normalization factor (scales reflectivity).
|
|
51
|
+
background (float or Tensor, optional): background intensity.
|
|
42
52
|
abeles_func (Callable, optional): a function implementing the simulation of the reflectivity curves, if different than the default Abeles matrix implementation ('abeles'). Defaults to None.
|
|
43
|
-
|
|
53
|
+
abeles_kwargs: Additional arguments specific to the chosen `abeles_func`.
|
|
44
54
|
Returns:
|
|
45
|
-
Tensor:
|
|
55
|
+
Tensor: the computed reflectivity curves
|
|
46
56
|
"""
|
|
47
57
|
abeles_func = abeles_func or abeles
|
|
48
|
-
q = torch.atleast_2d(q)
|
|
58
|
+
q = torch.atleast_2d(q) + q_shift
|
|
59
|
+
q = torch.clamp(q, min=0.0)
|
|
60
|
+
|
|
61
|
+
if solvent_vf is not None:
|
|
62
|
+
num_layers = thickness.shape[-1]
|
|
63
|
+
if solvent_mode == 'fronting':
|
|
64
|
+
assert sld.shape[-1] == num_layers + 2
|
|
65
|
+
assert solvent_vf.shape[-1] == num_layers
|
|
66
|
+
solvent_sld = sld[..., [0]]
|
|
67
|
+
idx = slice(1, num_layers)
|
|
68
|
+
sld[..., idx] = solvent_vf * solvent_sld + (1.0 - solvent_vf) * sld[..., idx]
|
|
69
|
+
elif solvent_mode == 'backing':
|
|
70
|
+
solvent_sld = sld[..., [-1]]
|
|
71
|
+
idx = slice(1, num_layers) if sld.shape[-1] == num_layers + 2 else slice(0, num_layers)
|
|
72
|
+
sld[..., idx] = solvent_vf * solvent_sld + (1.0 - solvent_vf) * sld[..., idx]
|
|
73
|
+
else:
|
|
74
|
+
raise NotImplementedError
|
|
49
75
|
|
|
50
76
|
if dq is None:
|
|
51
|
-
reflectivity_curves = abeles_func(q, thickness, roughness, sld)
|
|
77
|
+
reflectivity_curves = abeles_func(q, thickness, roughness, sld, **abeles_kwargs)
|
|
52
78
|
else:
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
79
|
+
if dq.shape[-1] > 1:
|
|
80
|
+
reflectivity_curves = abeles_pointwise_smearing(
|
|
81
|
+
q=q, dq=dq, thickness=thickness, roughness=roughness, sld=sld,
|
|
82
|
+
abeles_func=abeles_func, gauss_num=gauss_num,
|
|
83
|
+
**abeles_kwargs,
|
|
84
|
+
)
|
|
85
|
+
else:
|
|
86
|
+
reflectivity_curves = abeles_constant_smearing(
|
|
87
|
+
q, thickness, roughness, sld,
|
|
88
|
+
dq=dq, gauss_num=gauss_num, constant_dq=constant_dq, abeles_func=abeles_func,
|
|
89
|
+
**abeles_kwargs,
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
if isinstance(r_scale, Tensor):
|
|
93
|
+
r_scale = r_scale.view(-1, *[1] * (reflectivity_curves.dim() - 1))
|
|
94
|
+
if isinstance(background, Tensor):
|
|
95
|
+
background = background.view(-1, *[1] * (reflectivity_curves.dim() - 1))
|
|
96
|
+
|
|
97
|
+
reflectivity_curves = reflectivity_curves * r_scale + background
|
|
57
98
|
|
|
58
99
|
if log:
|
|
59
100
|
reflectivity_curves = torch.log10(reflectivity_curves)
|
|
101
|
+
|
|
60
102
|
return reflectivity_curves
|
|
@@ -51,7 +51,7 @@ def kinematical_approximation(
|
|
|
51
51
|
|
|
52
52
|
substrate_sld = sld[:, -1:]
|
|
53
53
|
|
|
54
|
-
rf =
|
|
54
|
+
rf = _get_fresnel_reflectivity(q, substrate_sld[:, None])
|
|
55
55
|
|
|
56
56
|
r = torch.clamp_max_(r * rf / substrate_sld.real ** 2, 1.)
|
|
57
57
|
|
|
@@ -61,12 +61,11 @@ def kinematical_approximation(
|
|
|
61
61
|
return r
|
|
62
62
|
|
|
63
63
|
|
|
64
|
-
def
|
|
65
|
-
_RE_CONST = 0.28174103675406496
|
|
64
|
+
def _get_fresnel_reflectivity(q, substrate_slds):
|
|
65
|
+
_RE_CONST = 0.28174103675406496 # 2/sqrt(16*pi)
|
|
66
66
|
|
|
67
67
|
q_c = torch.sqrt(substrate_slds + 0j) / _RE_CONST * 2
|
|
68
68
|
q_prime = torch.sqrt(q ** 2 - q_c ** 2 + 0j)
|
|
69
69
|
r_f = ((q - q_prime) / (q + q_prime)).abs().float() ** 2
|
|
70
70
|
|
|
71
|
-
return r_f.squeeze(-1)
|
|
72
|
-
|
|
71
|
+
return r_f.squeeze(-1)
|
|
@@ -14,26 +14,41 @@ def abeles_constant_smearing(
|
|
|
14
14
|
roughness: Tensor,
|
|
15
15
|
sld: Tensor,
|
|
16
16
|
dq: Tensor = None,
|
|
17
|
-
gauss_num: int =
|
|
18
|
-
constant_dq: bool =
|
|
17
|
+
gauss_num: int = 31,
|
|
18
|
+
constant_dq: bool = False,
|
|
19
19
|
abeles_func=None,
|
|
20
|
+
**abeles_kwargs
|
|
20
21
|
):
|
|
21
22
|
abeles_func = abeles_func or abeles
|
|
23
|
+
|
|
24
|
+
if dq.dtype != thickness.dtype:
|
|
25
|
+
q = q.to(thickness)
|
|
26
|
+
|
|
27
|
+
if dq.dtype != thickness.dtype:
|
|
28
|
+
dq = dq.to(thickness)
|
|
29
|
+
|
|
30
|
+
if q.shape[0] == 1:
|
|
31
|
+
q = q.repeat(thickness.shape[0], 1)
|
|
32
|
+
|
|
22
33
|
q_lin = _get_q_axes(q, dq, gauss_num, constant_dq=constant_dq)
|
|
23
34
|
kernels = _get_t_gauss_kernels(dq, gauss_num)
|
|
24
|
-
|
|
25
|
-
curves = abeles_func(q_lin, thickness, roughness, sld)
|
|
35
|
+
|
|
36
|
+
curves = abeles_func(q_lin, thickness, roughness, sld, **abeles_kwargs)
|
|
26
37
|
|
|
27
38
|
padding = (kernels.shape[-1] - 1) // 2
|
|
39
|
+
padded_curves = pad(curves, (padding, padding), 'reflect')
|
|
40
|
+
|
|
28
41
|
smeared_curves = conv1d(
|
|
29
|
-
|
|
30
|
-
)
|
|
42
|
+
padded_curves, kernels[:, None], groups=kernels.shape[0],
|
|
43
|
+
)
|
|
31
44
|
|
|
32
45
|
if q.shape[0] != smeared_curves.shape[0]:
|
|
33
|
-
|
|
34
|
-
|
|
46
|
+
repeat_factor = smeared_curves.shape[0] // q.shape[0]
|
|
47
|
+
q = q.repeat(repeat_factor, 1)
|
|
48
|
+
q_lin = q_lin.repeat(repeat_factor, 1)
|
|
49
|
+
|
|
35
50
|
smeared_curves = _batch_linear_interp1d(q_lin, smeared_curves, q)
|
|
36
|
-
|
|
51
|
+
|
|
37
52
|
return smeared_curves
|
|
38
53
|
|
|
39
54
|
|
|
@@ -55,7 +70,7 @@ def _get_t_gauss_kernels(resolutions: Tensor, gaussnum: int = 51):
|
|
|
55
70
|
return gauss_y
|
|
56
71
|
|
|
57
72
|
|
|
58
|
-
def _get_q_axes(q: Tensor, resolutions: Tensor, gaussnum: int = 51, constant_dq: bool =
|
|
73
|
+
def _get_q_axes(q: Tensor, resolutions: Tensor, gaussnum: int = 51, constant_dq: bool = False):
|
|
59
74
|
if constant_dq:
|
|
60
75
|
return _get_q_axes_for_constant_dq(q, resolutions, gaussnum)
|
|
61
76
|
else:
|
|
@@ -0,0 +1,110 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import scipy
|
|
3
|
+
import numpy as np
|
|
4
|
+
from functools import lru_cache
|
|
5
|
+
from typing import Tuple
|
|
6
|
+
|
|
7
|
+
from reflectorch.data_generation.reflectivity.abeles import abeles
|
|
8
|
+
|
|
9
|
+
#Pytorch version based on the JAX implementation of pointwise smearing in the refnx package.
|
|
10
|
+
|
|
11
|
+
@lru_cache(maxsize=128)
|
|
12
|
+
def gauss_legendre(n: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
13
|
+
"""
|
|
14
|
+
Calculate Gaussian quadrature abscissae and weights.
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
n (int): Gaussian quadrature order.
|
|
18
|
+
|
|
19
|
+
Returns:
|
|
20
|
+
Tuple[torch.Tensor, torch.Tensor]: The abscissae and weights for Gauss-Legendre integration.
|
|
21
|
+
"""
|
|
22
|
+
return scipy.special.p_roots(n)
|
|
23
|
+
|
|
24
|
+
def gauss(x: torch.Tensor) -> torch.Tensor:
|
|
25
|
+
"""
|
|
26
|
+
Calculate the Gaussian function.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
x (torch.Tensor): Input tensor.
|
|
30
|
+
|
|
31
|
+
Returns:
|
|
32
|
+
torch.Tensor: Output tensor after applying the Gaussian function.
|
|
33
|
+
"""
|
|
34
|
+
return torch.exp(-0.5 * x * x)
|
|
35
|
+
|
|
36
|
+
def abeles_pointwise_smearing(
|
|
37
|
+
q: torch.Tensor,
|
|
38
|
+
dq: torch.Tensor,
|
|
39
|
+
thickness: torch.Tensor,
|
|
40
|
+
roughness: torch.Tensor,
|
|
41
|
+
sld: torch.Tensor,
|
|
42
|
+
gauss_num: int = 17,
|
|
43
|
+
abeles_func=None,
|
|
44
|
+
**abeles_kwargs,
|
|
45
|
+
) -> torch.Tensor:
|
|
46
|
+
"""
|
|
47
|
+
Compute reflectivity with variable smearing using Gaussian quadrature.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
q (torch.Tensor): The momentum transfer (q) values.
|
|
51
|
+
dq (torch.Tensor): The resolution for curve smearing.
|
|
52
|
+
thickness (torch.Tensor): The layer thicknesses.
|
|
53
|
+
roughness (torch.Tensor): The interlayer roughnesses.
|
|
54
|
+
sld (torch.Tensor): The SLDs of the layers.
|
|
55
|
+
sld_magnetic (torch.Tensor, optional): The magnetic SLDs of the layers.
|
|
56
|
+
magnetization_angle (torch.Tensor, optional): The magnetization angles.
|
|
57
|
+
polarizer_eff (torch.Tensor, optional): The polarizer efficiency.
|
|
58
|
+
analyzer_eff (torch.Tensor, optional): The analyzer efficiency.
|
|
59
|
+
abeles_func (Callable, optional): A function implementing the simulation of the reflectivity curves.
|
|
60
|
+
gauss_num (int, optional): Gaussian quadrature order. Defaults to 17.
|
|
61
|
+
|
|
62
|
+
Returns:
|
|
63
|
+
torch.Tensor: The computed reflectivity curves.
|
|
64
|
+
"""
|
|
65
|
+
abeles_func = abeles_func or abeles
|
|
66
|
+
|
|
67
|
+
if q.shape[0] == 1:
|
|
68
|
+
q = q.repeat(thickness.shape[0], 1)
|
|
69
|
+
|
|
70
|
+
_FWHM = 2 * np.sqrt(2 * np.log(2.0))
|
|
71
|
+
_INTLIMIT = 3.5
|
|
72
|
+
|
|
73
|
+
bs = q.shape[0]
|
|
74
|
+
nq = q.shape[-1]
|
|
75
|
+
device = q.device
|
|
76
|
+
|
|
77
|
+
quad_order = gauss_num
|
|
78
|
+
abscissa, weights = gauss_legendre(quad_order)
|
|
79
|
+
abscissa = torch.tensor(abscissa)[None, :, None].to(device)
|
|
80
|
+
weights = torch.tensor(weights)[None, :, None].to(device)
|
|
81
|
+
prefactor = 1.0 / np.sqrt(2 * np.pi)
|
|
82
|
+
|
|
83
|
+
gaussvals = prefactor * gauss(abscissa * _INTLIMIT)
|
|
84
|
+
|
|
85
|
+
va = q[:, None, :] - _INTLIMIT * dq[:, None, :] / _FWHM
|
|
86
|
+
vb = q[:, None, :] + _INTLIMIT * dq[:, None, :] / _FWHM
|
|
87
|
+
|
|
88
|
+
qvals_for_res_0 = (abscissa * (vb - va) + vb + va) / 2
|
|
89
|
+
qvals_for_res = qvals_for_res_0.reshape(bs, -1)
|
|
90
|
+
|
|
91
|
+
refl_curves = abeles_func(
|
|
92
|
+
q=qvals_for_res,
|
|
93
|
+
thickness=thickness,
|
|
94
|
+
roughness=roughness,
|
|
95
|
+
sld=sld,
|
|
96
|
+
**abeles_kwargs
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
# Handle multiple channels
|
|
100
|
+
if refl_curves.dim() == 3:
|
|
101
|
+
n_channels = refl_curves.shape[1]
|
|
102
|
+
refl_curves = refl_curves.reshape(bs, n_channels, quad_order, nq)
|
|
103
|
+
refl_curves = refl_curves * gaussvals.unsqueeze(1) * weights.unsqueeze(1)
|
|
104
|
+
refl_curves = torch.sum(refl_curves, dim=2) * _INTLIMIT
|
|
105
|
+
else:
|
|
106
|
+
refl_curves = refl_curves.reshape(bs, quad_order, nq)
|
|
107
|
+
refl_curves = refl_curves * gaussvals * weights
|
|
108
|
+
refl_curves = torch.sum(refl_curves, dim=1) * _INTLIMIT
|
|
109
|
+
|
|
110
|
+
return refl_curves
|
|
@@ -9,15 +9,15 @@ class Smearing(object):
|
|
|
9
9
|
The intensity at a q point will be the average of the intensities of neighbouring q points, weighted by a gaussian profile.
|
|
10
10
|
|
|
11
11
|
Args:
|
|
12
|
-
sigma_range (tuple, optional): the range for sampling the resolutions. Defaults to (
|
|
12
|
+
sigma_range (tuple, optional): the range for sampling the resolutions. Defaults to (0.01, 0.1).
|
|
13
13
|
constant_dq (bool, optional): if ``True`` the smearing is constant (the resolution is given by the constant dq at each point in the curve)
|
|
14
14
|
otherwise the smearing is linear (the resolution is given by the constant dq/q at each point in the curve). Defaults to True.
|
|
15
15
|
gauss_num (int, optional): the number of interpolating gaussian profiles. Defaults to 31.
|
|
16
16
|
share_smeared (float, optional): the share of curves in the batch for which the resolution smearing is applied. Defaults to 0.2.
|
|
17
17
|
"""
|
|
18
18
|
def __init__(self,
|
|
19
|
-
sigma_range: tuple = (
|
|
20
|
-
constant_dq: bool =
|
|
19
|
+
sigma_range: tuple = (0.01, 0.1),
|
|
20
|
+
constant_dq: bool = False,
|
|
21
21
|
gauss_num: int = 31,
|
|
22
22
|
share_smeared: float = 0.2,
|
|
23
23
|
):
|
|
@@ -38,31 +38,62 @@ class Smearing(object):
|
|
|
38
38
|
indices = torch.zeros(batch_size, device=device, dtype=torch.bool)
|
|
39
39
|
indices[torch.randperm(batch_size, device=device)[:num_smeared]] = True
|
|
40
40
|
return dq, indices
|
|
41
|
+
|
|
42
|
+
def scale_resolutions(self, resolutions: Tensor) -> Tensor:
|
|
43
|
+
"""Scales the q-resolution values to [-1,1] range using the internal sigma range"""
|
|
44
|
+
sigma_min = 0.0 if self.share_smeared != 1.0 else self.sigma_min
|
|
45
|
+
return 2 * (resolutions - sigma_min) / (self.sigma_max - sigma_min) - 1
|
|
46
|
+
|
|
47
|
+
def get_curves(self, q_values: Tensor, params: BasicParams, refl_kwargs:dict = None):
|
|
48
|
+
refl_kwargs = refl_kwargs or {}
|
|
41
49
|
|
|
42
|
-
def get_curves(self, q_values: Tensor, params: BasicParams):
|
|
43
50
|
dq, indices = self.generate_resolutions(params.batch_size, device=params.device, dtype=params.dtype)
|
|
51
|
+
q_resolutions = torch.zeros(q_values.shape[0], 1, dtype=q_values.dtype, device=q_values.device)
|
|
44
52
|
|
|
45
53
|
if dq is None:
|
|
46
|
-
return params.reflectivity(q_values,
|
|
47
|
-
|
|
48
|
-
|
|
54
|
+
return params.reflectivity(q_values, **refl_kwargs), q_resolutions
|
|
55
|
+
|
|
56
|
+
refl_kwargs_not_smeared = {}
|
|
57
|
+
refl_kwargs_smeared = {}
|
|
58
|
+
for key, value in refl_kwargs.items():
|
|
59
|
+
if isinstance(value, torch.Tensor) and value.shape[0] == params.batch_size:
|
|
60
|
+
refl_kwargs_not_smeared[key] = value[~indices]
|
|
61
|
+
refl_kwargs_smeared[key] = value[indices]
|
|
62
|
+
else:
|
|
63
|
+
refl_kwargs_not_smeared[key] = value
|
|
64
|
+
refl_kwargs_smeared[key] = value
|
|
49
65
|
|
|
66
|
+
# Compute unsmeared reflectivity
|
|
50
67
|
if (~indices).sum().item():
|
|
51
68
|
if q_values.dim() == 2 and q_values.shape[0] > 1:
|
|
52
69
|
q = q_values[~indices]
|
|
53
70
|
else:
|
|
54
71
|
q = q_values
|
|
55
72
|
|
|
56
|
-
|
|
73
|
+
reflectivity_not_smeared = params[~indices].reflectivity(q, **refl_kwargs_not_smeared)
|
|
74
|
+
else:
|
|
75
|
+
reflectivity_not_smeared = None
|
|
57
76
|
|
|
77
|
+
# Compute smeared reflectivity
|
|
58
78
|
if indices.sum().item():
|
|
59
79
|
if q_values.dim() == 2 and q_values.shape[0] > 1:
|
|
60
80
|
q = q_values[indices]
|
|
61
81
|
else:
|
|
62
82
|
q = q_values
|
|
63
83
|
|
|
64
|
-
|
|
65
|
-
q, dq=dq, constant_dq=self.constant_dq,
|
|
84
|
+
reflectivity_smeared = params[indices].reflectivity(
|
|
85
|
+
q, dq=dq, constant_dq=self.constant_dq, gauss_num=self.gauss_num, **refl_kwargs_smeared
|
|
66
86
|
)
|
|
87
|
+
else:
|
|
88
|
+
reflectivity_smeared = None
|
|
89
|
+
|
|
90
|
+
curves = torch.empty(params.batch_size, q_values.shape[-1], device=params.device, dtype=params.dtype)
|
|
91
|
+
|
|
92
|
+
if (~indices).sum().item():
|
|
93
|
+
curves[~indices] = reflectivity_not_smeared
|
|
94
|
+
|
|
95
|
+
curves[indices] = reflectivity_smeared
|
|
96
|
+
|
|
97
|
+
q_resolutions[indices] = dq
|
|
67
98
|
|
|
68
|
-
return curves
|
|
99
|
+
return curves, q_resolutions
|
|
@@ -57,7 +57,7 @@ def get_reversed_params(thicknesses: Tensor, roughnesses: Tensor, slds: Tensor):
|
|
|
57
57
|
return reversed_params
|
|
58
58
|
|
|
59
59
|
|
|
60
|
-
def
|
|
60
|
+
def get_density_profiles_sld(
|
|
61
61
|
thicknesses: Tensor,
|
|
62
62
|
roughnesses: Tensor,
|
|
63
63
|
slds: Tensor,
|
|
@@ -120,29 +120,103 @@ def get_erf(z, z0, sigma, amp):
|
|
|
120
120
|
def get_gauss(z, z0, sigma, amp):
|
|
121
121
|
return amp / (sigma * sqrt(2 * pi)) * torch.exp(- (z - z0) ** 2 / 2 / sigma ** 2)
|
|
122
122
|
|
|
123
|
+
def get_density_profiles(
|
|
124
|
+
thicknesses: torch.Tensor,
|
|
125
|
+
roughnesses: torch.Tensor,
|
|
126
|
+
slds: torch.Tensor,
|
|
127
|
+
ambient_sld: torch.Tensor = None,
|
|
128
|
+
z_axis: torch.Tensor = None,
|
|
129
|
+
num: int = 1000,
|
|
130
|
+
padding_left: float = 0.2,
|
|
131
|
+
padding_right: float = 1.1,
|
|
132
|
+
):
|
|
133
|
+
"""
|
|
134
|
+
Args:
|
|
135
|
+
thicknesses (Tensor): finite layer thicknesses.
|
|
136
|
+
roughnesses (Tensor): interface roughnesses for all transitions (ambient→layer1 ... layerN→substrate).
|
|
137
|
+
slds (Tensor): SLDs for the finite layers + substrate.
|
|
138
|
+
ambient_sld (Tensor, optional): SLD for the top ambient. Defaults to 0.0 if None.
|
|
139
|
+
z_axis (Tensor, optional): a custom depth axis. If None, a linear axis is generated.
|
|
140
|
+
num (int): number of points in the generated z-axis (if z_axis is None).
|
|
141
|
+
padding_left (float): factor to extend the negative (above the surface) portion of z-axis.
|
|
142
|
+
padding_right (float): factor to extend the positive (into the sample) portion of z-axis.
|
|
143
|
+
|
|
144
|
+
Returns:
|
|
145
|
+
(z_axis, profile, d_profile)
|
|
146
|
+
z_axis: 1D Tensor of shape (num, ) with the depth coordinates.
|
|
147
|
+
profile: 2D Tensor of shape (batch_size, num) giving the SLD at each depth.
|
|
148
|
+
d_profile: 2D Tensor of shape (batch_size, num) giving d(SLD)/dz at each depth.
|
|
149
|
+
"""
|
|
150
|
+
|
|
151
|
+
bs, n = thicknesses.shape
|
|
152
|
+
assert roughnesses.shape == (bs, n + 1), (
|
|
153
|
+
f"Roughnesses must be (batch_size, num_layers+1). Found {roughnesses.shape} instead."
|
|
154
|
+
)
|
|
155
|
+
assert slds.shape == (bs, n + 1), (
|
|
156
|
+
f"SLDs must be (batch_size, num_layers+1). Found {slds.shape} instead."
|
|
157
|
+
)
|
|
158
|
+
assert torch.all(thicknesses >= 0), "Negative thickness encountered."
|
|
159
|
+
assert torch.all(roughnesses >= 0), "Negative roughness encountered."
|
|
160
|
+
|
|
161
|
+
if ambient_sld is None:
|
|
162
|
+
ambient_sld = torch.zeros((bs, 1), device=thicknesses.device)
|
|
163
|
+
else:
|
|
164
|
+
if ambient_sld.ndim == 1:
|
|
165
|
+
ambient_sld = ambient_sld.unsqueeze(-1)
|
|
166
|
+
|
|
167
|
+
slds_all = torch.cat([ambient_sld, slds], dim=-1) # new dimension: n+2
|
|
168
|
+
d_rhos = torch.diff(slds_all, dim=-1) # (bs, n+1)
|
|
169
|
+
|
|
170
|
+
interfaces = torch.cat([
|
|
171
|
+
torch.zeros((bs, 1), device=thicknesses.device), # z=0 for ambient→layer1
|
|
172
|
+
thicknesses
|
|
173
|
+
], dim=-1).cumsum(dim=-1) # now shape => (bs, n+1)
|
|
174
|
+
|
|
175
|
+
total_thickness = interfaces[..., -1].max()
|
|
176
|
+
if z_axis is None:
|
|
177
|
+
z_axis = torch.linspace(
|
|
178
|
+
-padding_left * total_thickness,
|
|
179
|
+
padding_right * total_thickness,
|
|
180
|
+
num,
|
|
181
|
+
device=thicknesses.device
|
|
182
|
+
) # shape => (num,)
|
|
183
|
+
if z_axis.ndim == 1:
|
|
184
|
+
z_axis = z_axis.unsqueeze(0) # shape => (1, num)
|
|
185
|
+
|
|
186
|
+
z_b = z_axis.repeat(bs, 1).unsqueeze(1) # (bs, 1, num)
|
|
187
|
+
interfaces_b = interfaces.unsqueeze(-1) # (bs, n+1, 1)
|
|
188
|
+
sigmas_b = (roughnesses * sqrt(2)).unsqueeze(-1) # (bs, n+1, 1)
|
|
189
|
+
d_rhos_b = d_rhos.unsqueeze(-1) # (bs, n+1, 1)
|
|
190
|
+
|
|
191
|
+
profile = get_erf(z_b, interfaces_b, sigmas_b, d_rhos_b).sum(dim=1) # (bs, num)
|
|
192
|
+
if ambient_sld is not None:
|
|
193
|
+
profile = profile + ambient_sld
|
|
194
|
+
|
|
195
|
+
d_profile = get_gauss(z_b, interfaces_b, sigmas_b, d_rhos_b).sum(dim=1) # (bs, num)
|
|
196
|
+
|
|
197
|
+
return z_axis.squeeze(0), profile, d_profile
|
|
123
198
|
|
|
124
199
|
def get_param_labels(
|
|
125
200
|
num_layers: int, *,
|
|
126
201
|
thickness_name: str = 'Thickness',
|
|
127
202
|
roughness_name: str = 'Roughness',
|
|
128
203
|
sld_name: str = 'SLD',
|
|
129
|
-
substrate_name: str = 'sub',
|
|
130
|
-
) -> List[str]:
|
|
131
|
-
thickness_labels = [f'{thickness_name} L{num_layers - i}' for i in range(num_layers)]
|
|
132
|
-
roughness_labels = [f'{roughness_name} L{num_layers - i}' for i in range(num_layers)] + [f'{roughness_name} {substrate_name}']
|
|
133
|
-
sld_labels = [f'{sld_name} L{num_layers - i}' for i in range(num_layers)] + [f'{sld_name} {substrate_name}']
|
|
134
|
-
return thickness_labels + roughness_labels + sld_labels
|
|
135
|
-
|
|
136
|
-
def get_param_labels_absorption_model(
|
|
137
|
-
num_layers: int, *,
|
|
138
|
-
thickness_name: str = 'Thickness',
|
|
139
|
-
roughness_name: str = 'Roughness',
|
|
140
|
-
real_sld_name: str = 'SLD real',
|
|
141
204
|
imag_sld_name: str = 'SLD imag',
|
|
142
205
|
substrate_name: str = 'sub',
|
|
206
|
+
parameterization_type: str = 'standard',
|
|
207
|
+
number_top_to_bottom: bool = False,
|
|
143
208
|
) -> List[str]:
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
209
|
+
def pos(i):
|
|
210
|
+
return i + 1 if number_top_to_bottom else num_layers - i
|
|
211
|
+
|
|
212
|
+
thickness_labels = [f'{thickness_name} L{pos(i)}' for i in range(num_layers)]
|
|
213
|
+
roughness_labels = [f'{roughness_name} L{pos(i)}' for i in range(num_layers)] + [f'{roughness_name} {substrate_name}']
|
|
214
|
+
sld_labels = [f'{sld_name} L{pos(i)}' for i in range(num_layers)] + [f'{sld_name} {substrate_name}']
|
|
215
|
+
|
|
216
|
+
all_labels = thickness_labels + roughness_labels + sld_labels
|
|
217
|
+
|
|
218
|
+
if parameterization_type == 'absorption':
|
|
219
|
+
imag_sld_labels = [f'{imag_sld_name} L{pos(i)}' for i in range(num_layers)] + [f'{imag_sld_name} {substrate_name}']
|
|
220
|
+
all_labels = all_labels + imag_sld_labels
|
|
221
|
+
|
|
222
|
+
return all_labels
|
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
from functools import reduce
|
|
3
|
+
from operator import or_
|
|
4
|
+
|
|
5
|
+
from reflectorch.inference.inference_model import EasyInferenceModel
|
|
6
|
+
from reflectorch import BasicParams
|
|
7
|
+
|
|
8
|
+
import refnx
|
|
9
|
+
from refnx.dataset import ReflectDataset, Data1D
|
|
10
|
+
from refnx.analysis import Transform, CurveFitter, Objective, Model, Parameter
|
|
11
|
+
from refnx.reflect import SLD, Slab, ReflectModel
|
|
12
|
+
|
|
13
|
+
def covert_reflectorch_prediction_to_refnx_structure(inference_model: EasyInferenceModel, pred_params_object: BasicParams, prior_bounds: np.array):
|
|
14
|
+
assert inference_model.trainer.loader.prior_sampler.param_model.__class__.__name__ == 'StandardModel'
|
|
15
|
+
|
|
16
|
+
n_layers = inference_model.trainer.loader.prior_sampler.max_num_layers
|
|
17
|
+
init_thicknesses = pred_params_object.thicknesses.squeeze().tolist()
|
|
18
|
+
init_roughnesses = pred_params_object.roughnesses.squeeze().tolist()
|
|
19
|
+
init_slds = pred_params_object.slds.squeeze().tolist()
|
|
20
|
+
|
|
21
|
+
sld_objects = []
|
|
22
|
+
|
|
23
|
+
for sld in init_slds:
|
|
24
|
+
sld_objects.append(SLD(value=sld))
|
|
25
|
+
|
|
26
|
+
layer_objects = [SLD(0)()]
|
|
27
|
+
for i in range(n_layers):
|
|
28
|
+
layer_objects.append(sld_objects[i](init_thicknesses[i], init_roughnesses[i]))
|
|
29
|
+
|
|
30
|
+
layer_objects.append(sld_objects[-1](0, init_roughnesses[-1]))
|
|
31
|
+
|
|
32
|
+
thickness_bounds = prior_bounds[:n_layers]
|
|
33
|
+
roughness_bounds = prior_bounds[n_layers:2*n_layers+1]
|
|
34
|
+
sld_bounds = prior_bounds[2*n_layers+1:]
|
|
35
|
+
|
|
36
|
+
for i, layer in enumerate(layer_objects):
|
|
37
|
+
if i == 0:
|
|
38
|
+
print("Ambient (air)")
|
|
39
|
+
print(80 * '-')
|
|
40
|
+
elif i < n_layers+1:
|
|
41
|
+
layer.thick.setp(bounds=thickness_bounds[i-1], vary=True)
|
|
42
|
+
layer.rough.setp(bounds=roughness_bounds[i-1], vary=True)
|
|
43
|
+
layer.sld.real.setp(bounds=sld_bounds[i-1], vary=True)
|
|
44
|
+
|
|
45
|
+
print(f'Layer {i}')
|
|
46
|
+
print(f'Thickness: value {layer.thick.value}, vary {layer.thick.vary}, bounds {layer.thick.bounds}')
|
|
47
|
+
print(f'Roughness: value {layer.rough.value}, vary {layer.rough.vary}, bounds {layer.rough.bounds}')
|
|
48
|
+
print(f'SLD: value {layer.sld.real.value}, vary {layer.sld.real.vary}, bounds {layer.sld.real.bounds}')
|
|
49
|
+
print(80 * '-')
|
|
50
|
+
else: #substrate
|
|
51
|
+
layer.rough.setp(bounds=roughness_bounds[i-1], vary=True)
|
|
52
|
+
layer.sld.real.setp(bounds=sld_bounds[i-1], vary=True)
|
|
53
|
+
|
|
54
|
+
print(f'Substrate')
|
|
55
|
+
print(f'Thickness: value {layer.thick.value}, vary {layer.thick.vary}, bounds {layer.thick.bounds}')
|
|
56
|
+
print(f'Roughness: value {layer.rough.value}, vary {layer.rough.vary}, bounds {layer.rough.bounds}')
|
|
57
|
+
print(f'SLD: value {layer.sld.real.value}, vary {layer.sld.real.vary}, bounds {layer.sld.real.bounds}')
|
|
58
|
+
|
|
59
|
+
refnx_structure = reduce(or_, layer_objects)
|
|
60
|
+
|
|
61
|
+
return refnx_structure
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
###Example usage:
|
|
65
|
+
# refnx_structure = covert_reflectorch_prediction_to_refnx_structure(inference_model, pred_params_object, prior_bounds)
|
|
66
|
+
|
|
67
|
+
# refnx_reflect_model = ReflectModel(refnx_structure, bkg=1e-10, dq=0.0)
|
|
68
|
+
# refnx_reflect_model.scale.setp(bounds=(0.8, 1.2), vary=True)
|
|
69
|
+
# refnx_reflect_model.q_offset.setp(bounds=(-0.01, 0.01), vary=True)
|
|
70
|
+
# refnx_reflect_model.bkg.setp(bounds=(1e-10, 1e-8), vary=True)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
# data = Data1D(data=(q_model, exp_curve_interp))
|
|
74
|
+
|
|
75
|
+
# refnx_objective = Objective(refnx_reflect_model, data, transform=Transform("logY"))
|
|
76
|
+
# fitter = CurveFitter(refnx_objective)
|
|
77
|
+
# fitter.fit('least_squares')
|