reflectorch 1.3.0__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of reflectorch might be problematic. Click here for more details.
- reflectorch/__init__.py +17 -17
- reflectorch/data_generation/__init__.py +128 -126
- reflectorch/data_generation/dataset.py +210 -210
- reflectorch/data_generation/likelihoods.py +80 -80
- reflectorch/data_generation/noise.py +470 -470
- reflectorch/data_generation/priors/__init__.py +60 -60
- reflectorch/data_generation/priors/base.py +55 -55
- reflectorch/data_generation/priors/exp_subprior_sampler.py +298 -298
- reflectorch/data_generation/priors/independent_priors.py +195 -195
- reflectorch/data_generation/priors/multilayer_models.py +311 -311
- reflectorch/data_generation/priors/multilayer_structures.py +104 -104
- reflectorch/data_generation/priors/no_constraints.py +206 -206
- reflectorch/data_generation/priors/parametric_models.py +841 -841
- reflectorch/data_generation/priors/parametric_subpriors.py +369 -369
- reflectorch/data_generation/priors/params.py +252 -252
- reflectorch/data_generation/priors/sampler_strategies.py +369 -369
- reflectorch/data_generation/priors/scaler_mixin.py +65 -65
- reflectorch/data_generation/priors/subprior_sampler.py +371 -371
- reflectorch/data_generation/priors/utils.py +118 -118
- reflectorch/data_generation/process_data.py +41 -41
- reflectorch/data_generation/q_generator.py +280 -246
- reflectorch/data_generation/reflectivity/__init__.py +102 -102
- reflectorch/data_generation/reflectivity/abeles.py +97 -97
- reflectorch/data_generation/reflectivity/kinematical.py +70 -70
- reflectorch/data_generation/reflectivity/memory_eff.py +105 -105
- reflectorch/data_generation/reflectivity/numpy_implementations.py +120 -120
- reflectorch/data_generation/reflectivity/smearing.py +138 -138
- reflectorch/data_generation/reflectivity/smearing_pointwise.py +109 -109
- reflectorch/data_generation/scale_curves.py +112 -112
- reflectorch/data_generation/smearing.py +98 -98
- reflectorch/data_generation/utils.py +223 -222
- reflectorch/extensions/jupyter/__init__.py +11 -6
- reflectorch/extensions/jupyter/api.py +85 -0
- reflectorch/extensions/jupyter/callbacks.py +34 -34
- reflectorch/extensions/jupyter/components.py +758 -0
- reflectorch/extensions/jupyter/custom_select.py +268 -0
- reflectorch/extensions/jupyter/log_widget.py +241 -0
- reflectorch/extensions/jupyter/model_selection.py +495 -0
- reflectorch/extensions/jupyter/plotly_plot_manager.py +329 -0
- reflectorch/extensions/jupyter/widget.py +625 -0
- reflectorch/extensions/matplotlib/__init__.py +5 -5
- reflectorch/extensions/matplotlib/losses.py +32 -32
- reflectorch/extensions/refnx/refnx_conversion.py +76 -76
- reflectorch/inference/__init__.py +28 -24
- reflectorch/inference/inference_model.py +847 -851
- reflectorch/inference/input_interface.py +239 -0
- reflectorch/inference/loading_data.py +37 -0
- reflectorch/inference/multilayer_fitter.py +171 -171
- reflectorch/inference/multilayer_inference_model.py +193 -193
- reflectorch/inference/plotting.py +524 -98
- reflectorch/inference/preprocess_exp/__init__.py +6 -6
- reflectorch/inference/preprocess_exp/attenuation.py +36 -36
- reflectorch/inference/preprocess_exp/cut_with_q_ratio.py +31 -31
- reflectorch/inference/preprocess_exp/footprint.py +81 -81
- reflectorch/inference/preprocess_exp/interpolation.py +19 -16
- reflectorch/inference/preprocess_exp/normalize.py +21 -21
- reflectorch/inference/preprocess_exp/preprocess.py +121 -121
- reflectorch/inference/query_matcher.py +81 -81
- reflectorch/inference/record_time.py +43 -43
- reflectorch/inference/sampler_solution.py +56 -56
- reflectorch/inference/scipy_fitter.py +272 -248
- reflectorch/inference/torch_fitter.py +87 -87
- reflectorch/ml/__init__.py +32 -32
- reflectorch/ml/basic_trainer.py +292 -292
- reflectorch/ml/callbacks.py +80 -80
- reflectorch/ml/dataloaders.py +26 -26
- reflectorch/ml/loggers.py +55 -55
- reflectorch/ml/schedulers.py +355 -355
- reflectorch/ml/trainers.py +200 -191
- reflectorch/ml/utils.py +2 -2
- reflectorch/models/__init__.py +15 -14
- reflectorch/models/activations.py +50 -50
- reflectorch/models/encoders/__init__.py +19 -17
- reflectorch/models/encoders/conv_encoder.py +218 -218
- reflectorch/models/encoders/conv_res_net.py +115 -115
- reflectorch/models/encoders/fno.py +133 -133
- reflectorch/models/encoders/integral_kernel_embedding.py +390 -0
- reflectorch/models/networks/__init__.py +14 -14
- reflectorch/models/networks/mlp_networks.py +434 -428
- reflectorch/models/networks/residual_net.py +156 -156
- reflectorch/paths.py +29 -27
- reflectorch/runs/__init__.py +31 -31
- reflectorch/runs/config.py +25 -25
- reflectorch/runs/slurm_utils.py +93 -93
- reflectorch/runs/train.py +78 -78
- reflectorch/runs/utils.py +404 -401
- reflectorch/test_config.py +4 -4
- reflectorch/train.py +4 -4
- reflectorch/train_on_cluster.py +4 -4
- reflectorch/utils.py +98 -68
- {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/METADATA +129 -125
- reflectorch-1.5.0.dist-info/RECORD +96 -0
- {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/WHEEL +1 -1
- {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/licenses/LICENSE.txt +20 -20
- reflectorch-1.3.0.dist-info/RECORD +0 -86
- {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/top_level.txt +0 -0
|
@@ -1,138 +1,138 @@
|
|
|
1
|
-
# -*- coding: utf-8 -*-
|
|
2
|
-
from math import pi, sqrt, log
|
|
3
|
-
|
|
4
|
-
import torch
|
|
5
|
-
from torch import Tensor
|
|
6
|
-
|
|
7
|
-
from reflectorch.data_generation.reflectivity.abeles import abeles
|
|
8
|
-
from torch.nn.functional import conv1d, pad
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
def abeles_constant_smearing(
|
|
12
|
-
q: Tensor,
|
|
13
|
-
thickness: Tensor,
|
|
14
|
-
roughness: Tensor,
|
|
15
|
-
sld: Tensor,
|
|
16
|
-
dq: Tensor = None,
|
|
17
|
-
gauss_num: int = 31,
|
|
18
|
-
constant_dq: bool = False,
|
|
19
|
-
abeles_func=None,
|
|
20
|
-
**abeles_kwargs
|
|
21
|
-
):
|
|
22
|
-
abeles_func = abeles_func or abeles
|
|
23
|
-
|
|
24
|
-
if dq.dtype != thickness.dtype:
|
|
25
|
-
q = q.to(thickness)
|
|
26
|
-
|
|
27
|
-
if dq.dtype != thickness.dtype:
|
|
28
|
-
dq = dq.to(thickness)
|
|
29
|
-
|
|
30
|
-
if q.shape[0] == 1:
|
|
31
|
-
q = q.repeat(thickness.shape[0], 1)
|
|
32
|
-
|
|
33
|
-
q_lin = _get_q_axes(q, dq, gauss_num, constant_dq=constant_dq)
|
|
34
|
-
kernels = _get_t_gauss_kernels(dq, gauss_num)
|
|
35
|
-
|
|
36
|
-
curves = abeles_func(q_lin, thickness, roughness, sld, **abeles_kwargs)
|
|
37
|
-
|
|
38
|
-
padding = (kernels.shape[-1] - 1) // 2
|
|
39
|
-
padded_curves = pad(curves, (padding, padding), 'reflect')
|
|
40
|
-
|
|
41
|
-
smeared_curves = conv1d(
|
|
42
|
-
padded_curves, kernels[:, None], groups=kernels.shape[0],
|
|
43
|
-
)
|
|
44
|
-
|
|
45
|
-
if q.shape[0] != smeared_curves.shape[0]:
|
|
46
|
-
repeat_factor = smeared_curves.shape[0] // q.shape[0]
|
|
47
|
-
q = q.repeat(repeat_factor, 1)
|
|
48
|
-
q_lin = q_lin.repeat(repeat_factor, 1)
|
|
49
|
-
|
|
50
|
-
smeared_curves = _batch_linear_interp1d(q_lin, smeared_curves, q)
|
|
51
|
-
|
|
52
|
-
return smeared_curves
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
_FWHM = 2 * sqrt(2 * log(2.0))
|
|
56
|
-
_2PI_SQRT = 1. / sqrt(2 * pi)
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
def _batch_linspace(start: Tensor, end: Tensor, num: int):
|
|
60
|
-
return torch.linspace(0, 1, int(num), device=end.device, dtype=end.dtype)[None] * (end - start) + start
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
def _torch_gauss(x, s):
|
|
64
|
-
return _2PI_SQRT / s * torch.exp(-0.5 * x ** 2 / s / s)
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
def _get_t_gauss_kernels(resolutions: Tensor, gaussnum: int = 51):
|
|
68
|
-
gauss_x = _batch_linspace(-1.7 * resolutions, 1.7 * resolutions, gaussnum)
|
|
69
|
-
gauss_y = _torch_gauss(gauss_x, resolutions / _FWHM) * (gauss_x[:, 1] - gauss_x[:, 0])[:, None]
|
|
70
|
-
return gauss_y
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
def _get_q_axes(q: Tensor, resolutions: Tensor, gaussnum: int = 51, constant_dq: bool = False):
|
|
74
|
-
if constant_dq:
|
|
75
|
-
return _get_q_axes_for_constant_dq(q, resolutions, gaussnum)
|
|
76
|
-
else:
|
|
77
|
-
return _get_q_axes_for_linear_dq(q, resolutions, gaussnum)
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
def _get_q_axes_for_linear_dq(q: Tensor, resolutions: Tensor, gaussnum: int = 51):
|
|
81
|
-
gaussgpoint = (gaussnum - 1) / 2
|
|
82
|
-
|
|
83
|
-
lowq = torch.clamp_min_(q.min(1).values, 1e-6)
|
|
84
|
-
highq = q.max(1).values
|
|
85
|
-
|
|
86
|
-
start = torch.log10(lowq)[:, None] - 6 * resolutions / _FWHM
|
|
87
|
-
end = torch.log10(highq[:, None] * (1 + 6 * resolutions / _FWHM))
|
|
88
|
-
|
|
89
|
-
interpnums = torch.abs(
|
|
90
|
-
(torch.abs(end - start)) / (1.7 * resolutions / _FWHM / gaussgpoint)
|
|
91
|
-
).round().to(int)
|
|
92
|
-
|
|
93
|
-
q_lin = 10 ** _batch_linspace_with_padding(start, end, interpnums)
|
|
94
|
-
|
|
95
|
-
return q_lin
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
def _get_q_axes_for_constant_dq(q: Tensor, resolutions: Tensor, gaussnum: int = 51) -> Tensor:
|
|
99
|
-
gaussgpoint = (gaussnum - 1) / 2
|
|
100
|
-
|
|
101
|
-
start = q.min(1).values[:, None] - resolutions * 1.7
|
|
102
|
-
end = q.max(1).values[:, None] + resolutions * 1.7
|
|
103
|
-
|
|
104
|
-
interpnums = torch.abs(
|
|
105
|
-
(torch.abs(end - start)) / (1.7 * resolutions / gaussgpoint)
|
|
106
|
-
).round().to(int)
|
|
107
|
-
|
|
108
|
-
q_lin = _batch_linspace_with_padding(start, end, interpnums)
|
|
109
|
-
q_lin = torch.clamp_min_(q_lin, 1e-6)
|
|
110
|
-
|
|
111
|
-
return q_lin
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
def _batch_linspace_with_padding(start: Tensor, end: Tensor, nums: Tensor) -> Tensor:
|
|
115
|
-
max_num = nums.max().int().item()
|
|
116
|
-
|
|
117
|
-
deltas = 1 / (nums - 1)
|
|
118
|
-
|
|
119
|
-
x = torch.clamp_min_(_batch_linspace(deltas * (nums - max_num), torch.ones_like(deltas), max_num), 0)
|
|
120
|
-
|
|
121
|
-
x = x * (end - start) + start
|
|
122
|
-
|
|
123
|
-
return x
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
def _batch_linear_interp1d(x: Tensor, y: Tensor, x_new: Tensor) -> Tensor:
|
|
127
|
-
eps = torch.finfo(y.dtype).eps
|
|
128
|
-
|
|
129
|
-
ind = torch.searchsorted(x.contiguous(), x_new.contiguous())
|
|
130
|
-
|
|
131
|
-
ind = torch.clamp_(ind - 1, 0, x.shape[-1] - 2)
|
|
132
|
-
slopes = (y[..., 1:] - y[..., :-1]) / (eps + (x[..., 1:] - x[..., :-1]))
|
|
133
|
-
ind_y = ind + torch.arange(slopes.shape[0], device=slopes.device)[:, None] * y.shape[1]
|
|
134
|
-
ind_slopes = ind + torch.arange(slopes.shape[0], device=slopes.device)[:, None] * slopes.shape[1]
|
|
135
|
-
|
|
136
|
-
y_new = y.flatten()[ind_y] + slopes.flatten()[ind_slopes] * (x_new - x.flatten()[ind_y])
|
|
137
|
-
|
|
138
|
-
return y_new
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
from math import pi, sqrt, log
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
from torch import Tensor
|
|
6
|
+
|
|
7
|
+
from reflectorch.data_generation.reflectivity.abeles import abeles
|
|
8
|
+
from torch.nn.functional import conv1d, pad
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def abeles_constant_smearing(
|
|
12
|
+
q: Tensor,
|
|
13
|
+
thickness: Tensor,
|
|
14
|
+
roughness: Tensor,
|
|
15
|
+
sld: Tensor,
|
|
16
|
+
dq: Tensor = None,
|
|
17
|
+
gauss_num: int = 31,
|
|
18
|
+
constant_dq: bool = False,
|
|
19
|
+
abeles_func=None,
|
|
20
|
+
**abeles_kwargs
|
|
21
|
+
):
|
|
22
|
+
abeles_func = abeles_func or abeles
|
|
23
|
+
|
|
24
|
+
if dq.dtype != thickness.dtype:
|
|
25
|
+
q = q.to(thickness)
|
|
26
|
+
|
|
27
|
+
if dq.dtype != thickness.dtype:
|
|
28
|
+
dq = dq.to(thickness)
|
|
29
|
+
|
|
30
|
+
if q.shape[0] == 1:
|
|
31
|
+
q = q.repeat(thickness.shape[0], 1)
|
|
32
|
+
|
|
33
|
+
q_lin = _get_q_axes(q, dq, gauss_num, constant_dq=constant_dq)
|
|
34
|
+
kernels = _get_t_gauss_kernels(dq, gauss_num)
|
|
35
|
+
|
|
36
|
+
curves = abeles_func(q_lin, thickness, roughness, sld, **abeles_kwargs)
|
|
37
|
+
|
|
38
|
+
padding = (kernels.shape[-1] - 1) // 2
|
|
39
|
+
padded_curves = pad(curves, (padding, padding), 'reflect')
|
|
40
|
+
|
|
41
|
+
smeared_curves = conv1d(
|
|
42
|
+
padded_curves, kernels[:, None], groups=kernels.shape[0],
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
if q.shape[0] != smeared_curves.shape[0]:
|
|
46
|
+
repeat_factor = smeared_curves.shape[0] // q.shape[0]
|
|
47
|
+
q = q.repeat(repeat_factor, 1)
|
|
48
|
+
q_lin = q_lin.repeat(repeat_factor, 1)
|
|
49
|
+
|
|
50
|
+
smeared_curves = _batch_linear_interp1d(q_lin, smeared_curves, q)
|
|
51
|
+
|
|
52
|
+
return smeared_curves
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
_FWHM = 2 * sqrt(2 * log(2.0))
|
|
56
|
+
_2PI_SQRT = 1. / sqrt(2 * pi)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def _batch_linspace(start: Tensor, end: Tensor, num: int):
|
|
60
|
+
return torch.linspace(0, 1, int(num), device=end.device, dtype=end.dtype)[None] * (end - start) + start
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def _torch_gauss(x, s):
|
|
64
|
+
return _2PI_SQRT / s * torch.exp(-0.5 * x ** 2 / s / s)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def _get_t_gauss_kernels(resolutions: Tensor, gaussnum: int = 51):
|
|
68
|
+
gauss_x = _batch_linspace(-1.7 * resolutions, 1.7 * resolutions, gaussnum)
|
|
69
|
+
gauss_y = _torch_gauss(gauss_x, resolutions / _FWHM) * (gauss_x[:, 1] - gauss_x[:, 0])[:, None]
|
|
70
|
+
return gauss_y
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def _get_q_axes(q: Tensor, resolutions: Tensor, gaussnum: int = 51, constant_dq: bool = False):
|
|
74
|
+
if constant_dq:
|
|
75
|
+
return _get_q_axes_for_constant_dq(q, resolutions, gaussnum)
|
|
76
|
+
else:
|
|
77
|
+
return _get_q_axes_for_linear_dq(q, resolutions, gaussnum)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def _get_q_axes_for_linear_dq(q: Tensor, resolutions: Tensor, gaussnum: int = 51):
|
|
81
|
+
gaussgpoint = (gaussnum - 1) / 2
|
|
82
|
+
|
|
83
|
+
lowq = torch.clamp_min_(q.min(1).values, 1e-6)
|
|
84
|
+
highq = q.max(1).values
|
|
85
|
+
|
|
86
|
+
start = torch.log10(lowq)[:, None] - 6 * resolutions / _FWHM
|
|
87
|
+
end = torch.log10(highq[:, None] * (1 + 6 * resolutions / _FWHM))
|
|
88
|
+
|
|
89
|
+
interpnums = torch.abs(
|
|
90
|
+
(torch.abs(end - start)) / (1.7 * resolutions / _FWHM / gaussgpoint)
|
|
91
|
+
).round().to(int)
|
|
92
|
+
|
|
93
|
+
q_lin = 10 ** _batch_linspace_with_padding(start, end, interpnums)
|
|
94
|
+
|
|
95
|
+
return q_lin
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def _get_q_axes_for_constant_dq(q: Tensor, resolutions: Tensor, gaussnum: int = 51) -> Tensor:
|
|
99
|
+
gaussgpoint = (gaussnum - 1) / 2
|
|
100
|
+
|
|
101
|
+
start = q.min(1).values[:, None] - resolutions * 1.7
|
|
102
|
+
end = q.max(1).values[:, None] + resolutions * 1.7
|
|
103
|
+
|
|
104
|
+
interpnums = torch.abs(
|
|
105
|
+
(torch.abs(end - start)) / (1.7 * resolutions / gaussgpoint)
|
|
106
|
+
).round().to(int)
|
|
107
|
+
|
|
108
|
+
q_lin = _batch_linspace_with_padding(start, end, interpnums)
|
|
109
|
+
q_lin = torch.clamp_min_(q_lin, 1e-6)
|
|
110
|
+
|
|
111
|
+
return q_lin
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def _batch_linspace_with_padding(start: Tensor, end: Tensor, nums: Tensor) -> Tensor:
|
|
115
|
+
max_num = nums.max().int().item()
|
|
116
|
+
|
|
117
|
+
deltas = 1 / (nums - 1)
|
|
118
|
+
|
|
119
|
+
x = torch.clamp_min_(_batch_linspace(deltas * (nums - max_num), torch.ones_like(deltas), max_num), 0)
|
|
120
|
+
|
|
121
|
+
x = x * (end - start) + start
|
|
122
|
+
|
|
123
|
+
return x
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def _batch_linear_interp1d(x: Tensor, y: Tensor, x_new: Tensor) -> Tensor:
|
|
127
|
+
eps = torch.finfo(y.dtype).eps
|
|
128
|
+
|
|
129
|
+
ind = torch.searchsorted(x.contiguous(), x_new.contiguous())
|
|
130
|
+
|
|
131
|
+
ind = torch.clamp_(ind - 1, 0, x.shape[-1] - 2)
|
|
132
|
+
slopes = (y[..., 1:] - y[..., :-1]) / (eps + (x[..., 1:] - x[..., :-1]))
|
|
133
|
+
ind_y = ind + torch.arange(slopes.shape[0], device=slopes.device)[:, None] * y.shape[1]
|
|
134
|
+
ind_slopes = ind + torch.arange(slopes.shape[0], device=slopes.device)[:, None] * slopes.shape[1]
|
|
135
|
+
|
|
136
|
+
y_new = y.flatten()[ind_y] + slopes.flatten()[ind_slopes] * (x_new - x.flatten()[ind_y])
|
|
137
|
+
|
|
138
|
+
return y_new
|
|
@@ -1,110 +1,110 @@
|
|
|
1
|
-
import torch
|
|
2
|
-
import scipy
|
|
3
|
-
import numpy as np
|
|
4
|
-
from functools import lru_cache
|
|
5
|
-
from typing import Tuple
|
|
6
|
-
|
|
7
|
-
from reflectorch.data_generation.reflectivity.abeles import abeles
|
|
8
|
-
|
|
9
|
-
#Pytorch version based on the JAX implementation of pointwise smearing in the refnx package.
|
|
10
|
-
|
|
11
|
-
@lru_cache(maxsize=128)
|
|
12
|
-
def gauss_legendre(n: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
13
|
-
"""
|
|
14
|
-
Calculate Gaussian quadrature abscissae and weights.
|
|
15
|
-
|
|
16
|
-
Args:
|
|
17
|
-
n (int): Gaussian quadrature order.
|
|
18
|
-
|
|
19
|
-
Returns:
|
|
20
|
-
Tuple[torch.Tensor, torch.Tensor]: The abscissae and weights for Gauss-Legendre integration.
|
|
21
|
-
"""
|
|
22
|
-
return scipy.special.p_roots(n)
|
|
23
|
-
|
|
24
|
-
def gauss(x: torch.Tensor) -> torch.Tensor:
|
|
25
|
-
"""
|
|
26
|
-
Calculate the Gaussian function.
|
|
27
|
-
|
|
28
|
-
Args:
|
|
29
|
-
x (torch.Tensor): Input tensor.
|
|
30
|
-
|
|
31
|
-
Returns:
|
|
32
|
-
torch.Tensor: Output tensor after applying the Gaussian function.
|
|
33
|
-
"""
|
|
34
|
-
return torch.exp(-0.5 * x * x)
|
|
35
|
-
|
|
36
|
-
def abeles_pointwise_smearing(
|
|
37
|
-
q: torch.Tensor,
|
|
38
|
-
dq: torch.Tensor,
|
|
39
|
-
thickness: torch.Tensor,
|
|
40
|
-
roughness: torch.Tensor,
|
|
41
|
-
sld: torch.Tensor,
|
|
42
|
-
gauss_num: int = 17,
|
|
43
|
-
abeles_func=None,
|
|
44
|
-
**abeles_kwargs,
|
|
45
|
-
) -> torch.Tensor:
|
|
46
|
-
"""
|
|
47
|
-
Compute reflectivity with variable smearing using Gaussian quadrature.
|
|
48
|
-
|
|
49
|
-
Args:
|
|
50
|
-
q (torch.Tensor): The momentum transfer (q) values.
|
|
51
|
-
dq (torch.Tensor): The resolution for curve smearing.
|
|
52
|
-
thickness (torch.Tensor): The layer thicknesses.
|
|
53
|
-
roughness (torch.Tensor): The interlayer roughnesses.
|
|
54
|
-
sld (torch.Tensor): The SLDs of the layers.
|
|
55
|
-
sld_magnetic (torch.Tensor, optional): The magnetic SLDs of the layers.
|
|
56
|
-
magnetization_angle (torch.Tensor, optional): The magnetization angles.
|
|
57
|
-
polarizer_eff (torch.Tensor, optional): The polarizer efficiency.
|
|
58
|
-
analyzer_eff (torch.Tensor, optional): The analyzer efficiency.
|
|
59
|
-
abeles_func (Callable, optional): A function implementing the simulation of the reflectivity curves.
|
|
60
|
-
gauss_num (int, optional): Gaussian quadrature order. Defaults to 17.
|
|
61
|
-
|
|
62
|
-
Returns:
|
|
63
|
-
torch.Tensor: The computed reflectivity curves.
|
|
64
|
-
"""
|
|
65
|
-
abeles_func = abeles_func or abeles
|
|
66
|
-
|
|
67
|
-
if q.shape[0] == 1:
|
|
68
|
-
q = q.repeat(thickness.shape[0], 1)
|
|
69
|
-
|
|
70
|
-
_FWHM = 2 * np.sqrt(2 * np.log(2.0))
|
|
71
|
-
_INTLIMIT = 3.5
|
|
72
|
-
|
|
73
|
-
bs = q.shape[0]
|
|
74
|
-
nq = q.shape[-1]
|
|
75
|
-
device = q.device
|
|
76
|
-
|
|
77
|
-
quad_order = gauss_num
|
|
78
|
-
abscissa, weights = gauss_legendre(quad_order)
|
|
79
|
-
abscissa = torch.tensor(abscissa)[None, :, None].to(device)
|
|
80
|
-
weights = torch.tensor(weights)[None, :, None].to(device)
|
|
81
|
-
prefactor = 1.0 / np.sqrt(2 * np.pi)
|
|
82
|
-
|
|
83
|
-
gaussvals = prefactor * gauss(abscissa * _INTLIMIT)
|
|
84
|
-
|
|
85
|
-
va = q[:, None, :] - _INTLIMIT * dq[:, None, :] / _FWHM
|
|
86
|
-
vb = q[:, None, :] + _INTLIMIT * dq[:, None, :] / _FWHM
|
|
87
|
-
|
|
88
|
-
qvals_for_res_0 = (abscissa * (vb - va) + vb + va) / 2
|
|
89
|
-
qvals_for_res = qvals_for_res_0.reshape(bs, -1)
|
|
90
|
-
|
|
91
|
-
refl_curves = abeles_func(
|
|
92
|
-
q=qvals_for_res,
|
|
93
|
-
thickness=thickness,
|
|
94
|
-
roughness=roughness,
|
|
95
|
-
sld=sld,
|
|
96
|
-
**abeles_kwargs
|
|
97
|
-
)
|
|
98
|
-
|
|
99
|
-
# Handle multiple channels
|
|
100
|
-
if refl_curves.dim() == 3:
|
|
101
|
-
n_channels = refl_curves.shape[1]
|
|
102
|
-
refl_curves = refl_curves.reshape(bs, n_channels, quad_order, nq)
|
|
103
|
-
refl_curves = refl_curves * gaussvals.unsqueeze(1) * weights.unsqueeze(1)
|
|
104
|
-
refl_curves = torch.sum(refl_curves, dim=2) * _INTLIMIT
|
|
105
|
-
else:
|
|
106
|
-
refl_curves = refl_curves.reshape(bs, quad_order, nq)
|
|
107
|
-
refl_curves = refl_curves * gaussvals * weights
|
|
108
|
-
refl_curves = torch.sum(refl_curves, dim=1) * _INTLIMIT
|
|
109
|
-
|
|
1
|
+
import torch
|
|
2
|
+
import scipy
|
|
3
|
+
import numpy as np
|
|
4
|
+
from functools import lru_cache
|
|
5
|
+
from typing import Tuple
|
|
6
|
+
|
|
7
|
+
from reflectorch.data_generation.reflectivity.abeles import abeles
|
|
8
|
+
|
|
9
|
+
#Pytorch version based on the JAX implementation of pointwise smearing in the refnx package.
|
|
10
|
+
|
|
11
|
+
@lru_cache(maxsize=128)
|
|
12
|
+
def gauss_legendre(n: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
13
|
+
"""
|
|
14
|
+
Calculate Gaussian quadrature abscissae and weights.
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
n (int): Gaussian quadrature order.
|
|
18
|
+
|
|
19
|
+
Returns:
|
|
20
|
+
Tuple[torch.Tensor, torch.Tensor]: The abscissae and weights for Gauss-Legendre integration.
|
|
21
|
+
"""
|
|
22
|
+
return scipy.special.p_roots(n)
|
|
23
|
+
|
|
24
|
+
def gauss(x: torch.Tensor) -> torch.Tensor:
|
|
25
|
+
"""
|
|
26
|
+
Calculate the Gaussian function.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
x (torch.Tensor): Input tensor.
|
|
30
|
+
|
|
31
|
+
Returns:
|
|
32
|
+
torch.Tensor: Output tensor after applying the Gaussian function.
|
|
33
|
+
"""
|
|
34
|
+
return torch.exp(-0.5 * x * x)
|
|
35
|
+
|
|
36
|
+
def abeles_pointwise_smearing(
|
|
37
|
+
q: torch.Tensor,
|
|
38
|
+
dq: torch.Tensor,
|
|
39
|
+
thickness: torch.Tensor,
|
|
40
|
+
roughness: torch.Tensor,
|
|
41
|
+
sld: torch.Tensor,
|
|
42
|
+
gauss_num: int = 17,
|
|
43
|
+
abeles_func=None,
|
|
44
|
+
**abeles_kwargs,
|
|
45
|
+
) -> torch.Tensor:
|
|
46
|
+
"""
|
|
47
|
+
Compute reflectivity with variable smearing using Gaussian quadrature.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
q (torch.Tensor): The momentum transfer (q) values.
|
|
51
|
+
dq (torch.Tensor): The resolution for curve smearing.
|
|
52
|
+
thickness (torch.Tensor): The layer thicknesses.
|
|
53
|
+
roughness (torch.Tensor): The interlayer roughnesses.
|
|
54
|
+
sld (torch.Tensor): The SLDs of the layers.
|
|
55
|
+
sld_magnetic (torch.Tensor, optional): The magnetic SLDs of the layers.
|
|
56
|
+
magnetization_angle (torch.Tensor, optional): The magnetization angles.
|
|
57
|
+
polarizer_eff (torch.Tensor, optional): The polarizer efficiency.
|
|
58
|
+
analyzer_eff (torch.Tensor, optional): The analyzer efficiency.
|
|
59
|
+
abeles_func (Callable, optional): A function implementing the simulation of the reflectivity curves.
|
|
60
|
+
gauss_num (int, optional): Gaussian quadrature order. Defaults to 17.
|
|
61
|
+
|
|
62
|
+
Returns:
|
|
63
|
+
torch.Tensor: The computed reflectivity curves.
|
|
64
|
+
"""
|
|
65
|
+
abeles_func = abeles_func or abeles
|
|
66
|
+
|
|
67
|
+
if q.shape[0] == 1:
|
|
68
|
+
q = q.repeat(thickness.shape[0], 1)
|
|
69
|
+
|
|
70
|
+
_FWHM = 2 * np.sqrt(2 * np.log(2.0))
|
|
71
|
+
_INTLIMIT = 3.5
|
|
72
|
+
|
|
73
|
+
bs = q.shape[0]
|
|
74
|
+
nq = q.shape[-1]
|
|
75
|
+
device = q.device
|
|
76
|
+
|
|
77
|
+
quad_order = gauss_num
|
|
78
|
+
abscissa, weights = gauss_legendre(quad_order)
|
|
79
|
+
abscissa = torch.tensor(abscissa)[None, :, None].to(device)
|
|
80
|
+
weights = torch.tensor(weights)[None, :, None].to(device)
|
|
81
|
+
prefactor = 1.0 / np.sqrt(2 * np.pi)
|
|
82
|
+
|
|
83
|
+
gaussvals = prefactor * gauss(abscissa * _INTLIMIT)
|
|
84
|
+
|
|
85
|
+
va = q[:, None, :] - _INTLIMIT * dq[:, None, :] / _FWHM
|
|
86
|
+
vb = q[:, None, :] + _INTLIMIT * dq[:, None, :] / _FWHM
|
|
87
|
+
|
|
88
|
+
qvals_for_res_0 = (abscissa * (vb - va) + vb + va) / 2
|
|
89
|
+
qvals_for_res = qvals_for_res_0.reshape(bs, -1)
|
|
90
|
+
|
|
91
|
+
refl_curves = abeles_func(
|
|
92
|
+
q=qvals_for_res,
|
|
93
|
+
thickness=thickness,
|
|
94
|
+
roughness=roughness,
|
|
95
|
+
sld=sld,
|
|
96
|
+
**abeles_kwargs
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
# Handle multiple channels
|
|
100
|
+
if refl_curves.dim() == 3:
|
|
101
|
+
n_channels = refl_curves.shape[1]
|
|
102
|
+
refl_curves = refl_curves.reshape(bs, n_channels, quad_order, nq)
|
|
103
|
+
refl_curves = refl_curves * gaussvals.unsqueeze(1) * weights.unsqueeze(1)
|
|
104
|
+
refl_curves = torch.sum(refl_curves, dim=2) * _INTLIMIT
|
|
105
|
+
else:
|
|
106
|
+
refl_curves = refl_curves.reshape(bs, quad_order, nq)
|
|
107
|
+
refl_curves = refl_curves * gaussvals * weights
|
|
108
|
+
refl_curves = torch.sum(refl_curves, dim=1) * _INTLIMIT
|
|
109
|
+
|
|
110
110
|
return refl_curves
|