reflectorch 1.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- reflectorch/__init__.py +17 -0
- reflectorch/data_generation/__init__.py +128 -0
- reflectorch/data_generation/dataset.py +216 -0
- reflectorch/data_generation/likelihoods.py +80 -0
- reflectorch/data_generation/noise.py +471 -0
- reflectorch/data_generation/priors/__init__.py +60 -0
- reflectorch/data_generation/priors/base.py +55 -0
- reflectorch/data_generation/priors/exp_subprior_sampler.py +298 -0
- reflectorch/data_generation/priors/independent_priors.py +195 -0
- reflectorch/data_generation/priors/multilayer_models.py +311 -0
- reflectorch/data_generation/priors/multilayer_structures.py +104 -0
- reflectorch/data_generation/priors/no_constraints.py +206 -0
- reflectorch/data_generation/priors/parametric_models.py +842 -0
- reflectorch/data_generation/priors/parametric_subpriors.py +369 -0
- reflectorch/data_generation/priors/params.py +252 -0
- reflectorch/data_generation/priors/sampler_strategies.py +370 -0
- reflectorch/data_generation/priors/scaler_mixin.py +65 -0
- reflectorch/data_generation/priors/subprior_sampler.py +371 -0
- reflectorch/data_generation/priors/utils.py +118 -0
- reflectorch/data_generation/process_data.py +41 -0
- reflectorch/data_generation/q_generator.py +280 -0
- reflectorch/data_generation/reflectivity/__init__.py +102 -0
- reflectorch/data_generation/reflectivity/abeles.py +97 -0
- reflectorch/data_generation/reflectivity/kinematical.py +71 -0
- reflectorch/data_generation/reflectivity/memory_eff.py +105 -0
- reflectorch/data_generation/reflectivity/numpy_implementations.py +120 -0
- reflectorch/data_generation/reflectivity/smearing.py +138 -0
- reflectorch/data_generation/reflectivity/smearing_pointwise.py +110 -0
- reflectorch/data_generation/scale_curves.py +112 -0
- reflectorch/data_generation/smearing.py +99 -0
- reflectorch/data_generation/utils.py +223 -0
- reflectorch/extensions/__init__.py +0 -0
- reflectorch/extensions/jupyter/__init__.py +11 -0
- reflectorch/extensions/jupyter/api.py +85 -0
- reflectorch/extensions/jupyter/callbacks.py +34 -0
- reflectorch/extensions/jupyter/components.py +758 -0
- reflectorch/extensions/jupyter/custom_select.py +268 -0
- reflectorch/extensions/jupyter/log_widget.py +241 -0
- reflectorch/extensions/jupyter/model_selection.py +495 -0
- reflectorch/extensions/jupyter/plotly_plot_manager.py +329 -0
- reflectorch/extensions/jupyter/widget.py +625 -0
- reflectorch/extensions/matplotlib/__init__.py +5 -0
- reflectorch/extensions/matplotlib/losses.py +32 -0
- reflectorch/extensions/refnx/refnx_conversion.py +77 -0
- reflectorch/inference/__init__.py +28 -0
- reflectorch/inference/inference_model.py +848 -0
- reflectorch/inference/input_interface.py +239 -0
- reflectorch/inference/loading_data.py +55 -0
- reflectorch/inference/multilayer_fitter.py +171 -0
- reflectorch/inference/multilayer_inference_model.py +193 -0
- reflectorch/inference/plotting.py +524 -0
- reflectorch/inference/preprocess_exp/__init__.py +7 -0
- reflectorch/inference/preprocess_exp/attenuation.py +36 -0
- reflectorch/inference/preprocess_exp/cut_with_q_ratio.py +31 -0
- reflectorch/inference/preprocess_exp/footprint.py +81 -0
- reflectorch/inference/preprocess_exp/interpolation.py +19 -0
- reflectorch/inference/preprocess_exp/normalize.py +21 -0
- reflectorch/inference/preprocess_exp/preprocess.py +121 -0
- reflectorch/inference/query_matcher.py +82 -0
- reflectorch/inference/record_time.py +43 -0
- reflectorch/inference/sampler_solution.py +56 -0
- reflectorch/inference/scipy_fitter.py +364 -0
- reflectorch/inference/torch_fitter.py +87 -0
- reflectorch/ml/__init__.py +32 -0
- reflectorch/ml/basic_trainer.py +292 -0
- reflectorch/ml/callbacks.py +81 -0
- reflectorch/ml/dataloaders.py +27 -0
- reflectorch/ml/loggers.py +56 -0
- reflectorch/ml/schedulers.py +356 -0
- reflectorch/ml/trainers.py +201 -0
- reflectorch/ml/utils.py +2 -0
- reflectorch/models/__init__.py +16 -0
- reflectorch/models/activations.py +50 -0
- reflectorch/models/encoders/__init__.py +19 -0
- reflectorch/models/encoders/conv_encoder.py +219 -0
- reflectorch/models/encoders/conv_res_net.py +115 -0
- reflectorch/models/encoders/fno.py +134 -0
- reflectorch/models/encoders/integral_kernel_embedding.py +390 -0
- reflectorch/models/networks/__init__.py +14 -0
- reflectorch/models/networks/mlp_networks.py +434 -0
- reflectorch/models/networks/residual_net.py +157 -0
- reflectorch/paths.py +29 -0
- reflectorch/runs/__init__.py +31 -0
- reflectorch/runs/config.py +25 -0
- reflectorch/runs/slurm_utils.py +93 -0
- reflectorch/runs/train.py +78 -0
- reflectorch/runs/utils.py +405 -0
- reflectorch/test_config.py +4 -0
- reflectorch/train.py +4 -0
- reflectorch/train_on_cluster.py +4 -0
- reflectorch/utils.py +98 -0
- reflectorch-1.5.1.dist-info/METADATA +151 -0
- reflectorch-1.5.1.dist-info/RECORD +96 -0
- reflectorch-1.5.1.dist-info/WHEEL +5 -0
- reflectorch-1.5.1.dist-info/licenses/LICENSE.txt +21 -0
- reflectorch-1.5.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,120 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def abeles_np(
|
|
7
|
+
q: np.ndarray,
|
|
8
|
+
thickness: np.ndarray,
|
|
9
|
+
roughness: np.ndarray,
|
|
10
|
+
sld: np.ndarray,
|
|
11
|
+
):
|
|
12
|
+
c_dtype = np.complex128 if q.dtype is np.float64 else np.complex64
|
|
13
|
+
|
|
14
|
+
if q.ndim == thickness.ndim == roughness.ndim == sld.ndim == 1:
|
|
15
|
+
zero_batch = True
|
|
16
|
+
else:
|
|
17
|
+
zero_batch = False
|
|
18
|
+
|
|
19
|
+
thickness = np.atleast_2d(thickness)
|
|
20
|
+
roughness = np.atleast_2d(roughness)
|
|
21
|
+
sld = np.atleast_2d(sld)
|
|
22
|
+
|
|
23
|
+
batch_size, num_layers = thickness.shape
|
|
24
|
+
|
|
25
|
+
sld = np.concatenate([np.zeros((batch_size, 1)).astype(sld.dtype), sld], -1)[:, None]
|
|
26
|
+
thickness = np.concatenate([np.zeros((batch_size, 1)).astype(thickness.dtype), thickness], -1)[:, None]
|
|
27
|
+
roughness = roughness[:, None] ** 2
|
|
28
|
+
|
|
29
|
+
sld = sld * 1e-6 + 1e-30j
|
|
30
|
+
|
|
31
|
+
k_z0 = (q / 2).astype(c_dtype)
|
|
32
|
+
|
|
33
|
+
if len(k_z0.shape) == 1:
|
|
34
|
+
k_z0 = k_z0[None]
|
|
35
|
+
|
|
36
|
+
if len(k_z0.shape) == 2:
|
|
37
|
+
k_z0 = k_z0[..., None]
|
|
38
|
+
|
|
39
|
+
k_n = np.sqrt(k_z0 ** 2 - 4 * np.pi * sld)
|
|
40
|
+
|
|
41
|
+
# k_n.shape - (batch, q, layers)
|
|
42
|
+
|
|
43
|
+
k_n, k_np1 = k_n[..., :-1], k_n[..., 1:]
|
|
44
|
+
|
|
45
|
+
beta = 1j * thickness * k_n
|
|
46
|
+
|
|
47
|
+
exp_beta = np.exp(beta)
|
|
48
|
+
exp_m_beta = np.exp(-beta)
|
|
49
|
+
|
|
50
|
+
rn = (k_n - k_np1) / (k_n + k_np1) * np.exp(- 2 * k_n * k_np1 * roughness)
|
|
51
|
+
|
|
52
|
+
c_matrices = np.stack([
|
|
53
|
+
np.stack([exp_beta, rn * exp_m_beta], -1),
|
|
54
|
+
np.stack([rn * exp_beta, exp_m_beta], -1),
|
|
55
|
+
], -1)
|
|
56
|
+
|
|
57
|
+
c_matrices = np.moveaxis(c_matrices, -3, 0)
|
|
58
|
+
|
|
59
|
+
m, c_matrices = c_matrices[0], c_matrices[1:]
|
|
60
|
+
|
|
61
|
+
for c in c_matrices:
|
|
62
|
+
m = m @ c
|
|
63
|
+
|
|
64
|
+
r = np.abs(m[..., 1, 0] / m[..., 0, 0]) ** 2
|
|
65
|
+
r = np.clip(r, None, 1.)
|
|
66
|
+
|
|
67
|
+
if zero_batch:
|
|
68
|
+
r = r[0]
|
|
69
|
+
|
|
70
|
+
return r
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def kinematical_approximation_np(
|
|
74
|
+
q: np.ndarray,
|
|
75
|
+
thickness: np.ndarray,
|
|
76
|
+
roughness: np.ndarray,
|
|
77
|
+
sld: np.ndarray,
|
|
78
|
+
):
|
|
79
|
+
if q.ndim == thickness.ndim == roughness.ndim == sld.ndim == 1:
|
|
80
|
+
zero_batch = True
|
|
81
|
+
else:
|
|
82
|
+
zero_batch = False
|
|
83
|
+
|
|
84
|
+
thickness = np.atleast_2d(thickness)
|
|
85
|
+
roughness = np.atleast_2d(roughness)
|
|
86
|
+
sld = np.atleast_2d(sld) * 1e-6 + 1e-30j
|
|
87
|
+
substrate_sld = sld[:, -1:]
|
|
88
|
+
|
|
89
|
+
batch_size, num_layers = thickness.shape
|
|
90
|
+
|
|
91
|
+
if q.ndim == 1:
|
|
92
|
+
q = q[None]
|
|
93
|
+
|
|
94
|
+
if q.ndim == 2:
|
|
95
|
+
q = q[..., None]
|
|
96
|
+
|
|
97
|
+
drho = np.concatenate([sld[..., 0][..., None], sld[..., 1:] - sld[..., :-1]], -1)[:, None]
|
|
98
|
+
thickness = np.cumsum(np.concatenate([np.zeros((batch_size, 1)), thickness], -1), -1)[:, None]
|
|
99
|
+
roughness = roughness[:, None]
|
|
100
|
+
|
|
101
|
+
r = np.abs((drho * np.exp(- (roughness * q) ** 2 / 2 + 1j * (q * thickness))).sum(-1)).astype(float) ** 2
|
|
102
|
+
|
|
103
|
+
rf = _get_resnel_reflectivity_np(q, substrate_sld[:, None])
|
|
104
|
+
|
|
105
|
+
r = np.clip(r * rf / np.real(substrate_sld) ** 2, None, 1.)
|
|
106
|
+
|
|
107
|
+
if zero_batch:
|
|
108
|
+
r = r[0]
|
|
109
|
+
|
|
110
|
+
return r
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def _get_resnel_reflectivity_np(q, substrate_slds):
|
|
114
|
+
_RE_CONST = 0.28174103675406496
|
|
115
|
+
|
|
116
|
+
q_c = np.sqrt(substrate_slds + 0j) / _RE_CONST * 2
|
|
117
|
+
q_prime = np.sqrt(q ** 2 - q_c ** 2 + 0j)
|
|
118
|
+
r_f = np.abs((q - q_prime) / (q + q_prime)).astype(float) ** 2
|
|
119
|
+
|
|
120
|
+
return r_f[..., 0]
|
|
@@ -0,0 +1,138 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
from math import pi, sqrt, log
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
from torch import Tensor
|
|
6
|
+
|
|
7
|
+
from reflectorch.data_generation.reflectivity.abeles import abeles
|
|
8
|
+
from torch.nn.functional import conv1d, pad
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def abeles_constant_smearing(
|
|
12
|
+
q: Tensor,
|
|
13
|
+
thickness: Tensor,
|
|
14
|
+
roughness: Tensor,
|
|
15
|
+
sld: Tensor,
|
|
16
|
+
dq: Tensor = None,
|
|
17
|
+
gauss_num: int = 31,
|
|
18
|
+
constant_dq: bool = False,
|
|
19
|
+
abeles_func=None,
|
|
20
|
+
**abeles_kwargs
|
|
21
|
+
):
|
|
22
|
+
abeles_func = abeles_func or abeles
|
|
23
|
+
|
|
24
|
+
if dq.dtype != thickness.dtype:
|
|
25
|
+
q = q.to(thickness)
|
|
26
|
+
|
|
27
|
+
if dq.dtype != thickness.dtype:
|
|
28
|
+
dq = dq.to(thickness)
|
|
29
|
+
|
|
30
|
+
if q.shape[0] == 1:
|
|
31
|
+
q = q.repeat(thickness.shape[0], 1)
|
|
32
|
+
|
|
33
|
+
q_lin = _get_q_axes(q, dq, gauss_num, constant_dq=constant_dq)
|
|
34
|
+
kernels = _get_t_gauss_kernels(dq, gauss_num)
|
|
35
|
+
|
|
36
|
+
curves = abeles_func(q_lin, thickness, roughness, sld, **abeles_kwargs)
|
|
37
|
+
|
|
38
|
+
padding = (kernels.shape[-1] - 1) // 2
|
|
39
|
+
padded_curves = pad(curves, (padding, padding), 'reflect')
|
|
40
|
+
|
|
41
|
+
smeared_curves = conv1d(
|
|
42
|
+
padded_curves, kernels[:, None], groups=kernels.shape[0],
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
if q.shape[0] != smeared_curves.shape[0]:
|
|
46
|
+
repeat_factor = smeared_curves.shape[0] // q.shape[0]
|
|
47
|
+
q = q.repeat(repeat_factor, 1)
|
|
48
|
+
q_lin = q_lin.repeat(repeat_factor, 1)
|
|
49
|
+
|
|
50
|
+
smeared_curves = _batch_linear_interp1d(q_lin, smeared_curves, q)
|
|
51
|
+
|
|
52
|
+
return smeared_curves
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
_FWHM = 2 * sqrt(2 * log(2.0))
|
|
56
|
+
_2PI_SQRT = 1. / sqrt(2 * pi)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def _batch_linspace(start: Tensor, end: Tensor, num: int):
|
|
60
|
+
return torch.linspace(0, 1, int(num), device=end.device, dtype=end.dtype)[None] * (end - start) + start
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def _torch_gauss(x, s):
|
|
64
|
+
return _2PI_SQRT / s * torch.exp(-0.5 * x ** 2 / s / s)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def _get_t_gauss_kernels(resolutions: Tensor, gaussnum: int = 51):
|
|
68
|
+
gauss_x = _batch_linspace(-1.7 * resolutions, 1.7 * resolutions, gaussnum)
|
|
69
|
+
gauss_y = _torch_gauss(gauss_x, resolutions / _FWHM) * (gauss_x[:, 1] - gauss_x[:, 0])[:, None]
|
|
70
|
+
return gauss_y
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def _get_q_axes(q: Tensor, resolutions: Tensor, gaussnum: int = 51, constant_dq: bool = False):
|
|
74
|
+
if constant_dq:
|
|
75
|
+
return _get_q_axes_for_constant_dq(q, resolutions, gaussnum)
|
|
76
|
+
else:
|
|
77
|
+
return _get_q_axes_for_linear_dq(q, resolutions, gaussnum)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def _get_q_axes_for_linear_dq(q: Tensor, resolutions: Tensor, gaussnum: int = 51):
|
|
81
|
+
gaussgpoint = (gaussnum - 1) / 2
|
|
82
|
+
|
|
83
|
+
lowq = torch.clamp_min_(q.min(1).values, 1e-6)
|
|
84
|
+
highq = q.max(1).values
|
|
85
|
+
|
|
86
|
+
start = torch.log10(lowq)[:, None] - 6 * resolutions / _FWHM
|
|
87
|
+
end = torch.log10(highq[:, None] * (1 + 6 * resolutions / _FWHM))
|
|
88
|
+
|
|
89
|
+
interpnums = torch.abs(
|
|
90
|
+
(torch.abs(end - start)) / (1.7 * resolutions / _FWHM / gaussgpoint)
|
|
91
|
+
).round().to(int)
|
|
92
|
+
|
|
93
|
+
q_lin = 10 ** _batch_linspace_with_padding(start, end, interpnums)
|
|
94
|
+
|
|
95
|
+
return q_lin
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def _get_q_axes_for_constant_dq(q: Tensor, resolutions: Tensor, gaussnum: int = 51) -> Tensor:
|
|
99
|
+
gaussgpoint = (gaussnum - 1) / 2
|
|
100
|
+
|
|
101
|
+
start = q.min(1).values[:, None] - resolutions * 1.7
|
|
102
|
+
end = q.max(1).values[:, None] + resolutions * 1.7
|
|
103
|
+
|
|
104
|
+
interpnums = torch.abs(
|
|
105
|
+
(torch.abs(end - start)) / (1.7 * resolutions / gaussgpoint)
|
|
106
|
+
).round().to(int)
|
|
107
|
+
|
|
108
|
+
q_lin = _batch_linspace_with_padding(start, end, interpnums)
|
|
109
|
+
q_lin = torch.clamp_min_(q_lin, 1e-6)
|
|
110
|
+
|
|
111
|
+
return q_lin
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def _batch_linspace_with_padding(start: Tensor, end: Tensor, nums: Tensor) -> Tensor:
|
|
115
|
+
max_num = nums.max().int().item()
|
|
116
|
+
|
|
117
|
+
deltas = 1 / (nums - 1)
|
|
118
|
+
|
|
119
|
+
x = torch.clamp_min_(_batch_linspace(deltas * (nums - max_num), torch.ones_like(deltas), max_num), 0)
|
|
120
|
+
|
|
121
|
+
x = x * (end - start) + start
|
|
122
|
+
|
|
123
|
+
return x
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def _batch_linear_interp1d(x: Tensor, y: Tensor, x_new: Tensor) -> Tensor:
|
|
127
|
+
eps = torch.finfo(y.dtype).eps
|
|
128
|
+
|
|
129
|
+
ind = torch.searchsorted(x.contiguous(), x_new.contiguous())
|
|
130
|
+
|
|
131
|
+
ind = torch.clamp_(ind - 1, 0, x.shape[-1] - 2)
|
|
132
|
+
slopes = (y[..., 1:] - y[..., :-1]) / (eps + (x[..., 1:] - x[..., :-1]))
|
|
133
|
+
ind_y = ind + torch.arange(slopes.shape[0], device=slopes.device)[:, None] * y.shape[1]
|
|
134
|
+
ind_slopes = ind + torch.arange(slopes.shape[0], device=slopes.device)[:, None] * slopes.shape[1]
|
|
135
|
+
|
|
136
|
+
y_new = y.flatten()[ind_y] + slopes.flatten()[ind_slopes] * (x_new - x.flatten()[ind_y])
|
|
137
|
+
|
|
138
|
+
return y_new
|
|
@@ -0,0 +1,110 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import scipy
|
|
3
|
+
import numpy as np
|
|
4
|
+
from functools import lru_cache
|
|
5
|
+
from typing import Tuple
|
|
6
|
+
|
|
7
|
+
from reflectorch.data_generation.reflectivity.abeles import abeles
|
|
8
|
+
|
|
9
|
+
#Pytorch version based on the JAX implementation of pointwise smearing in the refnx package.
|
|
10
|
+
|
|
11
|
+
@lru_cache(maxsize=128)
|
|
12
|
+
def gauss_legendre(n: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
13
|
+
"""
|
|
14
|
+
Calculate Gaussian quadrature abscissae and weights.
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
n (int): Gaussian quadrature order.
|
|
18
|
+
|
|
19
|
+
Returns:
|
|
20
|
+
Tuple[torch.Tensor, torch.Tensor]: The abscissae and weights for Gauss-Legendre integration.
|
|
21
|
+
"""
|
|
22
|
+
return scipy.special.p_roots(n)
|
|
23
|
+
|
|
24
|
+
def gauss(x: torch.Tensor) -> torch.Tensor:
|
|
25
|
+
"""
|
|
26
|
+
Calculate the Gaussian function.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
x (torch.Tensor): Input tensor.
|
|
30
|
+
|
|
31
|
+
Returns:
|
|
32
|
+
torch.Tensor: Output tensor after applying the Gaussian function.
|
|
33
|
+
"""
|
|
34
|
+
return torch.exp(-0.5 * x * x)
|
|
35
|
+
|
|
36
|
+
def abeles_pointwise_smearing(
|
|
37
|
+
q: torch.Tensor,
|
|
38
|
+
dq: torch.Tensor,
|
|
39
|
+
thickness: torch.Tensor,
|
|
40
|
+
roughness: torch.Tensor,
|
|
41
|
+
sld: torch.Tensor,
|
|
42
|
+
gauss_num: int = 17,
|
|
43
|
+
abeles_func=None,
|
|
44
|
+
**abeles_kwargs,
|
|
45
|
+
) -> torch.Tensor:
|
|
46
|
+
"""
|
|
47
|
+
Compute reflectivity with variable smearing using Gaussian quadrature.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
q (torch.Tensor): The momentum transfer (q) values.
|
|
51
|
+
dq (torch.Tensor): The resolution for curve smearing.
|
|
52
|
+
thickness (torch.Tensor): The layer thicknesses.
|
|
53
|
+
roughness (torch.Tensor): The interlayer roughnesses.
|
|
54
|
+
sld (torch.Tensor): The SLDs of the layers.
|
|
55
|
+
sld_magnetic (torch.Tensor, optional): The magnetic SLDs of the layers.
|
|
56
|
+
magnetization_angle (torch.Tensor, optional): The magnetization angles.
|
|
57
|
+
polarizer_eff (torch.Tensor, optional): The polarizer efficiency.
|
|
58
|
+
analyzer_eff (torch.Tensor, optional): The analyzer efficiency.
|
|
59
|
+
abeles_func (Callable, optional): A function implementing the simulation of the reflectivity curves.
|
|
60
|
+
gauss_num (int, optional): Gaussian quadrature order. Defaults to 17.
|
|
61
|
+
|
|
62
|
+
Returns:
|
|
63
|
+
torch.Tensor: The computed reflectivity curves.
|
|
64
|
+
"""
|
|
65
|
+
abeles_func = abeles_func or abeles
|
|
66
|
+
|
|
67
|
+
if q.shape[0] == 1:
|
|
68
|
+
q = q.repeat(thickness.shape[0], 1)
|
|
69
|
+
|
|
70
|
+
_FWHM = 2 * np.sqrt(2 * np.log(2.0))
|
|
71
|
+
_INTLIMIT = 3.5
|
|
72
|
+
|
|
73
|
+
bs = q.shape[0]
|
|
74
|
+
nq = q.shape[-1]
|
|
75
|
+
device = q.device
|
|
76
|
+
|
|
77
|
+
quad_order = gauss_num
|
|
78
|
+
abscissa, weights = gauss_legendre(quad_order)
|
|
79
|
+
abscissa = torch.tensor(abscissa)[None, :, None].to(device)
|
|
80
|
+
weights = torch.tensor(weights)[None, :, None].to(device)
|
|
81
|
+
prefactor = 1.0 / np.sqrt(2 * np.pi)
|
|
82
|
+
|
|
83
|
+
gaussvals = prefactor * gauss(abscissa * _INTLIMIT)
|
|
84
|
+
|
|
85
|
+
va = q[:, None, :] - _INTLIMIT * dq[:, None, :] / _FWHM
|
|
86
|
+
vb = q[:, None, :] + _INTLIMIT * dq[:, None, :] / _FWHM
|
|
87
|
+
|
|
88
|
+
qvals_for_res_0 = (abscissa * (vb - va) + vb + va) / 2
|
|
89
|
+
qvals_for_res = qvals_for_res_0.reshape(bs, -1)
|
|
90
|
+
|
|
91
|
+
refl_curves = abeles_func(
|
|
92
|
+
q=qvals_for_res,
|
|
93
|
+
thickness=thickness,
|
|
94
|
+
roughness=roughness,
|
|
95
|
+
sld=sld,
|
|
96
|
+
**abeles_kwargs
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
# Handle multiple channels
|
|
100
|
+
if refl_curves.dim() == 3:
|
|
101
|
+
n_channels = refl_curves.shape[1]
|
|
102
|
+
refl_curves = refl_curves.reshape(bs, n_channels, quad_order, nq)
|
|
103
|
+
refl_curves = refl_curves * gaussvals.unsqueeze(1) * weights.unsqueeze(1)
|
|
104
|
+
refl_curves = torch.sum(refl_curves, dim=2) * _INTLIMIT
|
|
105
|
+
else:
|
|
106
|
+
refl_curves = refl_curves.reshape(bs, quad_order, nq)
|
|
107
|
+
refl_curves = refl_curves * gaussvals * weights
|
|
108
|
+
refl_curves = torch.sum(refl_curves, dim=1) * _INTLIMIT
|
|
109
|
+
|
|
110
|
+
return refl_curves
|
|
@@ -0,0 +1,112 @@
|
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from torch import Tensor
|
|
5
|
+
|
|
6
|
+
from reflectorch.data_generation.priors import PriorSampler
|
|
7
|
+
from reflectorch.paths import SAVED_MODELS_DIR
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class CurvesScaler(object):
|
|
11
|
+
"""Base class for curve scalers"""
|
|
12
|
+
def scale(self, curves: Tensor):
|
|
13
|
+
raise NotImplementedError
|
|
14
|
+
|
|
15
|
+
def restore(self, curves: Tensor):
|
|
16
|
+
raise NotImplementedError
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class LogAffineCurvesScaler(CurvesScaler):
|
|
20
|
+
""" Curve scaler which scales the reflectivity curves according to the logarithmic affine transformation:
|
|
21
|
+
:math:`\log_{10}(R + eps) \cdot weight + bias`.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
weight (float): multiplication factor in the transformation
|
|
25
|
+
bias (float): addition term in the transformation
|
|
26
|
+
eps (float): sets the minimum intensity value of the reflectivity curves which is considered
|
|
27
|
+
"""
|
|
28
|
+
def __init__(self, weight: float = 0.1, bias: float = 0.5, eps: float = 1e-10):
|
|
29
|
+
self.weight = weight
|
|
30
|
+
self.bias = bias
|
|
31
|
+
self.eps = eps
|
|
32
|
+
|
|
33
|
+
def scale(self, curves: Tensor):
|
|
34
|
+
"""scales the reflectivity curves to a ML-friendly range
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
curves (Tensor): original reflectivity curves
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
Tensor: reflectivity curves scaled to a ML-friendly range
|
|
41
|
+
"""
|
|
42
|
+
return torch.log10(curves + self.eps) * self.weight + self.bias
|
|
43
|
+
|
|
44
|
+
def restore(self, curves: Tensor):
|
|
45
|
+
"""restores the physical reflectivity curves
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
curves (Tensor): scaled reflectivity curves
|
|
49
|
+
|
|
50
|
+
Returns:
|
|
51
|
+
Tensor: reflectivity curves restored to the physical range
|
|
52
|
+
"""
|
|
53
|
+
return 10 ** ((curves - self.bias) / self.weight) - self.eps
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class MeanNormalizationCurvesScaler(CurvesScaler):
|
|
57
|
+
"""Curve scaler which scales the reflectivity curves by the precomputed mean of a batch of curves
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
path (str, optional): path to the precomputed mean of the curves, only used if ``curves_mean`` is None. Defaults to None.
|
|
61
|
+
curves_mean (Tensor, optional): the precomputed mean of the curves. Defaults to None.
|
|
62
|
+
device (torch.device, optional): the Pytorch device. Defaults to 'cuda'.
|
|
63
|
+
"""
|
|
64
|
+
|
|
65
|
+
def __init__(self, path: str = None, curves_mean: Tensor = None, device: torch.device = 'cuda'):
|
|
66
|
+
if curves_mean is None:
|
|
67
|
+
curves_mean = torch.load(self.get_path(path))
|
|
68
|
+
self.curves_mean = curves_mean.to(device)
|
|
69
|
+
|
|
70
|
+
def scale(self, curves: Tensor):
|
|
71
|
+
"""scales the reflectivity curves to a ML-friendly range
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
curves (Tensor): original reflectivity curves
|
|
75
|
+
|
|
76
|
+
Returns:
|
|
77
|
+
Tensor: reflectivity curves scaled to a ML-friendly range
|
|
78
|
+
"""
|
|
79
|
+
self.curves_mean = self.curves_mean.to(curves)
|
|
80
|
+
return curves / self.curves_mean - 1
|
|
81
|
+
|
|
82
|
+
def restore(self, curves: Tensor):
|
|
83
|
+
"""restores the physical reflectivity curves
|
|
84
|
+
|
|
85
|
+
Args:
|
|
86
|
+
curves (Tensor): scaled reflectivity curves
|
|
87
|
+
|
|
88
|
+
Returns:
|
|
89
|
+
Tensor: reflectivity curves restored to the physical range
|
|
90
|
+
"""
|
|
91
|
+
self.curves_mean = self.curves_mean.to(curves)
|
|
92
|
+
return (curves + 1) * self.curves_mean
|
|
93
|
+
|
|
94
|
+
@staticmethod
|
|
95
|
+
def save(prior_sampler: PriorSampler, q: Tensor, path: str, num: int = 16384):
|
|
96
|
+
"""computes the mean of a batch of reflectivity curves and saves it
|
|
97
|
+
|
|
98
|
+
Args:
|
|
99
|
+
prior_sampler (PriorSampler): the prior sampler
|
|
100
|
+
q (Tensor): the q values
|
|
101
|
+
path (str): the path for saving the mean of the curves
|
|
102
|
+
num (int, optional): the number of curves used to compute the mean. Defaults to 16384.
|
|
103
|
+
"""
|
|
104
|
+
params = prior_sampler.sample(num)
|
|
105
|
+
curves_mean = params.reflectivity(q, log=False).mean(0).cpu()
|
|
106
|
+
torch.save(curves_mean, MeanNormalizationCurvesScaler.get_path(path))
|
|
107
|
+
|
|
108
|
+
@staticmethod
|
|
109
|
+
def get_path(path: str) -> Path:
|
|
110
|
+
if not path.endswith('.pt'):
|
|
111
|
+
path = path + '.pt'
|
|
112
|
+
return SAVED_MODELS_DIR / path
|
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from torch import Tensor
|
|
3
|
+
|
|
4
|
+
from reflectorch.data_generation.priors.parametric_subpriors import BasicParams
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class Smearing(object):
|
|
8
|
+
"""Class which applies resolution smearing to the reflectivity curves.
|
|
9
|
+
The intensity at a q point will be the average of the intensities of neighbouring q points, weighted by a gaussian profile.
|
|
10
|
+
|
|
11
|
+
Args:
|
|
12
|
+
sigma_range (tuple, optional): the range for sampling the resolutions. Defaults to (0.01, 0.1).
|
|
13
|
+
constant_dq (bool, optional): if ``True`` the smearing is constant (the resolution is given by the constant dq at each point in the curve)
|
|
14
|
+
otherwise the smearing is linear (the resolution is given by the constant dq/q at each point in the curve). Defaults to True.
|
|
15
|
+
gauss_num (int, optional): the number of interpolating gaussian profiles. Defaults to 31.
|
|
16
|
+
share_smeared (float, optional): the share of curves in the batch for which the resolution smearing is applied. Defaults to 0.2.
|
|
17
|
+
"""
|
|
18
|
+
def __init__(self,
|
|
19
|
+
sigma_range: tuple = (0.01, 0.1),
|
|
20
|
+
constant_dq: bool = False,
|
|
21
|
+
gauss_num: int = 31,
|
|
22
|
+
share_smeared: float = 0.2,
|
|
23
|
+
):
|
|
24
|
+
self.sigma_min, self.sigma_max = sigma_range
|
|
25
|
+
self.sigma_delta = self.sigma_max - self.sigma_min
|
|
26
|
+
self.constant_dq = constant_dq
|
|
27
|
+
self.gauss_num = gauss_num
|
|
28
|
+
self.share_smeared = share_smeared
|
|
29
|
+
|
|
30
|
+
def __repr__(self):
|
|
31
|
+
return f'Smearing(({self.sigma_min}, {self.sigma_max})'
|
|
32
|
+
|
|
33
|
+
def generate_resolutions(self, batch_size: int, device=None, dtype=None):
|
|
34
|
+
num_smeared = int(batch_size * self.share_smeared)
|
|
35
|
+
if not num_smeared:
|
|
36
|
+
return None, None
|
|
37
|
+
dq = torch.rand(num_smeared, 1, device=device, dtype=dtype) * self.sigma_delta + self.sigma_min
|
|
38
|
+
indices = torch.zeros(batch_size, device=device, dtype=torch.bool)
|
|
39
|
+
indices[torch.randperm(batch_size, device=device)[:num_smeared]] = True
|
|
40
|
+
return dq, indices
|
|
41
|
+
|
|
42
|
+
def scale_resolutions(self, resolutions: Tensor) -> Tensor:
|
|
43
|
+
"""Scales the q-resolution values to [-1,1] range using the internal sigma range"""
|
|
44
|
+
sigma_min = 0.0 if self.share_smeared != 1.0 else self.sigma_min
|
|
45
|
+
return 2 * (resolutions - sigma_min) / (self.sigma_max - sigma_min) - 1
|
|
46
|
+
|
|
47
|
+
def get_curves(self, q_values: Tensor, params: BasicParams, refl_kwargs:dict = None):
|
|
48
|
+
refl_kwargs = refl_kwargs or {}
|
|
49
|
+
|
|
50
|
+
dq, indices = self.generate_resolutions(params.batch_size, device=params.device, dtype=params.dtype)
|
|
51
|
+
q_resolutions = torch.zeros(q_values.shape[0], 1, dtype=q_values.dtype, device=q_values.device)
|
|
52
|
+
|
|
53
|
+
if dq is None:
|
|
54
|
+
return params.reflectivity(q_values, **refl_kwargs), q_resolutions
|
|
55
|
+
|
|
56
|
+
refl_kwargs_not_smeared = {}
|
|
57
|
+
refl_kwargs_smeared = {}
|
|
58
|
+
for key, value in refl_kwargs.items():
|
|
59
|
+
if isinstance(value, torch.Tensor) and value.shape[0] == params.batch_size:
|
|
60
|
+
refl_kwargs_not_smeared[key] = value[~indices]
|
|
61
|
+
refl_kwargs_smeared[key] = value[indices]
|
|
62
|
+
else:
|
|
63
|
+
refl_kwargs_not_smeared[key] = value
|
|
64
|
+
refl_kwargs_smeared[key] = value
|
|
65
|
+
|
|
66
|
+
# Compute unsmeared reflectivity
|
|
67
|
+
if (~indices).sum().item():
|
|
68
|
+
if q_values.dim() == 2 and q_values.shape[0] > 1:
|
|
69
|
+
q = q_values[~indices]
|
|
70
|
+
else:
|
|
71
|
+
q = q_values
|
|
72
|
+
|
|
73
|
+
reflectivity_not_smeared = params[~indices].reflectivity(q, **refl_kwargs_not_smeared)
|
|
74
|
+
else:
|
|
75
|
+
reflectivity_not_smeared = None
|
|
76
|
+
|
|
77
|
+
# Compute smeared reflectivity
|
|
78
|
+
if indices.sum().item():
|
|
79
|
+
if q_values.dim() == 2 and q_values.shape[0] > 1:
|
|
80
|
+
q = q_values[indices]
|
|
81
|
+
else:
|
|
82
|
+
q = q_values
|
|
83
|
+
|
|
84
|
+
reflectivity_smeared = params[indices].reflectivity(
|
|
85
|
+
q, dq=dq, constant_dq=self.constant_dq, gauss_num=self.gauss_num, **refl_kwargs_smeared
|
|
86
|
+
)
|
|
87
|
+
else:
|
|
88
|
+
reflectivity_smeared = None
|
|
89
|
+
|
|
90
|
+
curves = torch.empty(params.batch_size, q_values.shape[-1], device=params.device, dtype=params.dtype)
|
|
91
|
+
|
|
92
|
+
if (~indices).sum().item():
|
|
93
|
+
curves[~indices] = reflectivity_not_smeared
|
|
94
|
+
|
|
95
|
+
curves[indices] = reflectivity_smeared
|
|
96
|
+
|
|
97
|
+
q_resolutions[indices] = dq
|
|
98
|
+
|
|
99
|
+
return curves, q_resolutions
|