reflectorch 1.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (96) hide show
  1. reflectorch/__init__.py +17 -0
  2. reflectorch/data_generation/__init__.py +128 -0
  3. reflectorch/data_generation/dataset.py +216 -0
  4. reflectorch/data_generation/likelihoods.py +80 -0
  5. reflectorch/data_generation/noise.py +471 -0
  6. reflectorch/data_generation/priors/__init__.py +60 -0
  7. reflectorch/data_generation/priors/base.py +55 -0
  8. reflectorch/data_generation/priors/exp_subprior_sampler.py +298 -0
  9. reflectorch/data_generation/priors/independent_priors.py +195 -0
  10. reflectorch/data_generation/priors/multilayer_models.py +311 -0
  11. reflectorch/data_generation/priors/multilayer_structures.py +104 -0
  12. reflectorch/data_generation/priors/no_constraints.py +206 -0
  13. reflectorch/data_generation/priors/parametric_models.py +842 -0
  14. reflectorch/data_generation/priors/parametric_subpriors.py +369 -0
  15. reflectorch/data_generation/priors/params.py +252 -0
  16. reflectorch/data_generation/priors/sampler_strategies.py +370 -0
  17. reflectorch/data_generation/priors/scaler_mixin.py +65 -0
  18. reflectorch/data_generation/priors/subprior_sampler.py +371 -0
  19. reflectorch/data_generation/priors/utils.py +118 -0
  20. reflectorch/data_generation/process_data.py +41 -0
  21. reflectorch/data_generation/q_generator.py +280 -0
  22. reflectorch/data_generation/reflectivity/__init__.py +102 -0
  23. reflectorch/data_generation/reflectivity/abeles.py +97 -0
  24. reflectorch/data_generation/reflectivity/kinematical.py +71 -0
  25. reflectorch/data_generation/reflectivity/memory_eff.py +105 -0
  26. reflectorch/data_generation/reflectivity/numpy_implementations.py +120 -0
  27. reflectorch/data_generation/reflectivity/smearing.py +138 -0
  28. reflectorch/data_generation/reflectivity/smearing_pointwise.py +110 -0
  29. reflectorch/data_generation/scale_curves.py +112 -0
  30. reflectorch/data_generation/smearing.py +99 -0
  31. reflectorch/data_generation/utils.py +223 -0
  32. reflectorch/extensions/__init__.py +0 -0
  33. reflectorch/extensions/jupyter/__init__.py +11 -0
  34. reflectorch/extensions/jupyter/api.py +85 -0
  35. reflectorch/extensions/jupyter/callbacks.py +34 -0
  36. reflectorch/extensions/jupyter/components.py +758 -0
  37. reflectorch/extensions/jupyter/custom_select.py +268 -0
  38. reflectorch/extensions/jupyter/log_widget.py +241 -0
  39. reflectorch/extensions/jupyter/model_selection.py +495 -0
  40. reflectorch/extensions/jupyter/plotly_plot_manager.py +329 -0
  41. reflectorch/extensions/jupyter/widget.py +625 -0
  42. reflectorch/extensions/matplotlib/__init__.py +5 -0
  43. reflectorch/extensions/matplotlib/losses.py +32 -0
  44. reflectorch/extensions/refnx/refnx_conversion.py +77 -0
  45. reflectorch/inference/__init__.py +28 -0
  46. reflectorch/inference/inference_model.py +848 -0
  47. reflectorch/inference/input_interface.py +239 -0
  48. reflectorch/inference/loading_data.py +55 -0
  49. reflectorch/inference/multilayer_fitter.py +171 -0
  50. reflectorch/inference/multilayer_inference_model.py +193 -0
  51. reflectorch/inference/plotting.py +524 -0
  52. reflectorch/inference/preprocess_exp/__init__.py +7 -0
  53. reflectorch/inference/preprocess_exp/attenuation.py +36 -0
  54. reflectorch/inference/preprocess_exp/cut_with_q_ratio.py +31 -0
  55. reflectorch/inference/preprocess_exp/footprint.py +81 -0
  56. reflectorch/inference/preprocess_exp/interpolation.py +19 -0
  57. reflectorch/inference/preprocess_exp/normalize.py +21 -0
  58. reflectorch/inference/preprocess_exp/preprocess.py +121 -0
  59. reflectorch/inference/query_matcher.py +82 -0
  60. reflectorch/inference/record_time.py +43 -0
  61. reflectorch/inference/sampler_solution.py +56 -0
  62. reflectorch/inference/scipy_fitter.py +364 -0
  63. reflectorch/inference/torch_fitter.py +87 -0
  64. reflectorch/ml/__init__.py +32 -0
  65. reflectorch/ml/basic_trainer.py +292 -0
  66. reflectorch/ml/callbacks.py +81 -0
  67. reflectorch/ml/dataloaders.py +27 -0
  68. reflectorch/ml/loggers.py +56 -0
  69. reflectorch/ml/schedulers.py +356 -0
  70. reflectorch/ml/trainers.py +201 -0
  71. reflectorch/ml/utils.py +2 -0
  72. reflectorch/models/__init__.py +16 -0
  73. reflectorch/models/activations.py +50 -0
  74. reflectorch/models/encoders/__init__.py +19 -0
  75. reflectorch/models/encoders/conv_encoder.py +219 -0
  76. reflectorch/models/encoders/conv_res_net.py +115 -0
  77. reflectorch/models/encoders/fno.py +134 -0
  78. reflectorch/models/encoders/integral_kernel_embedding.py +390 -0
  79. reflectorch/models/networks/__init__.py +14 -0
  80. reflectorch/models/networks/mlp_networks.py +434 -0
  81. reflectorch/models/networks/residual_net.py +157 -0
  82. reflectorch/paths.py +29 -0
  83. reflectorch/runs/__init__.py +31 -0
  84. reflectorch/runs/config.py +25 -0
  85. reflectorch/runs/slurm_utils.py +93 -0
  86. reflectorch/runs/train.py +78 -0
  87. reflectorch/runs/utils.py +405 -0
  88. reflectorch/test_config.py +4 -0
  89. reflectorch/train.py +4 -0
  90. reflectorch/train_on_cluster.py +4 -0
  91. reflectorch/utils.py +98 -0
  92. reflectorch-1.5.1.dist-info/METADATA +151 -0
  93. reflectorch-1.5.1.dist-info/RECORD +96 -0
  94. reflectorch-1.5.1.dist-info/WHEEL +5 -0
  95. reflectorch-1.5.1.dist-info/licenses/LICENSE.txt +21 -0
  96. reflectorch-1.5.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,280 @@
1
+ from typing import Tuple, Union
2
+
3
+ import numpy as np
4
+
5
+ import torch
6
+ from torch import Tensor
7
+
8
+ from reflectorch.data_generation.utils import uniform_sampler
9
+ from reflectorch.data_generation.priors import BasicParams
10
+ from reflectorch.utils import angle_to_q
11
+ from reflectorch.data_generation.priors.no_constraints import DEFAULT_DEVICE, DEFAULT_DTYPE
12
+
13
+ __all__ = [
14
+ "QGenerator",
15
+ "ConstantQ",
16
+ "VariableQ",
17
+ "EquidistantQ",
18
+ "ConstantAngle",
19
+ "MaskedVariableQ",
20
+ ]
21
+
22
+
23
+ class QGenerator(object):
24
+ """Base class for momentum transfer (q) generators"""
25
+ def get_batch(self, batch_size: int, context: dict = None) -> Tensor:
26
+ pass
27
+
28
+
29
+ class ConstantQ(QGenerator):
30
+ """Q generator for reflectivity curves with fixed discretization
31
+
32
+ Args:
33
+ q (Union[Tensor, Tuple[float, float, int]], optional): tuple (q_min, q_max, num_q) defining the minimum q value, maximum q value and the number of q points. Defaults to (0., 0.2, 128).
34
+ device (optional): the Pytorch device. Defaults to DEFAULT_DEVICE.
35
+ dtype (optional): the Pytorch data type. Defaults to DEFAULT_DTYPE.
36
+ remove_zero (bool, optional): do not include the upper end of the interval. Defaults to False.
37
+ fixed_zero (bool, optional): do not include the lower end of the interval. Defaults to False.
38
+ """
39
+
40
+ def __init__(self,
41
+ q: Union[Tensor, Tuple[float, float, int]] = (0., 0.2, 128),
42
+ device=DEFAULT_DEVICE,
43
+ dtype=DEFAULT_DTYPE,
44
+ remove_zero: bool = False,
45
+ fixed_zero: bool = False,
46
+ ):
47
+ if isinstance(q, (tuple, list)):
48
+ q = torch.linspace(*q, device=device, dtype=dtype)
49
+ if remove_zero:
50
+ if fixed_zero:
51
+ q = q[1:]
52
+ else:
53
+ q = q[:-1]
54
+ self.q_min = q.min().item()
55
+ self.q_max = q.max().item()
56
+ self.q = q
57
+
58
+ def get_batch(self, batch_size: int, context: dict = None) -> Tensor:
59
+ """generate a batch of q values
60
+
61
+ Args:
62
+ batch_size (int): the batch size
63
+
64
+ Returns:
65
+ Tensor: generated batch of q values
66
+ """
67
+ return self.q.clone()[None].expand(batch_size, self.q.shape[0])
68
+
69
+ def scale_q(self, q):
70
+ """Scales the q values to the range [-1, 1].
71
+
72
+ Args:
73
+ q (Tensor): unscaled q values
74
+
75
+ Returns:
76
+ Tensor: scaled q values
77
+ """
78
+ scaled_q_01 = (q - self.q_min) / (self.q_max - self.q_min)
79
+ return 2.0 * (scaled_q_01 - 0.5)
80
+
81
+
82
+ class VariableQ(QGenerator):
83
+ """Q generator for reflectivity curves with variable discretization
84
+
85
+ Args:
86
+ q_min_range (list, optional): the range for sampling the minimum q value of the curves, q_min. Defaults to [0.01, 0.03].
87
+ q_max_range (list, optional): the range for sampling the maximum q value of the curves, q_max. Defaults to [0.1, 0.5].
88
+ n_q_range (list, optional): the range for the number of points in the curves (equidistantly sampled between q_min and q_max,
89
+ the number of points varies between batches but is constant within a batch). Defaults to [64, 256].
90
+ device (optional): the Pytorch device. Defaults to DEFAULT_DEVICE.
91
+ dtype (optional): the Pytorch data type. Defaults to DEFAULT_DTYPE.
92
+ """
93
+
94
+ def __init__(self,
95
+ q_min_range: Tuple[float, float] = (0.01, 0.03),
96
+ q_max_range: Tuple[float, float] = (0.1, 0.5),
97
+ n_q_range: Tuple[int, int] = (64, 256),
98
+ mode: str = 'equidistant',
99
+ device=DEFAULT_DEVICE,
100
+ dtype=DEFAULT_DTYPE,
101
+ ):
102
+ self.q_min_range = q_min_range
103
+ self.q_max_range = q_max_range
104
+ self.n_q_range = n_q_range
105
+ self.mode = mode
106
+ self.device = device
107
+ self.dtype = dtype
108
+
109
+ def get_batch(self, batch_size: int, context: dict = None) -> Tensor:
110
+ """generate a batch of q values (the number of points varies between batches but is constant within a batch)
111
+
112
+ Args:
113
+ batch_size (int): the batch size
114
+
115
+ Returns:
116
+ Tensor: generated batch of q values
117
+ """
118
+
119
+ q_min = torch.rand(batch_size, device=self.device, dtype=self.dtype) * (self.q_min_range[1] - self.q_min_range[0]) + self.q_min_range[0]
120
+ q_max = torch.rand(batch_size, device=self.device, dtype=self.dtype) * (self.q_max_range[1] - self.q_max_range[0]) + self.q_max_range[0]
121
+
122
+ n_q = torch.randint(self.n_q_range[0], self.n_q_range[1] + 1, (1,), device=self.device).item()
123
+
124
+ if self.mode == 'equidistant':
125
+ q = torch.linspace(0, 1, n_q, device=self.device, dtype=self.dtype)
126
+ elif self.mode == 'random':
127
+ q = torch.rand(n_q, device=self.device, dtype=self.dtype).sort().values
128
+ elif self.mode == 'logspace':
129
+ q = torch.logspace(
130
+ start=torch.log10(torch.tensor(1e-4, dtype=self.dtype, device=self.device)),
131
+ end=torch.log10(torch.tensor(1.0, dtype=self.dtype, device=self.device)),
132
+ steps=n_q, dtype=self.dtype, device=self.device)
133
+
134
+ q = q_min[:, None] + q * (q_max - q_min)[:, None]
135
+
136
+ return q
137
+
138
+ def scale_q(self, q):
139
+ """scales the q values to the range [-1, 1]
140
+
141
+ Args:
142
+ q (Tensor): unscaled q values
143
+
144
+ Returns:
145
+ Tensor: scaled q values
146
+ """
147
+ scaled_q_01 = (q - self.q_min_range[0]) / (self.q_max_range[1] - self.q_min_range[0])
148
+
149
+ return 2.0 * (scaled_q_01 - 0.5)
150
+
151
+
152
+ class ConstantAngle(QGenerator):
153
+ """Q generator for reflectivity curves measured at equidistant angles
154
+
155
+ Args:
156
+ angle_range (Tuple[float, float, int], optional): the range of the incident angles. Defaults to (0., 0.2, 257).
157
+ wavelength (float, optional): the beam wavelength in units of angstroms. Defaults to 1.
158
+ device (optional): the Pytorch device. Defaults to DEFAULT_DEVICE.
159
+ dtype (optional): the Pytorch data type. Defaults to DEFAULT_DTYPE.
160
+ """
161
+ def __init__(self,
162
+ angle_range: Tuple[float, float, int] = (0., 0.2, 257),
163
+ wavelength: float = 1.,
164
+ device=DEFAULT_DEVICE,
165
+ dtype=DEFAULT_DTYPE,
166
+ ):
167
+ self.q = torch.from_numpy(angle_to_q(np.linspace(*angle_range), wavelength)).to(device).to(dtype)
168
+
169
+ def get_batch(self, batch_size: int, context: dict = None) -> Tensor:
170
+ """generate a batch of q values
171
+
172
+ Args:
173
+ batch_size (int): the batch size
174
+
175
+ Returns:
176
+ Tensor: generated batch of q values
177
+ """
178
+ return self.q.clone()[None].expand(batch_size, self.q.shape[0])
179
+
180
+
181
+ class EquidistantQ(QGenerator):
182
+ def __init__(self,
183
+ max_range: Tuple[float, float],
184
+ num_values: Union[int, Tuple[int, int]],
185
+ device=None,
186
+ dtype=torch.float64
187
+ ):
188
+ self.max_range = max_range
189
+ self._num_values = num_values
190
+ self.device = device
191
+ self.dtype = dtype
192
+
193
+ @property
194
+ def num_values(self) -> int:
195
+ if isinstance(self._num_values, int):
196
+ return self._num_values
197
+ return np.random.randint(*self._num_values)
198
+
199
+ def get_batch(self, batch_size: int, context: dict = None) -> Tensor:
200
+ num_values = self.num_values
201
+ q_max = uniform_sampler(*self.max_range, batch_size, 1, device=self.device, dtype=self.dtype)
202
+ norm_qs = torch.linspace(0, 1, num_values + 1, device=self.device, dtype=self.dtype)[1:][None]
203
+ qs = norm_qs * q_max
204
+ return qs
205
+
206
+
207
+ class MaskedVariableQ:
208
+ def __init__(self,
209
+ q_min_range=(0.01, 0.03),
210
+ q_max_range=(0.1, 0.5),
211
+ n_q_range=(64, 256),
212
+ mode='equidistant',
213
+ shuffle_mask=False,
214
+ total_thickness_constraint=True,
215
+ min_points_per_fringe=4,
216
+ device=DEFAULT_DEVICE,
217
+ dtype=DEFAULT_DTYPE):
218
+ self.q_min_range = q_min_range
219
+ self.q_max_range = q_max_range
220
+ self.n_q_range = n_q_range
221
+ self.device = device
222
+ self.dtype = dtype
223
+ self.mode = mode
224
+ self.shuffle_mask = shuffle_mask
225
+ self.total_thickness_constraint = total_thickness_constraint
226
+ self.min_points_per_fringe = min_points_per_fringe
227
+
228
+ def get_batch(self, batch_size, context):
229
+ assert context is not None
230
+
231
+ q_min = torch.rand(batch_size, device=self.device, dtype=self.dtype) * (self.q_min_range[1] - self.q_min_range[0]) + self.q_min_range[0]
232
+ q_max = torch.rand(batch_size, device=self.device, dtype=self.dtype) * (self.q_max_range[1] - self.q_max_range[0]) + self.q_max_range[0]
233
+
234
+ max_n_q = self.n_q_range[1]
235
+
236
+ if self.mode == 'equidistant':
237
+ positions = torch.linspace(0, 1, max_n_q, device=self.device, dtype=self.dtype).expand(batch_size, max_n_q)
238
+ elif self.mode == 'random':
239
+ positions = torch.rand(batch_size, max_n_q, device=self.device, dtype=self.dtype)
240
+ positions, _ = positions.sort(dim=-1)
241
+ elif self.mode == 'mixed':
242
+ positions = torch.empty(batch_size, max_n_q, device=self.device, dtype=self.dtype)
243
+
244
+ half = batch_size // 2 # half batch gets equidistant
245
+ eq_pos = torch.linspace(0, 1, max_n_q, device=self.device, dtype=self.dtype).expand(half, max_n_q)
246
+ positions[:half] = eq_pos
247
+
248
+ rand_pos = torch.rand(batch_size - half, max_n_q, device=self.device, dtype=self.dtype) # other half gets sorted random
249
+ rand_pos, _ = rand_pos.sort(dim=-1)
250
+ positions[half:] = rand_pos
251
+ else:
252
+ raise ValueError(f"Unknown spacing mode: {self.mode}")
253
+
254
+ q = q_min[:, None] + positions * (q_max - q_min)[:, None]
255
+
256
+ n_qs = torch.randint(self.n_q_range[0], self.n_q_range[1] + 1, (batch_size,), device=self.device)
257
+
258
+ if 'params' in context and self.total_thickness_constraint: ### N_points > 1 + (Q_spread * total_thickness * min_np_per_kiessing_fringe) / (2*pi)
259
+ d_total = context['params'].thicknesses.sum(-1)
260
+ limit = 1 + ((q_max - q_min) * d_total * self.min_points_per_fringe) / (2*np.pi)
261
+ limit = limit.ceil().int()
262
+ n_qs = torch.maximum(n_qs, limit)
263
+ n_qs = torch.clamp(n_qs, max=self.n_q_range[1])
264
+
265
+ indices = torch.arange(max_n_q, device=self.device).expand(batch_size, max_n_q)
266
+ valid_mask = indices < n_qs[:, None] # right side padding
267
+
268
+ if self.shuffle_mask: # shuffle valid positions (inter-spread padding)
269
+ perm = torch.argsort(torch.rand(batch_size, max_n_q, device=self.device), dim=-1)
270
+ valid_mask = torch.gather(valid_mask, dim=1, index=perm)
271
+
272
+ context['key_padding_mask'] = valid_mask
273
+ context['n_points'] = valid_mask.sum(dim=-1)
274
+
275
+ return q
276
+
277
+ def scale_q(self, q):
278
+ scaled_q_01 = (q - self.q_min_range[0]) / (self.q_max_range[1] - self.q_min_range[0])
279
+
280
+ return 2.0 * (scaled_q_01 - 0.5)
@@ -0,0 +1,102 @@
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import torch
4
+ from torch import Tensor
5
+
6
+ from reflectorch.data_generation.reflectivity.abeles import abeles_compiled, abeles
7
+ from reflectorch.data_generation.reflectivity.memory_eff import abeles_memory_eff
8
+ from reflectorch.data_generation.reflectivity.numpy_implementations import (
9
+ kinematical_approximation_np,
10
+ abeles_np,
11
+ )
12
+ from reflectorch.data_generation.reflectivity.smearing import abeles_constant_smearing
13
+ from reflectorch.data_generation.reflectivity.smearing_pointwise import abeles_pointwise_smearing
14
+ from reflectorch.data_generation.reflectivity.kinematical import kinematical_approximation
15
+
16
+
17
+ def reflectivity(
18
+ q: Tensor,
19
+ thickness: Tensor,
20
+ roughness: Tensor,
21
+ sld: Tensor,
22
+ dq: Tensor = None,
23
+ gauss_num: int = 51,
24
+ constant_dq: bool = False,
25
+ log: bool = False,
26
+ q_shift: Tensor = 0.0,
27
+ r_scale: Tensor = 1.0,
28
+ background: Tensor = 0.0,
29
+ solvent_vf = None,
30
+ solvent_mode = 'fronting',
31
+ abeles_func = None,
32
+ **abeles_kwargs
33
+ ):
34
+ """Function which computes the reflectivity curves from thin film parameters.
35
+ By default it uses the fast implementation of the Abeles matrix formalism.
36
+
37
+ Args:
38
+ q (Tensor): tensor of momentum transfer (q) values with shape [batch_size, n_points] or [n_points]
39
+ thickness (Tensor): tensor containing the layer thicknesses (ordered from top to bottom) with shape [batch_size, n_layers]
40
+ roughness (Tensor): tensor containing the interlayer roughnesses (ordered from top to bottom) with shape [batch_size, n_layers + 1]
41
+ sld (Tensor): tensor containing the layer SLDs (real or complex; ordered from top to bottom) with shape
42
+ [batch_size, n_layers + 1] (excluding ambient SLD which is assumed to be 0) or [batch_size, n_layers + 2] (including ambient SLD; only for the default ``abeles_func='abeles'``)
43
+ dq (Tensor, optional): tensor of resolutions used for curve smearing with shape [batch_size, 1].
44
+ Either dq if ``constant_dq`` is ``True`` or dq/q if ``constant_dq`` is ``False``. Defaults to None.
45
+ gauss_num (int, optional): the number of gaussians for curve smearing. Defaults to 51.
46
+ constant_dq (bool, optional): if ``True`` the smearing is constant (constant dq at each point in the curve)
47
+ otherwise the smearing is linear (constant dq/q at each point in the curve). Defaults to False.
48
+ log (bool, optional): if True the base 10 logarithm of the reflectivity curves is returned. Defaults to False.
49
+ q_shift (float or Tensor, optional): misalignment in q.
50
+ r_scale (float or Tensor, optional): normalization factor (scales reflectivity).
51
+ background (float or Tensor, optional): background intensity.
52
+ abeles_func (Callable, optional): a function implementing the simulation of the reflectivity curves, if different than the default Abeles matrix implementation ('abeles'). Defaults to None.
53
+ abeles_kwargs: Additional arguments specific to the chosen `abeles_func`.
54
+ Returns:
55
+ Tensor: the computed reflectivity curves
56
+ """
57
+ abeles_func = abeles_func or abeles
58
+ q = torch.atleast_2d(q) + q_shift
59
+ q = torch.clamp(q, min=0.0)
60
+
61
+ if solvent_vf is not None:
62
+ num_layers = thickness.shape[-1]
63
+ if solvent_mode == 'fronting':
64
+ assert sld.shape[-1] == num_layers + 2
65
+ assert solvent_vf.shape[-1] == num_layers
66
+ solvent_sld = sld[..., [0]]
67
+ idx = slice(1, num_layers)
68
+ sld[..., idx] = solvent_vf * solvent_sld + (1.0 - solvent_vf) * sld[..., idx]
69
+ elif solvent_mode == 'backing':
70
+ solvent_sld = sld[..., [-1]]
71
+ idx = slice(1, num_layers) if sld.shape[-1] == num_layers + 2 else slice(0, num_layers)
72
+ sld[..., idx] = solvent_vf * solvent_sld + (1.0 - solvent_vf) * sld[..., idx]
73
+ else:
74
+ raise NotImplementedError
75
+
76
+ if dq is None:
77
+ reflectivity_curves = abeles_func(q, thickness, roughness, sld, **abeles_kwargs)
78
+ else:
79
+ if dq.shape[-1] > 1:
80
+ reflectivity_curves = abeles_pointwise_smearing(
81
+ q=q, dq=dq, thickness=thickness, roughness=roughness, sld=sld,
82
+ abeles_func=abeles_func, gauss_num=gauss_num,
83
+ **abeles_kwargs,
84
+ )
85
+ else:
86
+ reflectivity_curves = abeles_constant_smearing(
87
+ q, thickness, roughness, sld,
88
+ dq=dq, gauss_num=gauss_num, constant_dq=constant_dq, abeles_func=abeles_func,
89
+ **abeles_kwargs,
90
+ )
91
+
92
+ if isinstance(r_scale, Tensor):
93
+ r_scale = r_scale.view(-1, *[1] * (reflectivity_curves.dim() - 1))
94
+ if isinstance(background, Tensor):
95
+ background = background.view(-1, *[1] * (reflectivity_curves.dim() - 1))
96
+
97
+ reflectivity_curves = reflectivity_curves * r_scale + background
98
+
99
+ if log:
100
+ reflectivity_curves = torch.log10(reflectivity_curves)
101
+
102
+ return reflectivity_curves
@@ -0,0 +1,97 @@
1
+ # -*- coding: utf-8 -*-
2
+ import math
3
+ from functools import reduce
4
+
5
+ import torch
6
+ from torch import Tensor
7
+
8
+
9
+ def abeles(
10
+ q: Tensor,
11
+ thickness: Tensor,
12
+ roughness: Tensor,
13
+ sld: Tensor,
14
+ ):
15
+ """Simulates reflectivity curves for SLD profiles with box model parameterization using the Abeles matrix method
16
+
17
+ Args:
18
+ q (Tensor): tensor of momentum transfer (q) values with shape [batch_size, n_points] or [n_points]
19
+ thickness (Tensor): tensor containing the layer thicknesses (ordered from top to bottom) with shape [batch_size, n_layers]
20
+ roughness (Tensor): tensor containing the interlayer roughnesses (ordered from top to bottom) with shape [batch_size, n_layers + 1]
21
+ sld (Tensor): tensor containing the layer SLDs (real or complex; ordered from top to bottom). The tensor shape should be one of the following:
22
+ - [batch_size, n_layers + 1]: in this case, the ambient SLD is not included but assumed to be 0
23
+ - [batch_size, n_layers + 2]: this shape includes the ambient SLD as the first element in the tensor
24
+
25
+ Returns:
26
+ Tensor: tensor containing the simulated reflectivity curves with shape [batch_size, n_points]
27
+ """
28
+ c_dtype = torch.complex128 if q.dtype is torch.float64 else torch.complex64
29
+
30
+ batch_size, num_layers = thickness.shape
31
+
32
+ if sld.shape[-1] == num_layers + 1:
33
+ # add zero ambient sld
34
+ sld = torch.cat([torch.zeros(batch_size, 1).to(sld), sld], -1)
35
+ if sld.shape[-1] != num_layers + 2:
36
+ raise ValueError(
37
+ "Number of SLD values does not equal to num_layers + 2 (substrate + ambient)."
38
+ )
39
+
40
+ sld = sld[:, None]
41
+
42
+ # add zero thickness for ambient layer:
43
+ thickness = torch.cat([torch.zeros(batch_size, 1).to(thickness), thickness], -1)[
44
+ :, None
45
+ ]
46
+
47
+ roughness = roughness[:, None] ** 2
48
+
49
+ sld = (sld - sld[..., :1]) * 1e-6 + 1e-36j
50
+
51
+ k_z0 = (q / 2).to(c_dtype)
52
+
53
+ if k_z0.dim() == 1:
54
+ k_z0.unsqueeze_(0)
55
+
56
+ if k_z0.dim() == 2:
57
+ k_z0.unsqueeze_(-1)
58
+
59
+ k_n = torch.sqrt(k_z0**2 - 4 * math.pi * sld)
60
+
61
+ # k_n.shape - (batch, q, layers)
62
+
63
+ k_n, k_np1 = k_n[..., :-1], k_n[..., 1:]
64
+
65
+ beta = 1j * thickness * k_n
66
+
67
+ exp_beta = torch.exp(beta)
68
+ exp_m_beta = torch.exp(-beta)
69
+
70
+ rn = (k_n - k_np1) / (k_n + k_np1) * torch.exp(-2 * k_n * k_np1 * roughness)
71
+
72
+ c_matrices = torch.stack(
73
+ [
74
+ torch.stack([exp_beta, rn * exp_m_beta], -1),
75
+ torch.stack([rn * exp_beta, exp_m_beta], -1),
76
+ ],
77
+ -1,
78
+ )
79
+
80
+ c_matrices = [c.squeeze(-3) for c in c_matrices.split(1, -3)]
81
+
82
+ m = reduce(torch.matmul, c_matrices)
83
+
84
+ r = (m[..., 1, 0] / m[..., 0, 0]).abs() ** 2
85
+ r = torch.clamp_max_(r, 1.0)
86
+
87
+ return r
88
+
89
+
90
+ # @torch.jit.script # commented so far due to complex numbers issue
91
+ def abeles_compiled(
92
+ q: Tensor,
93
+ thickness: Tensor,
94
+ roughness: Tensor,
95
+ sld: Tensor,
96
+ ):
97
+ return abeles(q, thickness, roughness, sld)
@@ -0,0 +1,71 @@
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import torch
4
+ from torch import Tensor
5
+
6
+
7
+ def kinematical_approximation(
8
+ q: Tensor,
9
+ thickness: Tensor,
10
+ roughness: Tensor,
11
+ sld: Tensor,
12
+ *,
13
+ apply_fresnel: bool = True,
14
+ log: bool = False,
15
+ ):
16
+ """Simulates reflectivity curves for SLD profiles with box model parameterization using the kinematical approximation
17
+
18
+ Args:
19
+ q (Tensor): tensor of momentum transfer (q) values with shape [batch_size, n_points] or [n_points]
20
+ thickness (Tensor): tensor containing the layer thicknesses (ordered from top to bottom) with shape [batch_size, n_layers]
21
+ roughness (Tensor): tensor containing the interlayer roughnesses (ordered from top to bottom) with shape [batch_size, n_layers + 1]
22
+ sld (Tensor): tensor containing the layer SLDs (real or complex; ordered from top to bottom) with shape [batch_size, n_layers + 1].
23
+ It includes the substrate but excludes the ambient medium which is assumed to have an SLD of 0.
24
+ apply_fresnel (bool, optional): whether to use the Fresnel coefficient in the computation. Defaults to ``True``.
25
+ log (bool, optional): if True the base 10 logarithm of the reflectivity curves is returned. Defaults to ``False``.
26
+
27
+ Returns:
28
+ Tensor: tensor containing the simulated reflectivity curves with shape [batch_size, n_points]
29
+ """
30
+ c_dtype = torch.complex128 if q.dtype is torch.float64 else torch.complex64
31
+
32
+ batch_size, num_layers = thickness.shape
33
+
34
+ q = q.to(c_dtype)
35
+
36
+ if q.dim() == 1:
37
+ q.unsqueeze_(0)
38
+
39
+ if q.dim() == 2:
40
+ q.unsqueeze_(-1)
41
+
42
+ sld = sld * 1e-6 + 1e-30j
43
+
44
+ drho = torch.cat([sld[..., 0][..., None], sld[..., 1:] - sld[..., :-1]], -1)[:, None]
45
+ thickness = torch.cumsum(torch.cat([torch.zeros(batch_size, 1).to(thickness), thickness], -1), -1)[:, None]
46
+ roughness = roughness[:, None]
47
+
48
+ r = (drho * torch.exp(- (roughness * q) ** 2 / 2 + 1j * (q * thickness))).sum(-1).abs().float() ** 2
49
+
50
+ if apply_fresnel:
51
+
52
+ substrate_sld = sld[:, -1:]
53
+
54
+ rf = _get_fresnel_reflectivity(q, substrate_sld[:, None])
55
+
56
+ r = torch.clamp_max_(r * rf / substrate_sld.real ** 2, 1.)
57
+
58
+ if log:
59
+ r = torch.log10(r)
60
+
61
+ return r
62
+
63
+
64
+ def _get_fresnel_reflectivity(q, substrate_slds):
65
+ _RE_CONST = 0.28174103675406496 # 2/sqrt(16*pi)
66
+
67
+ q_c = torch.sqrt(substrate_slds + 0j) / _RE_CONST * 2
68
+ q_prime = torch.sqrt(q ** 2 - q_c ** 2 + 0j)
69
+ r_f = ((q - q_prime) / (q + q_prime)).abs().float() ** 2
70
+
71
+ return r_f.squeeze(-1)
@@ -0,0 +1,105 @@
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from math import pi
4
+
5
+ import torch
6
+ from torch import Tensor
7
+
8
+
9
+ def abeles_memory_eff(
10
+ q: Tensor,
11
+ thickness: Tensor,
12
+ roughness: Tensor,
13
+ sld: Tensor,
14
+ ):
15
+ """Simulates reflectivity curves for SLD profiles with box model parameterization using a memory-efficient implementation the Abeles matrix method.
16
+ It is computationally slower compared to the implementation in the 'abeles' function.
17
+
18
+ Args:
19
+ q (Tensor): tensor of momentum transfer (q) values with shape [batch_size, n_points] or [n_points]
20
+ thickness (Tensor): tensor containing the layer thicknesses (ordered from top to bottom) with shape [batch_size, n_layers]
21
+ roughness (Tensor): tensor containing the interlayer roughnesses (ordered from top to bottom) with shape [batch_size, n_layers + 1]
22
+ sld (Tensor): tensor containing the layer SLDs (real or complex; ordered from top to bottom) with shape [batch_size, n_layers + 1].
23
+ It includes the substrate but excludes the ambient medium which is assumed to have an SLD of 0.
24
+
25
+ Returns:
26
+ Tensor: tensor containing the simulated reflectivity curves with shape [batch_size, n_points]
27
+ """
28
+ c_dtype = torch.complex128 if q.dtype is torch.float64 else torch.complex64
29
+
30
+ batch_size, num_layers = thickness.shape
31
+
32
+ sld = sld * 1e-6 + 1e-30j
33
+
34
+ num_interfaces = num_layers + 1
35
+
36
+ k_z0 = (q / 2).to(c_dtype)
37
+
38
+ if len(k_z0.shape) == 1:
39
+ k_z0.unsqueeze_(0)
40
+
41
+ thickness_prev_layer = 1. # ambient
42
+
43
+ for interface_num in range(num_interfaces):
44
+
45
+ prev_layer_idx = interface_num - 1
46
+ next_layer_idx = interface_num
47
+
48
+ if interface_num == 0:
49
+ k_z_previous_layer = _get_relative_k_z(k_z0, torch.zeros(batch_size, 1).to(sld))
50
+ else:
51
+ thickness_prev_layer = thickness[:, prev_layer_idx].unsqueeze(1)
52
+ k_z_previous_layer = _get_relative_k_z(k_z0, sld[:, prev_layer_idx].unsqueeze(1))
53
+
54
+ k_z_next_layer = _get_relative_k_z(k_z0, sld[:, next_layer_idx].unsqueeze(1)) # (batch_num, q_num)
55
+
56
+ reflection_matrix = _make_reflection_matrix(
57
+ k_z_previous_layer, k_z_next_layer, roughness[:, interface_num].unsqueeze(1)
58
+ )
59
+
60
+ if interface_num == 0:
61
+ total_reflectivity_matrix = reflection_matrix
62
+ else:
63
+ translation_matrix = _make_translation_matrix(k_z_previous_layer, thickness_prev_layer)
64
+
65
+ total_reflectivity_matrix = torch.einsum(
66
+ 'bnmr, bmlr, bljr -> bnjr', total_reflectivity_matrix, translation_matrix, reflection_matrix
67
+ )
68
+
69
+ r = total_reflectivity_matrix[:, 0, 1] / total_reflectivity_matrix[:, 1, 1]
70
+
71
+ reflectivity = torch.clamp_max_(torch.abs(r) ** 2, 1.).flatten(1)
72
+
73
+ return reflectivity
74
+
75
+
76
+ def _get_relative_k_z(k_z0, scattering_length_density):
77
+ return torch.sqrt(k_z0 ** 2 - 4 * pi * scattering_length_density)
78
+
79
+
80
+ def _make_reflection_matrix(k_z_previous_layer, k_z_next_layer, interface_roughness):
81
+ p = _safe_div((k_z_previous_layer + k_z_next_layer), (2 * k_z_previous_layer)) * \
82
+ torch.exp(-(k_z_previous_layer - k_z_next_layer) ** 2 * 0.5 * interface_roughness ** 2)
83
+
84
+ m = _safe_div((k_z_previous_layer - k_z_next_layer), (2 * k_z_previous_layer)) * \
85
+ torch.exp(-(k_z_previous_layer + k_z_next_layer) ** 2 * 0.5 * interface_roughness ** 2)
86
+
87
+ return _stack_mtx(p, m, m, p)
88
+
89
+
90
+ def _stack_mtx(a11, a12, a21, a22):
91
+ return torch.stack([
92
+ torch.stack([a11, a12], dim=1),
93
+ torch.stack([a21, a22], dim=1),
94
+ ], dim=1)
95
+
96
+
97
+ def _make_translation_matrix(k_z, thickness):
98
+ return _stack_mtx(
99
+ torch.exp(-1j * k_z * thickness), torch.zeros_like(k_z),
100
+ torch.zeros_like(k_z), torch.exp(1j * k_z * thickness)
101
+ )
102
+
103
+
104
+ def _safe_div(numerator, denominator):
105
+ return torch.where(denominator == 0, numerator, torch.divide(numerator, denominator))