reflectorch 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of reflectorch might be problematic. Click here for more details.

Files changed (83) hide show
  1. reflectorch/__init__.py +23 -0
  2. reflectorch/data_generation/__init__.py +130 -0
  3. reflectorch/data_generation/dataset.py +196 -0
  4. reflectorch/data_generation/likelihoods.py +86 -0
  5. reflectorch/data_generation/noise.py +371 -0
  6. reflectorch/data_generation/priors/__init__.py +66 -0
  7. reflectorch/data_generation/priors/base.py +61 -0
  8. reflectorch/data_generation/priors/exp_subprior_sampler.py +304 -0
  9. reflectorch/data_generation/priors/independent_priors.py +201 -0
  10. reflectorch/data_generation/priors/multilayer_models.py +311 -0
  11. reflectorch/data_generation/priors/multilayer_structures.py +110 -0
  12. reflectorch/data_generation/priors/no_constraints.py +212 -0
  13. reflectorch/data_generation/priors/parametric_models.py +767 -0
  14. reflectorch/data_generation/priors/parametric_subpriors.py +354 -0
  15. reflectorch/data_generation/priors/params.py +258 -0
  16. reflectorch/data_generation/priors/sampler_strategies.py +306 -0
  17. reflectorch/data_generation/priors/scaler_mixin.py +65 -0
  18. reflectorch/data_generation/priors/subprior_sampler.py +377 -0
  19. reflectorch/data_generation/priors/utils.py +124 -0
  20. reflectorch/data_generation/process_data.py +47 -0
  21. reflectorch/data_generation/q_generator.py +232 -0
  22. reflectorch/data_generation/reflectivity/__init__.py +56 -0
  23. reflectorch/data_generation/reflectivity/abeles.py +81 -0
  24. reflectorch/data_generation/reflectivity/kinematical.py +58 -0
  25. reflectorch/data_generation/reflectivity/memory_eff.py +92 -0
  26. reflectorch/data_generation/reflectivity/numpy_implementations.py +120 -0
  27. reflectorch/data_generation/reflectivity/smearing.py +123 -0
  28. reflectorch/data_generation/scale_curves.py +118 -0
  29. reflectorch/data_generation/smearing.py +67 -0
  30. reflectorch/data_generation/utils.py +154 -0
  31. reflectorch/extensions/__init__.py +6 -0
  32. reflectorch/extensions/jupyter/__init__.py +12 -0
  33. reflectorch/extensions/jupyter/callbacks.py +40 -0
  34. reflectorch/extensions/matplotlib/__init__.py +11 -0
  35. reflectorch/extensions/matplotlib/losses.py +38 -0
  36. reflectorch/inference/__init__.py +22 -0
  37. reflectorch/inference/inference_model.py +734 -0
  38. reflectorch/inference/multilayer_fitter.py +171 -0
  39. reflectorch/inference/multilayer_inference_model.py +193 -0
  40. reflectorch/inference/preprocess_exp/__init__.py +7 -0
  41. reflectorch/inference/preprocess_exp/attenuation.py +36 -0
  42. reflectorch/inference/preprocess_exp/cut_with_q_ratio.py +31 -0
  43. reflectorch/inference/preprocess_exp/footprint.py +81 -0
  44. reflectorch/inference/preprocess_exp/interpolation.py +16 -0
  45. reflectorch/inference/preprocess_exp/normalize.py +21 -0
  46. reflectorch/inference/preprocess_exp/preprocess.py +121 -0
  47. reflectorch/inference/record_time.py +43 -0
  48. reflectorch/inference/sampler_solution.py +56 -0
  49. reflectorch/inference/scipy_fitter.py +171 -0
  50. reflectorch/inference/torch_fitter.py +87 -0
  51. reflectorch/ml/__init__.py +37 -0
  52. reflectorch/ml/basic_trainer.py +286 -0
  53. reflectorch/ml/callbacks.py +86 -0
  54. reflectorch/ml/dataloaders.py +27 -0
  55. reflectorch/ml/loggers.py +38 -0
  56. reflectorch/ml/schedulers.py +246 -0
  57. reflectorch/ml/trainers.py +126 -0
  58. reflectorch/ml/utils.py +9 -0
  59. reflectorch/models/__init__.py +22 -0
  60. reflectorch/models/activations.py +50 -0
  61. reflectorch/models/encoders/__init__.py +27 -0
  62. reflectorch/models/encoders/conv_encoder.py +211 -0
  63. reflectorch/models/encoders/conv_res_net.py +119 -0
  64. reflectorch/models/encoders/fno.py +127 -0
  65. reflectorch/models/encoders/transformers.py +56 -0
  66. reflectorch/models/networks/__init__.py +18 -0
  67. reflectorch/models/networks/mlp_networks.py +256 -0
  68. reflectorch/models/networks/residual_net.py +131 -0
  69. reflectorch/paths.py +33 -0
  70. reflectorch/runs/__init__.py +35 -0
  71. reflectorch/runs/config.py +31 -0
  72. reflectorch/runs/slurm_utils.py +99 -0
  73. reflectorch/runs/train.py +85 -0
  74. reflectorch/runs/utils.py +300 -0
  75. reflectorch/test_config.py +4 -0
  76. reflectorch/train.py +4 -0
  77. reflectorch/train_on_cluster.py +4 -0
  78. reflectorch/utils.py +74 -0
  79. reflectorch-1.0.0.dist-info/LICENSE.txt +621 -0
  80. reflectorch-1.0.0.dist-info/METADATA +115 -0
  81. reflectorch-1.0.0.dist-info/RECORD +83 -0
  82. reflectorch-1.0.0.dist-info/WHEEL +5 -0
  83. reflectorch-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,232 @@
1
+ # -*- coding: utf-8 -*-
2
+ #
3
+ #
4
+ # This source code is licensed under the GPL license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import Tuple, Union
8
+
9
+ import numpy as np
10
+
11
+ import torch
12
+ from torch import Tensor
13
+
14
+ from reflectorch.data_generation.utils import uniform_sampler
15
+ from reflectorch.data_generation.priors import BasicParams
16
+ from reflectorch.utils import angle_to_q
17
+ from reflectorch.data_generation.priors.no_constraints import DEFAULT_DEVICE, DEFAULT_DTYPE
18
+
19
+ __all__ = [
20
+ "QGenerator",
21
+ "ConstantQ",
22
+ "VariableQ",
23
+ "EquidistantQ",
24
+ "ConstantAngle",
25
+ ]
26
+
27
+
28
+ class QGenerator(object):
29
+ """Base class for momentum transfer (q) generators"""
30
+ def get_batch(self, batch_size: int, context: dict = None) -> Tensor:
31
+ pass
32
+
33
+
34
+ class ConstantQ(QGenerator):
35
+ """Q generator for reflectivity curves with fixed discretization
36
+
37
+ Args:
38
+ q (Union[Tensor, Tuple[float, float, int]], optional): tuple (q_min, q_max, num_q) defining the minimum q value, maximum q value and the number of q points. Defaults to (0., 0.2, 128).
39
+ device (optional): the Pytorch device. Defaults to DEFAULT_DEVICE.
40
+ dtype (optional): the Pytorch data type. Defaults to DEFAULT_DTYPE.
41
+ remove_zero (bool, optional): do not include the upper end of the interval. Defaults to False.
42
+ fixed_zero (bool, optional): do not include the lower end of the interval. Defaults to False.
43
+ """
44
+
45
+ def __init__(self,
46
+ q: Union[Tensor, Tuple[float, float, int]] = (0., 0.2, 128),
47
+ device=DEFAULT_DEVICE,
48
+ dtype=DEFAULT_DTYPE,
49
+ remove_zero: bool = False,
50
+ fixed_zero: bool = False,
51
+ ):
52
+ if isinstance(q, (tuple, list)):
53
+ q = torch.linspace(*q, device=device, dtype=dtype)
54
+ if remove_zero:
55
+ if fixed_zero:
56
+ q = q[1:]
57
+ else:
58
+ q = q[:-1]
59
+ self.q = q
60
+
61
+ def get_batch(self, batch_size: int, context: dict = None) -> Tensor:
62
+ """generate a batch of q values
63
+
64
+ Args:
65
+ batch_size (int): the batch size
66
+
67
+ Returns:
68
+ Tensor: generated batch of q values
69
+ """
70
+ return self.q.clone()[None].expand(batch_size, self.q.shape[0])
71
+
72
+
73
+ class VariableQ(QGenerator):
74
+ """Q generator for reflectivity curves with variable discretization
75
+
76
+ Args:
77
+ q_min_range (list, optional): the range for sampling the minimum q value of the curves, *q_min*. Defaults to [0.01, 0.03].
78
+ q_max_range (list, optional): the range for sampling the maximum q value of the curves, *q_max*. Defaults to [0.1, 0.5].
79
+ n_q_range (list, optional): the range for the number of points in the curves (equidistantly sampled between *q_min* and *q_max*,
80
+ the number of points varies between batches but is constant within a batch). Defaults to [64, 256].
81
+ device (optional): the Pytorch device. Defaults to DEFAULT_DEVICE.
82
+ dtype (optional): the Pytorch data type. Defaults to DEFAULT_DTYPE.
83
+ """
84
+
85
+ def __init__(self,
86
+ q_min_range: Tuple[float, float] = (0.01, 0.03),
87
+ q_max_range: Tuple[float, float] = (0.1, 0.5),
88
+ n_q_range: Tuple[int, int] = (64, 256),
89
+ device=DEFAULT_DEVICE,
90
+ dtype=DEFAULT_DTYPE,
91
+ ):
92
+ self.q_min_range = q_min_range
93
+ self.q_max_range = q_max_range
94
+ self.n_q_range = n_q_range
95
+ self.device = device
96
+ self.dtype = dtype
97
+
98
+ def get_batch(self, batch_size: int, context: dict = None) -> Tensor:
99
+ """generate a batch of q values (the number of points varies between batches but is constant within a batch)
100
+
101
+ Args:
102
+ batch_size (int): the batch size
103
+
104
+ Returns:
105
+ Tensor: generated batch of q values
106
+ """
107
+ q_min = np.random.uniform(*self.q_min_range, batch_size)
108
+ q_max = np.random.uniform(*self.q_max_range, batch_size)
109
+ if self.n_q_range[0] == self.n_q_range[1]:
110
+ n_q = self.n_q_range[0]
111
+ else:
112
+ n_q = np.random.randint(self.n_q_range[0], self.n_q_range[1] + 1)
113
+
114
+ q = torch.from_numpy(np.linspace(q_min, q_max, n_q).T).to(self.device).to(self.dtype)
115
+
116
+ return q
117
+
118
+ def scale_q(self, q):
119
+ """scales the q values to the range [-1, 1]
120
+
121
+ Args:
122
+ q (Tensor): unscaled q values
123
+
124
+ Returns:
125
+ Tensor: scaled q values
126
+ """
127
+ scaled_q_01 = (q - self.q_min_range[0]) / (self.q_max_range[1] - self.q_min_range[0])
128
+
129
+ return 2.0 * (scaled_q_01 - 0.5)
130
+
131
+
132
+ class ConstantAngle(QGenerator):
133
+ """Q generator for reflectivity curves measured at equidistant angles
134
+
135
+ Args:
136
+ angle_range (Tuple[float, float, int], optional): the range of the incident angles. Defaults to (0., 0.2, 257).
137
+ wavelength (float, optional): the beam wavelength in units of angstroms. Defaults to 1.
138
+ device (optional): the Pytorch device. Defaults to DEFAULT_DEVICE.
139
+ dtype (optional): the Pytorch data type. Defaults to DEFAULT_DTYPE.
140
+ """
141
+ def __init__(self,
142
+ angle_range: Tuple[float, float, int] = (0., 0.2, 257),
143
+ wavelength: float = 1.,
144
+ device=DEFAULT_DEVICE,
145
+ dtype=DEFAULT_DTYPE,
146
+ ):
147
+ self.q = torch.from_numpy(angle_to_q(np.linspace(*angle_range), wavelength)).to(device).to(dtype)
148
+
149
+ def get_batch(self, batch_size: int, context: dict = None) -> Tensor:
150
+ """generate a batch of q values
151
+
152
+ Args:
153
+ batch_size (int): the batch size
154
+
155
+ Returns:
156
+ Tensor: generated batch of q values
157
+ """
158
+ return self.q.clone()[None].expand(batch_size, self.q.shape[0])
159
+
160
+
161
+ class EquidistantQ(QGenerator):
162
+ def __init__(self,
163
+ max_range: Tuple[float, float],
164
+ num_values: Union[int, Tuple[int, int]],
165
+ device=None,
166
+ dtype=torch.float64
167
+ ):
168
+ self.max_range = max_range
169
+ self._num_values = num_values
170
+ self.device = device
171
+ self.dtype = dtype
172
+
173
+ @property
174
+ def num_values(self) -> int:
175
+ if isinstance(self._num_values, int):
176
+ return self._num_values
177
+ return np.random.randint(*self._num_values)
178
+
179
+ def get_batch(self, batch_size: int, context: dict = None) -> Tensor:
180
+ num_values = self.num_values
181
+ q_max = uniform_sampler(*self.max_range, batch_size, 1, device=self.device, dtype=self.dtype)
182
+ norm_qs = torch.linspace(0, 1, num_values + 1, device=self.device, dtype=self.dtype)[1:][None]
183
+ qs = norm_qs * q_max
184
+ return qs
185
+
186
+
187
+ class TransformerQ(QGenerator):
188
+ def __init__(self,
189
+ q_max: float = 0.2,
190
+ num_values: Union[int, Tuple[int, int]] = (30, 512),
191
+ min_dq_ratio: float = 5.,
192
+ device=None,
193
+ dtype=torch.float64,
194
+ ):
195
+ self.min_dq_ratio = min_dq_ratio
196
+ self.q_max = q_max
197
+ self._dq_range = q_max / num_values[1], q_max / num_values[0]
198
+ self._num_values = num_values
199
+ self.device = device
200
+ self.dtype = dtype
201
+
202
+ def get_batch(self, batch_size: int, context: dict = None) -> Tensor:
203
+ assert context is not None
204
+
205
+ params: BasicParams = context['params']
206
+ total_thickness = params.thicknesses.sum(-1)
207
+
208
+ assert total_thickness.shape[0] == batch_size
209
+
210
+ min_dqs = torch.clamp(
211
+ 2 * np.pi / total_thickness / self.min_dq_ratio, self._dq_range[0], self._dq_range[1] * 0.9
212
+ )
213
+
214
+ dqs = torch.rand_like(min_dqs) * (self._dq_range[1] - min_dqs) + min_dqs
215
+
216
+ num_q_values = torch.clamp(self.q_max // dqs, *self._num_values).to(torch.int)
217
+
218
+ q_values, mask = generate_q_padding_mask(num_q_values, self.q_max)
219
+
220
+ context['tgt_key_padding_mask'] = mask
221
+ context['num_q_values'] = num_q_values
222
+
223
+ return q_values
224
+
225
+
226
+ def generate_q_padding_mask(num_q_values: Tensor, q_max: float):
227
+ batch_size = num_q_values.shape[0]
228
+ dqs = (q_max / num_q_values)[:, None]
229
+ q_values = torch.arange(1, num_q_values.max().item() + 1)[None].repeat(batch_size, 1) * dqs
230
+ mask = (q_values > q_max + dqs / 2)
231
+ q_values[mask] = 0.
232
+ return q_values, mask
@@ -0,0 +1,56 @@
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import torch
4
+ from torch import Tensor
5
+
6
+ from reflectorch.data_generation.reflectivity.abeles import abeles_compiled, abeles
7
+ from reflectorch.data_generation.reflectivity.memory_eff import abeles_memory_eff
8
+ from reflectorch.data_generation.reflectivity.numpy_implementations import (
9
+ kinematical_approximation_np,
10
+ abeles_np,
11
+ )
12
+ from reflectorch.data_generation.reflectivity.smearing import abeles_constant_smearing
13
+ from reflectorch.data_generation.reflectivity.kinematical import kinematical_approximation
14
+
15
+
16
+ def reflectivity(
17
+ q: Tensor,
18
+ thickness: Tensor,
19
+ roughness: Tensor,
20
+ sld: Tensor,
21
+ dq: Tensor = None,
22
+ gauss_num: int = 51,
23
+ constant_dq: bool = True,
24
+ log: bool = False,
25
+ abeles_func=None,
26
+ ):
27
+ """Function which computes the reflectivity from thin film parameters using the Abeles matrix formalism
28
+
29
+ Args:
30
+ q (Tensor): the momentum transfer (q) values
31
+ thickness (Tensor): the layer thicknesses
32
+ roughness (Tensor): the interlayer roughnesses
33
+ sld (Tensor): the SLDs of the layers
34
+ dq (Tensor, optional): the resolution for curve smearing. Defaults to None.
35
+ gauss_num (int, optional): the number of gaussians for curve smearing. Defaults to 51.
36
+ constant_dq (bool, optional): whether the smearing is constant. Defaults to True.
37
+ log (bool, optional): if True the base 10 logarithm of the reflectivity curves is returned. Defaults to False.
38
+ abeles_func (Callable, optional): a function implementing the simulation of the reflectivity curves, if different than the default implementation. Defaults to None.
39
+
40
+ Returns:
41
+ Tensor: the computed reflectivity curves
42
+ """
43
+ abeles_func = abeles_func or abeles
44
+ q = torch.atleast_2d(q)
45
+
46
+ if dq is None:
47
+ reflectivity_curves = abeles_func(q, thickness, roughness, sld)
48
+ else:
49
+ reflectivity_curves = abeles_constant_smearing(
50
+ q, thickness, roughness, sld,
51
+ dq=dq, gauss_num=gauss_num, constant_dq=constant_dq, abeles_func=abeles_func
52
+ )
53
+
54
+ if log:
55
+ reflectivity_curves = torch.log10(reflectivity_curves)
56
+ return reflectivity_curves
@@ -0,0 +1,81 @@
1
+ # -*- coding: utf-8 -*-
2
+ import math
3
+
4
+ import torch
5
+ from torch import Tensor
6
+
7
+
8
+ def abeles(
9
+ q: Tensor,
10
+ thickness: Tensor,
11
+ roughness: Tensor,
12
+ sld: Tensor,
13
+ ):
14
+ """Simulates reflectivity curves for SLD profiles with box model parameterization using the Abeles matrix method
15
+
16
+ Args:
17
+ q (Tensor): q values
18
+ thickness (Tensor): layer thicknesses
19
+ roughness (Tensor): interlayer roughnesses
20
+ sld (Tensor): layer SLDs
21
+
22
+ Returns:
23
+ Tensor: simulated reflectivity curves
24
+ """
25
+ c_dtype = torch.complex128 if q.dtype is torch.float64 else torch.complex64
26
+
27
+ batch_size, num_layers = thickness.shape
28
+
29
+ sld = torch.cat([torch.zeros(batch_size, 1).to(sld), sld], -1)[:, None]
30
+ thickness = torch.cat([torch.zeros(batch_size, 1).to(thickness), thickness], -1)[:, None]
31
+ roughness = roughness[:, None] ** 2
32
+
33
+ sld = sld * 1e-6 + 1e-30j
34
+
35
+ k_z0 = (q / 2).to(c_dtype)
36
+
37
+ if k_z0.dim() == 1:
38
+ k_z0.unsqueeze_(0)
39
+
40
+ if k_z0.dim() == 2:
41
+ k_z0.unsqueeze_(-1)
42
+
43
+ k_n = torch.sqrt(k_z0 ** 2 - 4 * math.pi * sld)
44
+
45
+ # k_n.shape - (batch, q, layers)
46
+
47
+ k_n, k_np1 = k_n[..., :-1], k_n[..., 1:]
48
+
49
+ beta = 1j * thickness * k_n
50
+
51
+ exp_beta = torch.exp(beta)
52
+ exp_m_beta = torch.exp(-beta)
53
+
54
+ rn = (k_n - k_np1) / (k_n + k_np1) * torch.exp(- 2 * k_n * k_np1 * roughness)
55
+
56
+ c_matrices = torch.stack([
57
+ torch.stack([exp_beta, rn * exp_m_beta], -1),
58
+ torch.stack([rn * exp_beta, exp_m_beta], -1),
59
+ ], -1)
60
+
61
+ c_matrices = [c.squeeze(-3) for c in c_matrices.split(1, -3)]
62
+
63
+ m, c_matrices = c_matrices[0], c_matrices[1:]
64
+
65
+ for c in c_matrices:
66
+ m = m @ c
67
+
68
+ r = (m[..., 1, 0] / m[..., 0, 0]).abs() ** 2
69
+ r = torch.clamp_max_(r, 1.)
70
+
71
+ return r
72
+
73
+
74
+ # @torch.jit.script # commented so far due to complex numbers issue
75
+ def abeles_compiled(
76
+ q: Tensor,
77
+ thickness: Tensor,
78
+ roughness: Tensor,
79
+ sld: Tensor,
80
+ ):
81
+ return abeles(q, thickness, roughness, sld)
@@ -0,0 +1,58 @@
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import torch
4
+ from torch import Tensor
5
+
6
+
7
+ def kinematical_approximation(
8
+ q: Tensor,
9
+ thickness: Tensor,
10
+ roughness: Tensor,
11
+ sld: Tensor,
12
+ *,
13
+ apply_fresnel: bool = True,
14
+ log: bool = False,
15
+ ):
16
+ c_dtype = torch.complex128 if q.dtype is torch.float64 else torch.complex64
17
+
18
+ batch_size, num_layers = thickness.shape
19
+
20
+ q = q.to(c_dtype)
21
+
22
+ if q.dim() == 1:
23
+ q.unsqueeze_(0)
24
+
25
+ if q.dim() == 2:
26
+ q.unsqueeze_(-1)
27
+
28
+ sld = sld * 1e-6 + 1e-30j
29
+
30
+ drho = torch.cat([sld[..., 0][..., None], sld[..., 1:] - sld[..., :-1]], -1)[:, None]
31
+ thickness = torch.cumsum(torch.cat([torch.zeros(batch_size, 1).to(thickness), thickness], -1), -1)[:, None]
32
+ roughness = roughness[:, None]
33
+
34
+ r = (drho * torch.exp(- (roughness * q) ** 2 / 2 + 1j * (q * thickness))).sum(-1).abs().float() ** 2
35
+
36
+ if apply_fresnel:
37
+
38
+ substrate_sld = sld[:, -1:]
39
+
40
+ rf = _get_resnel_reflectivity(q, substrate_sld[:, None])
41
+
42
+ r = torch.clamp_max_(r * rf / substrate_sld.real ** 2, 1.)
43
+
44
+ if log:
45
+ r = torch.log10(r)
46
+
47
+ return r
48
+
49
+
50
+ def _get_resnel_reflectivity(q, substrate_slds):
51
+ _RE_CONST = 0.28174103675406496
52
+
53
+ q_c = torch.sqrt(substrate_slds + 0j) / _RE_CONST * 2
54
+ q_prime = torch.sqrt(q ** 2 - q_c ** 2 + 0j)
55
+ r_f = ((q - q_prime) / (q + q_prime)).abs().float() ** 2
56
+
57
+ return r_f.squeeze(-1)
58
+
@@ -0,0 +1,92 @@
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from math import pi
4
+
5
+ import torch
6
+ from torch import Tensor
7
+
8
+
9
+ def abeles_memory_eff(
10
+ q: Tensor,
11
+ thickness: Tensor,
12
+ roughness: Tensor,
13
+ sld: Tensor,
14
+ ):
15
+ c_dtype = torch.complex128 if q.dtype is torch.float64 else torch.complex64
16
+
17
+ batch_size, num_layers = thickness.shape
18
+
19
+ sld = sld * 1e-6 + 1e-30j
20
+
21
+ num_interfaces = num_layers + 1
22
+
23
+ k_z0 = (q / 2).to(c_dtype)
24
+
25
+ if len(k_z0.shape) == 1:
26
+ k_z0.unsqueeze_(0)
27
+
28
+ thickness_prev_layer = 1. # ambient
29
+
30
+ for interface_num in range(num_interfaces):
31
+
32
+ prev_layer_idx = interface_num - 1
33
+ next_layer_idx = interface_num
34
+
35
+ if interface_num == 0:
36
+ k_z_previous_layer = _get_relative_k_z(k_z0, torch.zeros(batch_size, 1).to(sld))
37
+ else:
38
+ thickness_prev_layer = thickness[:, prev_layer_idx].unsqueeze(1)
39
+ k_z_previous_layer = _get_relative_k_z(k_z0, sld[:, prev_layer_idx].unsqueeze(1))
40
+
41
+ k_z_next_layer = _get_relative_k_z(k_z0, sld[:, next_layer_idx].unsqueeze(1)) # (batch_num, q_num)
42
+
43
+ reflection_matrix = _make_reflection_matrix(
44
+ k_z_previous_layer, k_z_next_layer, roughness[:, interface_num].unsqueeze(1)
45
+ )
46
+
47
+ if interface_num == 0:
48
+ total_reflectivity_matrix = reflection_matrix
49
+ else:
50
+ translation_matrix = _make_translation_matrix(k_z_previous_layer, thickness_prev_layer)
51
+
52
+ total_reflectivity_matrix = torch.einsum(
53
+ 'bnmr, bmlr, bljr -> bnjr', total_reflectivity_matrix, translation_matrix, reflection_matrix
54
+ )
55
+
56
+ r = total_reflectivity_matrix[:, 0, 1] / total_reflectivity_matrix[:, 1, 1]
57
+
58
+ reflectivity = torch.clamp_max_(torch.abs(r) ** 2, 1.).flatten(1)
59
+
60
+ return reflectivity
61
+
62
+
63
+ def _get_relative_k_z(k_z0, scattering_length_density):
64
+ return torch.sqrt(k_z0 ** 2 - 4 * pi * scattering_length_density)
65
+
66
+
67
+ def _make_reflection_matrix(k_z_previous_layer, k_z_next_layer, interface_roughness):
68
+ p = _safe_div((k_z_previous_layer + k_z_next_layer), (2 * k_z_previous_layer)) * \
69
+ torch.exp(-(k_z_previous_layer - k_z_next_layer) ** 2 * 0.5 * interface_roughness ** 2)
70
+
71
+ m = _safe_div((k_z_previous_layer - k_z_next_layer), (2 * k_z_previous_layer)) * \
72
+ torch.exp(-(k_z_previous_layer + k_z_next_layer) ** 2 * 0.5 * interface_roughness ** 2)
73
+
74
+ return _stack_mtx(p, m, m, p)
75
+
76
+
77
+ def _stack_mtx(a11, a12, a21, a22):
78
+ return torch.stack([
79
+ torch.stack([a11, a12], dim=1),
80
+ torch.stack([a21, a22], dim=1),
81
+ ], dim=1)
82
+
83
+
84
+ def _make_translation_matrix(k_z, thickness):
85
+ return _stack_mtx(
86
+ torch.exp(-1j * k_z * thickness), torch.zeros_like(k_z),
87
+ torch.zeros_like(k_z), torch.exp(1j * k_z * thickness)
88
+ )
89
+
90
+
91
+ def _safe_div(numerator, denominator):
92
+ return torch.where(denominator == 0, numerator, torch.divide(numerator, denominator))
@@ -0,0 +1,120 @@
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import numpy as np
4
+
5
+
6
+ def abeles_np(
7
+ q: np.ndarray,
8
+ thickness: np.ndarray,
9
+ roughness: np.ndarray,
10
+ sld: np.ndarray,
11
+ ):
12
+ c_dtype = np.complex128 if q.dtype is np.float64 else np.complex64
13
+
14
+ if q.ndim == thickness.ndim == roughness.ndim == sld.ndim == 1:
15
+ zero_batch = True
16
+ else:
17
+ zero_batch = False
18
+
19
+ thickness = np.atleast_2d(thickness)
20
+ roughness = np.atleast_2d(roughness)
21
+ sld = np.atleast_2d(sld)
22
+
23
+ batch_size, num_layers = thickness.shape
24
+
25
+ sld = np.concatenate([np.zeros((batch_size, 1)).astype(sld.dtype), sld], -1)[:, None]
26
+ thickness = np.concatenate([np.zeros((batch_size, 1)).astype(thickness.dtype), thickness], -1)[:, None]
27
+ roughness = roughness[:, None] ** 2
28
+
29
+ sld = sld * 1e-6 + 1e-30j
30
+
31
+ k_z0 = (q / 2).astype(c_dtype)
32
+
33
+ if len(k_z0.shape) == 1:
34
+ k_z0 = k_z0[None]
35
+
36
+ if len(k_z0.shape) == 2:
37
+ k_z0 = k_z0[..., None]
38
+
39
+ k_n = np.sqrt(k_z0 ** 2 - 4 * np.pi * sld)
40
+
41
+ # k_n.shape - (batch, q, layers)
42
+
43
+ k_n, k_np1 = k_n[..., :-1], k_n[..., 1:]
44
+
45
+ beta = 1j * thickness * k_n
46
+
47
+ exp_beta = np.exp(beta)
48
+ exp_m_beta = np.exp(-beta)
49
+
50
+ rn = (k_n - k_np1) / (k_n + k_np1) * np.exp(- 2 * k_n * k_np1 * roughness)
51
+
52
+ c_matrices = np.stack([
53
+ np.stack([exp_beta, rn * exp_m_beta], -1),
54
+ np.stack([rn * exp_beta, exp_m_beta], -1),
55
+ ], -1)
56
+
57
+ c_matrices = np.moveaxis(c_matrices, -3, 0)
58
+
59
+ m, c_matrices = c_matrices[0], c_matrices[1:]
60
+
61
+ for c in c_matrices:
62
+ m = m @ c
63
+
64
+ r = np.abs(m[..., 1, 0] / m[..., 0, 0]) ** 2
65
+ r = np.clip(r, None, 1.)
66
+
67
+ if zero_batch:
68
+ r = r[0]
69
+
70
+ return r
71
+
72
+
73
+ def kinematical_approximation_np(
74
+ q: np.ndarray,
75
+ thickness: np.ndarray,
76
+ roughness: np.ndarray,
77
+ sld: np.ndarray,
78
+ ):
79
+ if q.ndim == thickness.ndim == roughness.ndim == sld.ndim == 1:
80
+ zero_batch = True
81
+ else:
82
+ zero_batch = False
83
+
84
+ thickness = np.atleast_2d(thickness)
85
+ roughness = np.atleast_2d(roughness)
86
+ sld = np.atleast_2d(sld) * 1e-6 + 1e-30j
87
+ substrate_sld = sld[:, -1:]
88
+
89
+ batch_size, num_layers = thickness.shape
90
+
91
+ if q.ndim == 1:
92
+ q = q[None]
93
+
94
+ if q.ndim == 2:
95
+ q = q[..., None]
96
+
97
+ drho = np.concatenate([sld[..., 0][..., None], sld[..., 1:] - sld[..., :-1]], -1)[:, None]
98
+ thickness = np.cumsum(np.concatenate([np.zeros((batch_size, 1)), thickness], -1), -1)[:, None]
99
+ roughness = roughness[:, None]
100
+
101
+ r = np.abs((drho * np.exp(- (roughness * q) ** 2 / 2 + 1j * (q * thickness))).sum(-1)).astype(float) ** 2
102
+
103
+ rf = _get_resnel_reflectivity_np(q, substrate_sld[:, None])
104
+
105
+ r = np.clip(r * rf / np.real(substrate_sld) ** 2, None, 1.)
106
+
107
+ if zero_batch:
108
+ r = r[0]
109
+
110
+ return r
111
+
112
+
113
+ def _get_resnel_reflectivity_np(q, substrate_slds):
114
+ _RE_CONST = 0.28174103675406496
115
+
116
+ q_c = np.sqrt(substrate_slds + 0j) / _RE_CONST * 2
117
+ q_prime = np.sqrt(q ** 2 - q_c ** 2 + 0j)
118
+ r_f = np.abs((q - q_prime) / (q + q_prime)).astype(float) ** 2
119
+
120
+ return r_f[..., 0]