reflectorch 1.3.0__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of reflectorch might be problematic. Click here for more details.

Files changed (96) hide show
  1. reflectorch/__init__.py +17 -17
  2. reflectorch/data_generation/__init__.py +128 -126
  3. reflectorch/data_generation/dataset.py +210 -210
  4. reflectorch/data_generation/likelihoods.py +80 -80
  5. reflectorch/data_generation/noise.py +470 -470
  6. reflectorch/data_generation/priors/__init__.py +60 -60
  7. reflectorch/data_generation/priors/base.py +55 -55
  8. reflectorch/data_generation/priors/exp_subprior_sampler.py +298 -298
  9. reflectorch/data_generation/priors/independent_priors.py +195 -195
  10. reflectorch/data_generation/priors/multilayer_models.py +311 -311
  11. reflectorch/data_generation/priors/multilayer_structures.py +104 -104
  12. reflectorch/data_generation/priors/no_constraints.py +206 -206
  13. reflectorch/data_generation/priors/parametric_models.py +841 -841
  14. reflectorch/data_generation/priors/parametric_subpriors.py +369 -369
  15. reflectorch/data_generation/priors/params.py +252 -252
  16. reflectorch/data_generation/priors/sampler_strategies.py +369 -369
  17. reflectorch/data_generation/priors/scaler_mixin.py +65 -65
  18. reflectorch/data_generation/priors/subprior_sampler.py +371 -371
  19. reflectorch/data_generation/priors/utils.py +118 -118
  20. reflectorch/data_generation/process_data.py +41 -41
  21. reflectorch/data_generation/q_generator.py +280 -246
  22. reflectorch/data_generation/reflectivity/__init__.py +102 -102
  23. reflectorch/data_generation/reflectivity/abeles.py +97 -97
  24. reflectorch/data_generation/reflectivity/kinematical.py +70 -70
  25. reflectorch/data_generation/reflectivity/memory_eff.py +105 -105
  26. reflectorch/data_generation/reflectivity/numpy_implementations.py +120 -120
  27. reflectorch/data_generation/reflectivity/smearing.py +138 -138
  28. reflectorch/data_generation/reflectivity/smearing_pointwise.py +109 -109
  29. reflectorch/data_generation/scale_curves.py +112 -112
  30. reflectorch/data_generation/smearing.py +98 -98
  31. reflectorch/data_generation/utils.py +223 -222
  32. reflectorch/extensions/jupyter/__init__.py +11 -6
  33. reflectorch/extensions/jupyter/api.py +85 -0
  34. reflectorch/extensions/jupyter/callbacks.py +34 -34
  35. reflectorch/extensions/jupyter/components.py +758 -0
  36. reflectorch/extensions/jupyter/custom_select.py +268 -0
  37. reflectorch/extensions/jupyter/log_widget.py +241 -0
  38. reflectorch/extensions/jupyter/model_selection.py +495 -0
  39. reflectorch/extensions/jupyter/plotly_plot_manager.py +329 -0
  40. reflectorch/extensions/jupyter/widget.py +625 -0
  41. reflectorch/extensions/matplotlib/__init__.py +5 -5
  42. reflectorch/extensions/matplotlib/losses.py +32 -32
  43. reflectorch/extensions/refnx/refnx_conversion.py +76 -76
  44. reflectorch/inference/__init__.py +28 -24
  45. reflectorch/inference/inference_model.py +847 -851
  46. reflectorch/inference/input_interface.py +239 -0
  47. reflectorch/inference/loading_data.py +37 -0
  48. reflectorch/inference/multilayer_fitter.py +171 -171
  49. reflectorch/inference/multilayer_inference_model.py +193 -193
  50. reflectorch/inference/plotting.py +524 -98
  51. reflectorch/inference/preprocess_exp/__init__.py +6 -6
  52. reflectorch/inference/preprocess_exp/attenuation.py +36 -36
  53. reflectorch/inference/preprocess_exp/cut_with_q_ratio.py +31 -31
  54. reflectorch/inference/preprocess_exp/footprint.py +81 -81
  55. reflectorch/inference/preprocess_exp/interpolation.py +19 -16
  56. reflectorch/inference/preprocess_exp/normalize.py +21 -21
  57. reflectorch/inference/preprocess_exp/preprocess.py +121 -121
  58. reflectorch/inference/query_matcher.py +81 -81
  59. reflectorch/inference/record_time.py +43 -43
  60. reflectorch/inference/sampler_solution.py +56 -56
  61. reflectorch/inference/scipy_fitter.py +272 -248
  62. reflectorch/inference/torch_fitter.py +87 -87
  63. reflectorch/ml/__init__.py +32 -32
  64. reflectorch/ml/basic_trainer.py +292 -292
  65. reflectorch/ml/callbacks.py +80 -80
  66. reflectorch/ml/dataloaders.py +26 -26
  67. reflectorch/ml/loggers.py +55 -55
  68. reflectorch/ml/schedulers.py +355 -355
  69. reflectorch/ml/trainers.py +200 -191
  70. reflectorch/ml/utils.py +2 -2
  71. reflectorch/models/__init__.py +15 -14
  72. reflectorch/models/activations.py +50 -50
  73. reflectorch/models/encoders/__init__.py +19 -17
  74. reflectorch/models/encoders/conv_encoder.py +218 -218
  75. reflectorch/models/encoders/conv_res_net.py +115 -115
  76. reflectorch/models/encoders/fno.py +133 -133
  77. reflectorch/models/encoders/integral_kernel_embedding.py +390 -0
  78. reflectorch/models/networks/__init__.py +14 -14
  79. reflectorch/models/networks/mlp_networks.py +434 -428
  80. reflectorch/models/networks/residual_net.py +156 -156
  81. reflectorch/paths.py +29 -27
  82. reflectorch/runs/__init__.py +31 -31
  83. reflectorch/runs/config.py +25 -25
  84. reflectorch/runs/slurm_utils.py +93 -93
  85. reflectorch/runs/train.py +78 -78
  86. reflectorch/runs/utils.py +404 -401
  87. reflectorch/test_config.py +4 -4
  88. reflectorch/train.py +4 -4
  89. reflectorch/train_on_cluster.py +4 -4
  90. reflectorch/utils.py +98 -68
  91. {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/METADATA +129 -125
  92. reflectorch-1.5.0.dist-info/RECORD +96 -0
  93. {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/WHEEL +1 -1
  94. {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/licenses/LICENSE.txt +20 -20
  95. reflectorch-1.3.0.dist-info/RECORD +0 -86
  96. {reflectorch-1.3.0.dist-info → reflectorch-1.5.0.dist-info}/top_level.txt +0 -0
@@ -1,138 +1,138 @@
1
- # -*- coding: utf-8 -*-
2
- from math import pi, sqrt, log
3
-
4
- import torch
5
- from torch import Tensor
6
-
7
- from reflectorch.data_generation.reflectivity.abeles import abeles
8
- from torch.nn.functional import conv1d, pad
9
-
10
-
11
- def abeles_constant_smearing(
12
- q: Tensor,
13
- thickness: Tensor,
14
- roughness: Tensor,
15
- sld: Tensor,
16
- dq: Tensor = None,
17
- gauss_num: int = 31,
18
- constant_dq: bool = False,
19
- abeles_func=None,
20
- **abeles_kwargs
21
- ):
22
- abeles_func = abeles_func or abeles
23
-
24
- if dq.dtype != thickness.dtype:
25
- q = q.to(thickness)
26
-
27
- if dq.dtype != thickness.dtype:
28
- dq = dq.to(thickness)
29
-
30
- if q.shape[0] == 1:
31
- q = q.repeat(thickness.shape[0], 1)
32
-
33
- q_lin = _get_q_axes(q, dq, gauss_num, constant_dq=constant_dq)
34
- kernels = _get_t_gauss_kernels(dq, gauss_num)
35
-
36
- curves = abeles_func(q_lin, thickness, roughness, sld, **abeles_kwargs)
37
-
38
- padding = (kernels.shape[-1] - 1) // 2
39
- padded_curves = pad(curves, (padding, padding), 'reflect')
40
-
41
- smeared_curves = conv1d(
42
- padded_curves, kernels[:, None], groups=kernels.shape[0],
43
- )
44
-
45
- if q.shape[0] != smeared_curves.shape[0]:
46
- repeat_factor = smeared_curves.shape[0] // q.shape[0]
47
- q = q.repeat(repeat_factor, 1)
48
- q_lin = q_lin.repeat(repeat_factor, 1)
49
-
50
- smeared_curves = _batch_linear_interp1d(q_lin, smeared_curves, q)
51
-
52
- return smeared_curves
53
-
54
-
55
- _FWHM = 2 * sqrt(2 * log(2.0))
56
- _2PI_SQRT = 1. / sqrt(2 * pi)
57
-
58
-
59
- def _batch_linspace(start: Tensor, end: Tensor, num: int):
60
- return torch.linspace(0, 1, int(num), device=end.device, dtype=end.dtype)[None] * (end - start) + start
61
-
62
-
63
- def _torch_gauss(x, s):
64
- return _2PI_SQRT / s * torch.exp(-0.5 * x ** 2 / s / s)
65
-
66
-
67
- def _get_t_gauss_kernels(resolutions: Tensor, gaussnum: int = 51):
68
- gauss_x = _batch_linspace(-1.7 * resolutions, 1.7 * resolutions, gaussnum)
69
- gauss_y = _torch_gauss(gauss_x, resolutions / _FWHM) * (gauss_x[:, 1] - gauss_x[:, 0])[:, None]
70
- return gauss_y
71
-
72
-
73
- def _get_q_axes(q: Tensor, resolutions: Tensor, gaussnum: int = 51, constant_dq: bool = False):
74
- if constant_dq:
75
- return _get_q_axes_for_constant_dq(q, resolutions, gaussnum)
76
- else:
77
- return _get_q_axes_for_linear_dq(q, resolutions, gaussnum)
78
-
79
-
80
- def _get_q_axes_for_linear_dq(q: Tensor, resolutions: Tensor, gaussnum: int = 51):
81
- gaussgpoint = (gaussnum - 1) / 2
82
-
83
- lowq = torch.clamp_min_(q.min(1).values, 1e-6)
84
- highq = q.max(1).values
85
-
86
- start = torch.log10(lowq)[:, None] - 6 * resolutions / _FWHM
87
- end = torch.log10(highq[:, None] * (1 + 6 * resolutions / _FWHM))
88
-
89
- interpnums = torch.abs(
90
- (torch.abs(end - start)) / (1.7 * resolutions / _FWHM / gaussgpoint)
91
- ).round().to(int)
92
-
93
- q_lin = 10 ** _batch_linspace_with_padding(start, end, interpnums)
94
-
95
- return q_lin
96
-
97
-
98
- def _get_q_axes_for_constant_dq(q: Tensor, resolutions: Tensor, gaussnum: int = 51) -> Tensor:
99
- gaussgpoint = (gaussnum - 1) / 2
100
-
101
- start = q.min(1).values[:, None] - resolutions * 1.7
102
- end = q.max(1).values[:, None] + resolutions * 1.7
103
-
104
- interpnums = torch.abs(
105
- (torch.abs(end - start)) / (1.7 * resolutions / gaussgpoint)
106
- ).round().to(int)
107
-
108
- q_lin = _batch_linspace_with_padding(start, end, interpnums)
109
- q_lin = torch.clamp_min_(q_lin, 1e-6)
110
-
111
- return q_lin
112
-
113
-
114
- def _batch_linspace_with_padding(start: Tensor, end: Tensor, nums: Tensor) -> Tensor:
115
- max_num = nums.max().int().item()
116
-
117
- deltas = 1 / (nums - 1)
118
-
119
- x = torch.clamp_min_(_batch_linspace(deltas * (nums - max_num), torch.ones_like(deltas), max_num), 0)
120
-
121
- x = x * (end - start) + start
122
-
123
- return x
124
-
125
-
126
- def _batch_linear_interp1d(x: Tensor, y: Tensor, x_new: Tensor) -> Tensor:
127
- eps = torch.finfo(y.dtype).eps
128
-
129
- ind = torch.searchsorted(x.contiguous(), x_new.contiguous())
130
-
131
- ind = torch.clamp_(ind - 1, 0, x.shape[-1] - 2)
132
- slopes = (y[..., 1:] - y[..., :-1]) / (eps + (x[..., 1:] - x[..., :-1]))
133
- ind_y = ind + torch.arange(slopes.shape[0], device=slopes.device)[:, None] * y.shape[1]
134
- ind_slopes = ind + torch.arange(slopes.shape[0], device=slopes.device)[:, None] * slopes.shape[1]
135
-
136
- y_new = y.flatten()[ind_y] + slopes.flatten()[ind_slopes] * (x_new - x.flatten()[ind_y])
137
-
138
- return y_new
1
+ # -*- coding: utf-8 -*-
2
+ from math import pi, sqrt, log
3
+
4
+ import torch
5
+ from torch import Tensor
6
+
7
+ from reflectorch.data_generation.reflectivity.abeles import abeles
8
+ from torch.nn.functional import conv1d, pad
9
+
10
+
11
+ def abeles_constant_smearing(
12
+ q: Tensor,
13
+ thickness: Tensor,
14
+ roughness: Tensor,
15
+ sld: Tensor,
16
+ dq: Tensor = None,
17
+ gauss_num: int = 31,
18
+ constant_dq: bool = False,
19
+ abeles_func=None,
20
+ **abeles_kwargs
21
+ ):
22
+ abeles_func = abeles_func or abeles
23
+
24
+ if dq.dtype != thickness.dtype:
25
+ q = q.to(thickness)
26
+
27
+ if dq.dtype != thickness.dtype:
28
+ dq = dq.to(thickness)
29
+
30
+ if q.shape[0] == 1:
31
+ q = q.repeat(thickness.shape[0], 1)
32
+
33
+ q_lin = _get_q_axes(q, dq, gauss_num, constant_dq=constant_dq)
34
+ kernels = _get_t_gauss_kernels(dq, gauss_num)
35
+
36
+ curves = abeles_func(q_lin, thickness, roughness, sld, **abeles_kwargs)
37
+
38
+ padding = (kernels.shape[-1] - 1) // 2
39
+ padded_curves = pad(curves, (padding, padding), 'reflect')
40
+
41
+ smeared_curves = conv1d(
42
+ padded_curves, kernels[:, None], groups=kernels.shape[0],
43
+ )
44
+
45
+ if q.shape[0] != smeared_curves.shape[0]:
46
+ repeat_factor = smeared_curves.shape[0] // q.shape[0]
47
+ q = q.repeat(repeat_factor, 1)
48
+ q_lin = q_lin.repeat(repeat_factor, 1)
49
+
50
+ smeared_curves = _batch_linear_interp1d(q_lin, smeared_curves, q)
51
+
52
+ return smeared_curves
53
+
54
+
55
+ _FWHM = 2 * sqrt(2 * log(2.0))
56
+ _2PI_SQRT = 1. / sqrt(2 * pi)
57
+
58
+
59
+ def _batch_linspace(start: Tensor, end: Tensor, num: int):
60
+ return torch.linspace(0, 1, int(num), device=end.device, dtype=end.dtype)[None] * (end - start) + start
61
+
62
+
63
+ def _torch_gauss(x, s):
64
+ return _2PI_SQRT / s * torch.exp(-0.5 * x ** 2 / s / s)
65
+
66
+
67
+ def _get_t_gauss_kernels(resolutions: Tensor, gaussnum: int = 51):
68
+ gauss_x = _batch_linspace(-1.7 * resolutions, 1.7 * resolutions, gaussnum)
69
+ gauss_y = _torch_gauss(gauss_x, resolutions / _FWHM) * (gauss_x[:, 1] - gauss_x[:, 0])[:, None]
70
+ return gauss_y
71
+
72
+
73
+ def _get_q_axes(q: Tensor, resolutions: Tensor, gaussnum: int = 51, constant_dq: bool = False):
74
+ if constant_dq:
75
+ return _get_q_axes_for_constant_dq(q, resolutions, gaussnum)
76
+ else:
77
+ return _get_q_axes_for_linear_dq(q, resolutions, gaussnum)
78
+
79
+
80
+ def _get_q_axes_for_linear_dq(q: Tensor, resolutions: Tensor, gaussnum: int = 51):
81
+ gaussgpoint = (gaussnum - 1) / 2
82
+
83
+ lowq = torch.clamp_min_(q.min(1).values, 1e-6)
84
+ highq = q.max(1).values
85
+
86
+ start = torch.log10(lowq)[:, None] - 6 * resolutions / _FWHM
87
+ end = torch.log10(highq[:, None] * (1 + 6 * resolutions / _FWHM))
88
+
89
+ interpnums = torch.abs(
90
+ (torch.abs(end - start)) / (1.7 * resolutions / _FWHM / gaussgpoint)
91
+ ).round().to(int)
92
+
93
+ q_lin = 10 ** _batch_linspace_with_padding(start, end, interpnums)
94
+
95
+ return q_lin
96
+
97
+
98
+ def _get_q_axes_for_constant_dq(q: Tensor, resolutions: Tensor, gaussnum: int = 51) -> Tensor:
99
+ gaussgpoint = (gaussnum - 1) / 2
100
+
101
+ start = q.min(1).values[:, None] - resolutions * 1.7
102
+ end = q.max(1).values[:, None] + resolutions * 1.7
103
+
104
+ interpnums = torch.abs(
105
+ (torch.abs(end - start)) / (1.7 * resolutions / gaussgpoint)
106
+ ).round().to(int)
107
+
108
+ q_lin = _batch_linspace_with_padding(start, end, interpnums)
109
+ q_lin = torch.clamp_min_(q_lin, 1e-6)
110
+
111
+ return q_lin
112
+
113
+
114
+ def _batch_linspace_with_padding(start: Tensor, end: Tensor, nums: Tensor) -> Tensor:
115
+ max_num = nums.max().int().item()
116
+
117
+ deltas = 1 / (nums - 1)
118
+
119
+ x = torch.clamp_min_(_batch_linspace(deltas * (nums - max_num), torch.ones_like(deltas), max_num), 0)
120
+
121
+ x = x * (end - start) + start
122
+
123
+ return x
124
+
125
+
126
+ def _batch_linear_interp1d(x: Tensor, y: Tensor, x_new: Tensor) -> Tensor:
127
+ eps = torch.finfo(y.dtype).eps
128
+
129
+ ind = torch.searchsorted(x.contiguous(), x_new.contiguous())
130
+
131
+ ind = torch.clamp_(ind - 1, 0, x.shape[-1] - 2)
132
+ slopes = (y[..., 1:] - y[..., :-1]) / (eps + (x[..., 1:] - x[..., :-1]))
133
+ ind_y = ind + torch.arange(slopes.shape[0], device=slopes.device)[:, None] * y.shape[1]
134
+ ind_slopes = ind + torch.arange(slopes.shape[0], device=slopes.device)[:, None] * slopes.shape[1]
135
+
136
+ y_new = y.flatten()[ind_y] + slopes.flatten()[ind_slopes] * (x_new - x.flatten()[ind_y])
137
+
138
+ return y_new
@@ -1,110 +1,110 @@
1
- import torch
2
- import scipy
3
- import numpy as np
4
- from functools import lru_cache
5
- from typing import Tuple
6
-
7
- from reflectorch.data_generation.reflectivity.abeles import abeles
8
-
9
- #Pytorch version based on the JAX implementation of pointwise smearing in the refnx package.
10
-
11
- @lru_cache(maxsize=128)
12
- def gauss_legendre(n: int) -> Tuple[torch.Tensor, torch.Tensor]:
13
- """
14
- Calculate Gaussian quadrature abscissae and weights.
15
-
16
- Args:
17
- n (int): Gaussian quadrature order.
18
-
19
- Returns:
20
- Tuple[torch.Tensor, torch.Tensor]: The abscissae and weights for Gauss-Legendre integration.
21
- """
22
- return scipy.special.p_roots(n)
23
-
24
- def gauss(x: torch.Tensor) -> torch.Tensor:
25
- """
26
- Calculate the Gaussian function.
27
-
28
- Args:
29
- x (torch.Tensor): Input tensor.
30
-
31
- Returns:
32
- torch.Tensor: Output tensor after applying the Gaussian function.
33
- """
34
- return torch.exp(-0.5 * x * x)
35
-
36
- def abeles_pointwise_smearing(
37
- q: torch.Tensor,
38
- dq: torch.Tensor,
39
- thickness: torch.Tensor,
40
- roughness: torch.Tensor,
41
- sld: torch.Tensor,
42
- gauss_num: int = 17,
43
- abeles_func=None,
44
- **abeles_kwargs,
45
- ) -> torch.Tensor:
46
- """
47
- Compute reflectivity with variable smearing using Gaussian quadrature.
48
-
49
- Args:
50
- q (torch.Tensor): The momentum transfer (q) values.
51
- dq (torch.Tensor): The resolution for curve smearing.
52
- thickness (torch.Tensor): The layer thicknesses.
53
- roughness (torch.Tensor): The interlayer roughnesses.
54
- sld (torch.Tensor): The SLDs of the layers.
55
- sld_magnetic (torch.Tensor, optional): The magnetic SLDs of the layers.
56
- magnetization_angle (torch.Tensor, optional): The magnetization angles.
57
- polarizer_eff (torch.Tensor, optional): The polarizer efficiency.
58
- analyzer_eff (torch.Tensor, optional): The analyzer efficiency.
59
- abeles_func (Callable, optional): A function implementing the simulation of the reflectivity curves.
60
- gauss_num (int, optional): Gaussian quadrature order. Defaults to 17.
61
-
62
- Returns:
63
- torch.Tensor: The computed reflectivity curves.
64
- """
65
- abeles_func = abeles_func or abeles
66
-
67
- if q.shape[0] == 1:
68
- q = q.repeat(thickness.shape[0], 1)
69
-
70
- _FWHM = 2 * np.sqrt(2 * np.log(2.0))
71
- _INTLIMIT = 3.5
72
-
73
- bs = q.shape[0]
74
- nq = q.shape[-1]
75
- device = q.device
76
-
77
- quad_order = gauss_num
78
- abscissa, weights = gauss_legendre(quad_order)
79
- abscissa = torch.tensor(abscissa)[None, :, None].to(device)
80
- weights = torch.tensor(weights)[None, :, None].to(device)
81
- prefactor = 1.0 / np.sqrt(2 * np.pi)
82
-
83
- gaussvals = prefactor * gauss(abscissa * _INTLIMIT)
84
-
85
- va = q[:, None, :] - _INTLIMIT * dq[:, None, :] / _FWHM
86
- vb = q[:, None, :] + _INTLIMIT * dq[:, None, :] / _FWHM
87
-
88
- qvals_for_res_0 = (abscissa * (vb - va) + vb + va) / 2
89
- qvals_for_res = qvals_for_res_0.reshape(bs, -1)
90
-
91
- refl_curves = abeles_func(
92
- q=qvals_for_res,
93
- thickness=thickness,
94
- roughness=roughness,
95
- sld=sld,
96
- **abeles_kwargs
97
- )
98
-
99
- # Handle multiple channels
100
- if refl_curves.dim() == 3:
101
- n_channels = refl_curves.shape[1]
102
- refl_curves = refl_curves.reshape(bs, n_channels, quad_order, nq)
103
- refl_curves = refl_curves * gaussvals.unsqueeze(1) * weights.unsqueeze(1)
104
- refl_curves = torch.sum(refl_curves, dim=2) * _INTLIMIT
105
- else:
106
- refl_curves = refl_curves.reshape(bs, quad_order, nq)
107
- refl_curves = refl_curves * gaussvals * weights
108
- refl_curves = torch.sum(refl_curves, dim=1) * _INTLIMIT
109
-
1
+ import torch
2
+ import scipy
3
+ import numpy as np
4
+ from functools import lru_cache
5
+ from typing import Tuple
6
+
7
+ from reflectorch.data_generation.reflectivity.abeles import abeles
8
+
9
+ #Pytorch version based on the JAX implementation of pointwise smearing in the refnx package.
10
+
11
+ @lru_cache(maxsize=128)
12
+ def gauss_legendre(n: int) -> Tuple[torch.Tensor, torch.Tensor]:
13
+ """
14
+ Calculate Gaussian quadrature abscissae and weights.
15
+
16
+ Args:
17
+ n (int): Gaussian quadrature order.
18
+
19
+ Returns:
20
+ Tuple[torch.Tensor, torch.Tensor]: The abscissae and weights for Gauss-Legendre integration.
21
+ """
22
+ return scipy.special.p_roots(n)
23
+
24
+ def gauss(x: torch.Tensor) -> torch.Tensor:
25
+ """
26
+ Calculate the Gaussian function.
27
+
28
+ Args:
29
+ x (torch.Tensor): Input tensor.
30
+
31
+ Returns:
32
+ torch.Tensor: Output tensor after applying the Gaussian function.
33
+ """
34
+ return torch.exp(-0.5 * x * x)
35
+
36
+ def abeles_pointwise_smearing(
37
+ q: torch.Tensor,
38
+ dq: torch.Tensor,
39
+ thickness: torch.Tensor,
40
+ roughness: torch.Tensor,
41
+ sld: torch.Tensor,
42
+ gauss_num: int = 17,
43
+ abeles_func=None,
44
+ **abeles_kwargs,
45
+ ) -> torch.Tensor:
46
+ """
47
+ Compute reflectivity with variable smearing using Gaussian quadrature.
48
+
49
+ Args:
50
+ q (torch.Tensor): The momentum transfer (q) values.
51
+ dq (torch.Tensor): The resolution for curve smearing.
52
+ thickness (torch.Tensor): The layer thicknesses.
53
+ roughness (torch.Tensor): The interlayer roughnesses.
54
+ sld (torch.Tensor): The SLDs of the layers.
55
+ sld_magnetic (torch.Tensor, optional): The magnetic SLDs of the layers.
56
+ magnetization_angle (torch.Tensor, optional): The magnetization angles.
57
+ polarizer_eff (torch.Tensor, optional): The polarizer efficiency.
58
+ analyzer_eff (torch.Tensor, optional): The analyzer efficiency.
59
+ abeles_func (Callable, optional): A function implementing the simulation of the reflectivity curves.
60
+ gauss_num (int, optional): Gaussian quadrature order. Defaults to 17.
61
+
62
+ Returns:
63
+ torch.Tensor: The computed reflectivity curves.
64
+ """
65
+ abeles_func = abeles_func or abeles
66
+
67
+ if q.shape[0] == 1:
68
+ q = q.repeat(thickness.shape[0], 1)
69
+
70
+ _FWHM = 2 * np.sqrt(2 * np.log(2.0))
71
+ _INTLIMIT = 3.5
72
+
73
+ bs = q.shape[0]
74
+ nq = q.shape[-1]
75
+ device = q.device
76
+
77
+ quad_order = gauss_num
78
+ abscissa, weights = gauss_legendre(quad_order)
79
+ abscissa = torch.tensor(abscissa)[None, :, None].to(device)
80
+ weights = torch.tensor(weights)[None, :, None].to(device)
81
+ prefactor = 1.0 / np.sqrt(2 * np.pi)
82
+
83
+ gaussvals = prefactor * gauss(abscissa * _INTLIMIT)
84
+
85
+ va = q[:, None, :] - _INTLIMIT * dq[:, None, :] / _FWHM
86
+ vb = q[:, None, :] + _INTLIMIT * dq[:, None, :] / _FWHM
87
+
88
+ qvals_for_res_0 = (abscissa * (vb - va) + vb + va) / 2
89
+ qvals_for_res = qvals_for_res_0.reshape(bs, -1)
90
+
91
+ refl_curves = abeles_func(
92
+ q=qvals_for_res,
93
+ thickness=thickness,
94
+ roughness=roughness,
95
+ sld=sld,
96
+ **abeles_kwargs
97
+ )
98
+
99
+ # Handle multiple channels
100
+ if refl_curves.dim() == 3:
101
+ n_channels = refl_curves.shape[1]
102
+ refl_curves = refl_curves.reshape(bs, n_channels, quad_order, nq)
103
+ refl_curves = refl_curves * gaussvals.unsqueeze(1) * weights.unsqueeze(1)
104
+ refl_curves = torch.sum(refl_curves, dim=2) * _INTLIMIT
105
+ else:
106
+ refl_curves = refl_curves.reshape(bs, quad_order, nq)
107
+ refl_curves = refl_curves * gaussvals * weights
108
+ refl_curves = torch.sum(refl_curves, dim=1) * _INTLIMIT
109
+
110
110
  return refl_curves