reflectorch 1.2.1__py3-none-any.whl → 1.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of reflectorch might be problematic. Click here for more details.

Files changed (41) hide show
  1. reflectorch/data_generation/__init__.py +4 -0
  2. reflectorch/data_generation/dataset.py +27 -7
  3. reflectorch/data_generation/noise.py +115 -9
  4. reflectorch/data_generation/priors/parametric_models.py +91 -16
  5. reflectorch/data_generation/priors/parametric_subpriors.py +28 -7
  6. reflectorch/data_generation/priors/sampler_strategies.py +67 -3
  7. reflectorch/data_generation/q_generator.py +97 -43
  8. reflectorch/data_generation/reflectivity/__init__.py +53 -11
  9. reflectorch/data_generation/reflectivity/kinematical.py +4 -5
  10. reflectorch/data_generation/reflectivity/smearing.py +25 -10
  11. reflectorch/data_generation/reflectivity/smearing_pointwise.py +110 -0
  12. reflectorch/data_generation/smearing.py +42 -11
  13. reflectorch/data_generation/utils.py +93 -18
  14. reflectorch/extensions/refnx/refnx_conversion.py +77 -0
  15. reflectorch/inference/inference_model.py +795 -159
  16. reflectorch/inference/loading_data.py +37 -0
  17. reflectorch/inference/plotting.py +517 -0
  18. reflectorch/inference/preprocess_exp/interpolation.py +5 -2
  19. reflectorch/inference/scipy_fitter.py +98 -7
  20. reflectorch/ml/__init__.py +2 -0
  21. reflectorch/ml/basic_trainer.py +18 -6
  22. reflectorch/ml/callbacks.py +5 -4
  23. reflectorch/ml/loggers.py +25 -0
  24. reflectorch/ml/schedulers.py +116 -0
  25. reflectorch/ml/trainers.py +131 -23
  26. reflectorch/models/__init__.py +2 -1
  27. reflectorch/models/encoders/__init__.py +2 -2
  28. reflectorch/models/encoders/conv_encoder.py +54 -40
  29. reflectorch/models/encoders/fno.py +23 -16
  30. reflectorch/models/encoders/integral_kernel_embedding.py +390 -0
  31. reflectorch/models/networks/__init__.py +2 -0
  32. reflectorch/models/networks/mlp_networks.py +331 -153
  33. reflectorch/models/networks/residual_net.py +31 -5
  34. reflectorch/runs/train.py +0 -1
  35. reflectorch/runs/utils.py +48 -11
  36. reflectorch/utils.py +30 -0
  37. {reflectorch-1.2.1.dist-info → reflectorch-1.4.0.dist-info}/METADATA +20 -17
  38. {reflectorch-1.2.1.dist-info → reflectorch-1.4.0.dist-info}/RECORD +41 -36
  39. {reflectorch-1.2.1.dist-info → reflectorch-1.4.0.dist-info}/WHEEL +1 -1
  40. {reflectorch-1.2.1.dist-info → reflectorch-1.4.0.dist-info/licenses}/LICENSE.txt +0 -0
  41. {reflectorch-1.2.1.dist-info → reflectorch-1.4.0.dist-info}/top_level.txt +0 -0
@@ -73,11 +73,13 @@ class ConstrainedRoughnessSamplerStrategy(BasicSamplerStrategy):
73
73
  roughness_mask: Tensor,
74
74
  logdist: bool = False,
75
75
  max_thickness_share: float = 0.5,
76
+ max_total_thickness: float = None,
76
77
  ):
77
78
  super().__init__(logdist=logdist)
78
79
  self.thickness_mask = thickness_mask
79
80
  self.roughness_mask = roughness_mask
80
81
  self.max_thickness_share = max_thickness_share
82
+ self.max_total_thickness = max_total_thickness
81
83
 
82
84
  def sample(self, batch_size: int,
83
85
  total_min_bounds: Tensor,
@@ -106,7 +108,8 @@ class ConstrainedRoughnessSamplerStrategy(BasicSamplerStrategy):
106
108
  thickness_mask=self.thickness_mask.to(device),
107
109
  roughness_mask=self.roughness_mask.to(device),
108
110
  widths_sampler_func=self.widths_sampler_func,
109
- coef=self.max_thickness_share,
111
+ coef_roughness=self.max_thickness_share,
112
+ max_total_thickness=self.max_total_thickness,
110
113
  )
111
114
 
112
115
  class ConstrainedRoughnessAndImgSldSamplerStrategy(BasicSamplerStrategy):
@@ -129,6 +132,7 @@ class ConstrainedRoughnessAndImgSldSamplerStrategy(BasicSamplerStrategy):
129
132
  logdist: bool = False,
130
133
  max_thickness_share: float = 0.5,
131
134
  max_sld_share: float = 0.2,
135
+ max_total_thickness: float = None,
132
136
  ):
133
137
  super().__init__(logdist=logdist)
134
138
  self.thickness_mask = thickness_mask
@@ -137,6 +141,7 @@ class ConstrainedRoughnessAndImgSldSamplerStrategy(BasicSamplerStrategy):
137
141
  self.isld_mask = isld_mask
138
142
  self.max_thickness_share = max_thickness_share
139
143
  self.max_sld_share = max_sld_share
144
+ self.max_total_thickness = max_total_thickness
140
145
 
141
146
  def sample(self, batch_size: int,
142
147
  total_min_bounds: Tensor,
@@ -169,6 +174,7 @@ class ConstrainedRoughnessAndImgSldSamplerStrategy(BasicSamplerStrategy):
169
174
  widths_sampler_func=self.widths_sampler_func,
170
175
  coef_roughness=self.max_thickness_share,
171
176
  coef_isld=self.max_sld_share,
177
+ max_total_thickness=self.max_total_thickness,
172
178
  )
173
179
 
174
180
  def basic_sampler(
@@ -214,15 +220,44 @@ def constrained_roughness_sampler(
214
220
  thickness_mask: Tensor,
215
221
  roughness_mask: Tensor,
216
222
  widths_sampler_func,
217
- coef: float = 0.5,
223
+ coef_roughness: float = 0.5,
224
+ max_total_thickness: float = None,
218
225
  ):
219
226
  params, min_bounds, max_bounds = basic_sampler(
220
227
  batch_size, total_min_bounds, total_max_bounds, total_min_delta, total_max_delta,
221
228
  widths_sampler_func=widths_sampler_func,
222
229
  )
223
230
 
231
+ if max_total_thickness is not None:
232
+ total_thickness = max_bounds[:, thickness_mask].sum(-1)
233
+ indices = total_thickness > max_total_thickness
234
+
235
+ if indices.any():
236
+ eps = 0.01
237
+ rand_scale = torch.rand_like(total_thickness) * eps + 1 - eps
238
+ scale_coef = max_total_thickness / total_thickness * rand_scale
239
+ scale_coef[~indices] = 1.0
240
+ min_bounds[:, thickness_mask] *= scale_coef[:, None]
241
+ max_bounds[:, thickness_mask] *= scale_coef[:, None]
242
+ params[:, thickness_mask] *= scale_coef[:, None]
243
+
244
+ min_bounds[:, thickness_mask] = torch.clamp_min(
245
+ min_bounds[:, thickness_mask],
246
+ total_min_bounds[:, thickness_mask],
247
+ )
248
+
249
+ max_bounds[:, thickness_mask] = torch.clamp_min(
250
+ max_bounds[:, thickness_mask],
251
+ total_min_bounds[:, thickness_mask],
252
+ )
253
+
254
+ params[:, thickness_mask] = torch.clamp_min(
255
+ params[:, thickness_mask],
256
+ total_min_bounds[:, thickness_mask],
257
+ )
258
+
224
259
  max_roughness = torch.minimum(
225
- get_max_allowed_roughness(thicknesses=params[..., thickness_mask], coef=coef),
260
+ get_max_allowed_roughness(thicknesses=params[..., thickness_mask], coef=coef_roughness),
226
261
  total_max_bounds[..., roughness_mask]
227
262
  )
228
263
  min_roughness = total_min_bounds[..., roughness_mask]
@@ -256,12 +291,41 @@ def constrained_roughness_and_isld_sampler(
256
291
  widths_sampler_func,
257
292
  coef_roughness: float = 0.5,
258
293
  coef_isld: float = 0.2,
294
+ max_total_thickness: float = None,
259
295
  ):
260
296
  params, min_bounds, max_bounds = basic_sampler(
261
297
  batch_size, total_min_bounds, total_max_bounds, total_min_delta, total_max_delta,
262
298
  widths_sampler_func=widths_sampler_func,
263
299
  )
264
300
 
301
+ if max_total_thickness is not None:
302
+ total_thickness = max_bounds[:, thickness_mask].sum(-1)
303
+ indices = total_thickness > max_total_thickness
304
+
305
+ if indices.any():
306
+ eps = 0.01
307
+ rand_scale = torch.rand_like(total_thickness) * eps + 1 - eps
308
+ scale_coef = max_total_thickness / total_thickness * rand_scale
309
+ scale_coef[~indices] = 1.0
310
+ min_bounds[:, thickness_mask] *= scale_coef[:, None]
311
+ max_bounds[:, thickness_mask] *= scale_coef[:, None]
312
+ params[:, thickness_mask] *= scale_coef[:, None]
313
+
314
+ min_bounds[:, thickness_mask] = torch.clamp_min(
315
+ min_bounds[:, thickness_mask],
316
+ total_min_bounds[:, thickness_mask],
317
+ )
318
+
319
+ max_bounds[:, thickness_mask] = torch.clamp_min(
320
+ max_bounds[:, thickness_mask],
321
+ total_min_bounds[:, thickness_mask],
322
+ )
323
+
324
+ params[:, thickness_mask] = torch.clamp_min(
325
+ params[:, thickness_mask],
326
+ total_min_bounds[:, thickness_mask],
327
+ )
328
+
265
329
  max_roughness = torch.minimum(
266
330
  get_max_allowed_roughness(thicknesses=params[..., thickness_mask], coef=coef_roughness),
267
331
  total_max_bounds[..., roughness_mask]
@@ -16,6 +16,7 @@ __all__ = [
16
16
  "VariableQ",
17
17
  "EquidistantQ",
18
18
  "ConstantAngle",
19
+ "MaskedVariableQ",
19
20
  ]
20
21
 
21
22
 
@@ -50,6 +51,8 @@ class ConstantQ(QGenerator):
50
51
  q = q[1:]
51
52
  else:
52
53
  q = q[:-1]
54
+ self.q_min = q.min().item()
55
+ self.q_max = q.max().item()
53
56
  self.q = q
54
57
 
55
58
  def get_batch(self, batch_size: int, context: dict = None) -> Tensor:
@@ -63,14 +66,26 @@ class ConstantQ(QGenerator):
63
66
  """
64
67
  return self.q.clone()[None].expand(batch_size, self.q.shape[0])
65
68
 
69
+ def scale_q(self, q):
70
+ """Scales the q values to the range [-1, 1].
71
+
72
+ Args:
73
+ q (Tensor): unscaled q values
74
+
75
+ Returns:
76
+ Tensor: scaled q values
77
+ """
78
+ scaled_q_01 = (q - self.q_min) / (self.q_max - self.q_min)
79
+ return 2.0 * (scaled_q_01 - 0.5)
80
+
66
81
 
67
82
  class VariableQ(QGenerator):
68
83
  """Q generator for reflectivity curves with variable discretization
69
84
 
70
85
  Args:
71
- q_min_range (list, optional): the range for sampling the minimum q value of the curves, *q_min*. Defaults to [0.01, 0.03].
72
- q_max_range (list, optional): the range for sampling the maximum q value of the curves, *q_max*. Defaults to [0.1, 0.5].
73
- n_q_range (list, optional): the range for the number of points in the curves (equidistantly sampled between *q_min* and *q_max*,
86
+ q_min_range (list, optional): the range for sampling the minimum q value of the curves, q_min. Defaults to [0.01, 0.03].
87
+ q_max_range (list, optional): the range for sampling the maximum q value of the curves, q_max. Defaults to [0.1, 0.5].
88
+ n_q_range (list, optional): the range for the number of points in the curves (equidistantly sampled between q_min and q_max,
74
89
  the number of points varies between batches but is constant within a batch). Defaults to [64, 256].
75
90
  device (optional): the Pytorch device. Defaults to DEFAULT_DEVICE.
76
91
  dtype (optional): the Pytorch data type. Defaults to DEFAULT_DTYPE.
@@ -80,12 +95,14 @@ class VariableQ(QGenerator):
80
95
  q_min_range: Tuple[float, float] = (0.01, 0.03),
81
96
  q_max_range: Tuple[float, float] = (0.1, 0.5),
82
97
  n_q_range: Tuple[int, int] = (64, 256),
98
+ mode: str = 'equidistant',
83
99
  device=DEFAULT_DEVICE,
84
100
  dtype=DEFAULT_DTYPE,
85
101
  ):
86
102
  self.q_min_range = q_min_range
87
103
  self.q_max_range = q_max_range
88
104
  self.n_q_range = n_q_range
105
+ self.mode = mode
89
106
  self.device = device
90
107
  self.dtype = dtype
91
108
 
@@ -98,14 +115,23 @@ class VariableQ(QGenerator):
98
115
  Returns:
99
116
  Tensor: generated batch of q values
100
117
  """
101
- q_min = np.random.uniform(*self.q_min_range, batch_size)
102
- q_max = np.random.uniform(*self.q_max_range, batch_size)
103
- if self.n_q_range[0] == self.n_q_range[1]:
104
- n_q = self.n_q_range[0]
105
- else:
106
- n_q = np.random.randint(self.n_q_range[0], self.n_q_range[1] + 1)
107
-
108
- q = torch.from_numpy(np.linspace(q_min, q_max, n_q).T).to(self.device).to(self.dtype)
118
+
119
+ q_min = torch.rand(batch_size, device=self.device, dtype=self.dtype) * (self.q_min_range[1] - self.q_min_range[0]) + self.q_min_range[0]
120
+ q_max = torch.rand(batch_size, device=self.device, dtype=self.dtype) * (self.q_max_range[1] - self.q_max_range[0]) + self.q_max_range[0]
121
+
122
+ n_q = torch.randint(self.n_q_range[0], self.n_q_range[1] + 1, (1,), device=self.device).item()
123
+
124
+ if self.mode == 'equidistant':
125
+ q = torch.linspace(0, 1, n_q, device=self.device, dtype=self.dtype)
126
+ elif self.mode == 'random':
127
+ q = torch.rand(n_q, device=self.device, dtype=self.dtype).sort().values
128
+ elif self.mode == 'logspace':
129
+ q = torch.logspace(
130
+ start=torch.log10(torch.tensor(1e-4, dtype=self.dtype, device=self.device)),
131
+ end=torch.log10(torch.tensor(1.0, dtype=self.dtype, device=self.device)),
132
+ steps=n_q, dtype=self.dtype, device=self.device)
133
+
134
+ q = q_min[:, None] + q * (q_max - q_min)[:, None]
109
135
 
110
136
  return q
111
137
 
@@ -178,49 +204,77 @@ class EquidistantQ(QGenerator):
178
204
  return qs
179
205
 
180
206
 
181
- class TransformerQ(QGenerator):
207
+ class MaskedVariableQ:
182
208
  def __init__(self,
183
- q_max: float = 0.2,
184
- num_values: Union[int, Tuple[int, int]] = (30, 512),
185
- min_dq_ratio: float = 5.,
186
- device=None,
187
- dtype=torch.float64,
188
- ):
189
- self.min_dq_ratio = min_dq_ratio
190
- self.q_max = q_max
191
- self._dq_range = q_max / num_values[1], q_max / num_values[0]
192
- self._num_values = num_values
209
+ q_min_range=(0.01, 0.03),
210
+ q_max_range=(0.1, 0.5),
211
+ n_q_range=(64, 256),
212
+ mode='equidistant',
213
+ shuffle_mask=False,
214
+ total_thickness_constraint=True,
215
+ min_points_per_fringe=4,
216
+ device=DEFAULT_DEVICE,
217
+ dtype=DEFAULT_DTYPE):
218
+ self.q_min_range = q_min_range
219
+ self.q_max_range = q_max_range
220
+ self.n_q_range = n_q_range
193
221
  self.device = device
194
222
  self.dtype = dtype
195
-
196
- def get_batch(self, batch_size: int, context: dict = None) -> Tensor:
223
+ self.mode = mode
224
+ self.shuffle_mask = shuffle_mask
225
+ self.total_thickness_constraint = total_thickness_constraint
226
+ self.min_points_per_fringe = min_points_per_fringe
227
+
228
+ def get_batch(self, batch_size, context):
197
229
  assert context is not None
198
230
 
199
- params: BasicParams = context['params']
200
- total_thickness = params.thicknesses.sum(-1)
231
+ q_min = torch.rand(batch_size, device=self.device, dtype=self.dtype) * (self.q_min_range[1] - self.q_min_range[0]) + self.q_min_range[0]
232
+ q_max = torch.rand(batch_size, device=self.device, dtype=self.dtype) * (self.q_max_range[1] - self.q_max_range[0]) + self.q_max_range[0]
201
233
 
202
- assert total_thickness.shape[0] == batch_size
234
+ max_n_q = self.n_q_range[1]
203
235
 
204
- min_dqs = torch.clamp(
205
- 2 * np.pi / total_thickness / self.min_dq_ratio, self._dq_range[0], self._dq_range[1] * 0.9
206
- )
236
+ if self.mode == 'equidistant':
237
+ positions = torch.linspace(0, 1, max_n_q, device=self.device, dtype=self.dtype).expand(batch_size, max_n_q)
238
+ elif self.mode == 'random':
239
+ positions = torch.rand(batch_size, max_n_q, device=self.device, dtype=self.dtype)
240
+ positions, _ = positions.sort(dim=-1)
241
+ elif self.mode == 'mixed':
242
+ positions = torch.empty(batch_size, max_n_q, device=self.device, dtype=self.dtype)
207
243
 
208
- dqs = torch.rand_like(min_dqs) * (self._dq_range[1] - min_dqs) + min_dqs
244
+ half = batch_size // 2 # half batch gets equidistant
245
+ eq_pos = torch.linspace(0, 1, max_n_q, device=self.device, dtype=self.dtype).expand(half, max_n_q)
246
+ positions[:half] = eq_pos
209
247
 
210
- num_q_values = torch.clamp(self.q_max // dqs, *self._num_values).to(torch.int)
248
+ rand_pos = torch.rand(batch_size - half, max_n_q, device=self.device, dtype=self.dtype) # other half gets sorted random
249
+ rand_pos, _ = rand_pos.sort(dim=-1)
250
+ positions[half:] = rand_pos
251
+ else:
252
+ raise ValueError(f"Unknown spacing mode: {self.mode}")
211
253
 
212
- q_values, mask = generate_q_padding_mask(num_q_values, self.q_max)
254
+ q = q_min[:, None] + positions * (q_max - q_min)[:, None]
213
255
 
214
- context['tgt_key_padding_mask'] = mask
215
- context['num_q_values'] = num_q_values
256
+ n_qs = torch.randint(self.n_q_range[0], self.n_q_range[1] + 1, (batch_size,), device=self.device)
216
257
 
217
- return q_values
258
+ if 'params' in context and self.total_thickness_constraint: ### N_points > 1 + (Q_spread * total_thickness * min_np_per_kiessing_fringe) / (2*pi)
259
+ d_total = context['params'].thicknesses.sum(-1)
260
+ limit = 1 + ((q_max - q_min) * d_total * self.min_points_per_fringe) / (2*np.pi)
261
+ limit = limit.ceil().int()
262
+ n_qs = torch.maximum(n_qs, limit)
263
+ n_qs = torch.clamp(n_qs, max=self.n_q_range[1])
218
264
 
265
+ indices = torch.arange(max_n_q, device=self.device).expand(batch_size, max_n_q)
266
+ valid_mask = indices < n_qs[:, None] # right side padding
267
+
268
+ if self.shuffle_mask: # shuffle valid positions (inter-spread padding)
269
+ perm = torch.argsort(torch.rand(batch_size, max_n_q, device=self.device), dim=-1)
270
+ valid_mask = torch.gather(valid_mask, dim=1, index=perm)
271
+
272
+ context['key_padding_mask'] = valid_mask
273
+ context['n_points'] = valid_mask.sum(dim=-1)
274
+
275
+ return q
276
+
277
+ def scale_q(self, q):
278
+ scaled_q_01 = (q - self.q_min_range[0]) / (self.q_max_range[1] - self.q_min_range[0])
219
279
 
220
- def generate_q_padding_mask(num_q_values: Tensor, q_max: float):
221
- batch_size = num_q_values.shape[0]
222
- dqs = (q_max / num_q_values)[:, None]
223
- q_values = torch.arange(1, num_q_values.max().item() + 1)[None].repeat(batch_size, 1) * dqs
224
- mask = (q_values > q_max + dqs / 2)
225
- q_values[mask] = 0.
226
- return q_values, mask
280
+ return 2.0 * (scaled_q_01 - 0.5)
@@ -10,6 +10,7 @@ from reflectorch.data_generation.reflectivity.numpy_implementations import (
10
10
  abeles_np,
11
11
  )
12
12
  from reflectorch.data_generation.reflectivity.smearing import abeles_constant_smearing
13
+ from reflectorch.data_generation.reflectivity.smearing_pointwise import abeles_pointwise_smearing
13
14
  from reflectorch.data_generation.reflectivity.kinematical import kinematical_approximation
14
15
 
15
16
 
@@ -20,9 +21,15 @@ def reflectivity(
20
21
  sld: Tensor,
21
22
  dq: Tensor = None,
22
23
  gauss_num: int = 51,
23
- constant_dq: bool = True,
24
+ constant_dq: bool = False,
24
25
  log: bool = False,
25
- abeles_func=None,
26
+ q_shift: Tensor = 0.0,
27
+ r_scale: Tensor = 1.0,
28
+ background: Tensor = 0.0,
29
+ solvent_vf = None,
30
+ solvent_mode = 'fronting',
31
+ abeles_func = None,
32
+ **abeles_kwargs
26
33
  ):
27
34
  """Function which computes the reflectivity curves from thin film parameters.
28
35
  By default it uses the fast implementation of the Abeles matrix formalism.
@@ -37,24 +44,59 @@ def reflectivity(
37
44
  Either dq if ``constant_dq`` is ``True`` or dq/q if ``constant_dq`` is ``False``. Defaults to None.
38
45
  gauss_num (int, optional): the number of gaussians for curve smearing. Defaults to 51.
39
46
  constant_dq (bool, optional): if ``True`` the smearing is constant (constant dq at each point in the curve)
40
- otherwise the smearing is linear (constant dq/q at each point in the curve). Defaults to True.
47
+ otherwise the smearing is linear (constant dq/q at each point in the curve). Defaults to False.
41
48
  log (bool, optional): if True the base 10 logarithm of the reflectivity curves is returned. Defaults to False.
49
+ q_shift (float or Tensor, optional): misalignment in q.
50
+ r_scale (float or Tensor, optional): normalization factor (scales reflectivity).
51
+ background (float or Tensor, optional): background intensity.
42
52
  abeles_func (Callable, optional): a function implementing the simulation of the reflectivity curves, if different than the default Abeles matrix implementation ('abeles'). Defaults to None.
43
-
53
+ abeles_kwargs: Additional arguments specific to the chosen `abeles_func`.
44
54
  Returns:
45
- Tensor: tensor containing the simulated reflectivity curves with shape [batch_size, n_points]
55
+ Tensor: the computed reflectivity curves
46
56
  """
47
57
  abeles_func = abeles_func or abeles
48
- q = torch.atleast_2d(q)
58
+ q = torch.atleast_2d(q) + q_shift
59
+ q = torch.clamp(q, min=0.0)
60
+
61
+ if solvent_vf is not None:
62
+ num_layers = thickness.shape[-1]
63
+ if solvent_mode == 'fronting':
64
+ assert sld.shape[-1] == num_layers + 2
65
+ assert solvent_vf.shape[-1] == num_layers
66
+ solvent_sld = sld[..., [0]]
67
+ idx = slice(1, num_layers)
68
+ sld[..., idx] = solvent_vf * solvent_sld + (1.0 - solvent_vf) * sld[..., idx]
69
+ elif solvent_mode == 'backing':
70
+ solvent_sld = sld[..., [-1]]
71
+ idx = slice(1, num_layers) if sld.shape[-1] == num_layers + 2 else slice(0, num_layers)
72
+ sld[..., idx] = solvent_vf * solvent_sld + (1.0 - solvent_vf) * sld[..., idx]
73
+ else:
74
+ raise NotImplementedError
49
75
 
50
76
  if dq is None:
51
- reflectivity_curves = abeles_func(q, thickness, roughness, sld)
77
+ reflectivity_curves = abeles_func(q, thickness, roughness, sld, **abeles_kwargs)
52
78
  else:
53
- reflectivity_curves = abeles_constant_smearing(
54
- q, thickness, roughness, sld,
55
- dq=dq, gauss_num=gauss_num, constant_dq=constant_dq, abeles_func=abeles_func
56
- )
79
+ if dq.shape[-1] > 1:
80
+ reflectivity_curves = abeles_pointwise_smearing(
81
+ q=q, dq=dq, thickness=thickness, roughness=roughness, sld=sld,
82
+ abeles_func=abeles_func, gauss_num=gauss_num,
83
+ **abeles_kwargs,
84
+ )
85
+ else:
86
+ reflectivity_curves = abeles_constant_smearing(
87
+ q, thickness, roughness, sld,
88
+ dq=dq, gauss_num=gauss_num, constant_dq=constant_dq, abeles_func=abeles_func,
89
+ **abeles_kwargs,
90
+ )
91
+
92
+ if isinstance(r_scale, Tensor):
93
+ r_scale = r_scale.view(-1, *[1] * (reflectivity_curves.dim() - 1))
94
+ if isinstance(background, Tensor):
95
+ background = background.view(-1, *[1] * (reflectivity_curves.dim() - 1))
96
+
97
+ reflectivity_curves = reflectivity_curves * r_scale + background
57
98
 
58
99
  if log:
59
100
  reflectivity_curves = torch.log10(reflectivity_curves)
101
+
60
102
  return reflectivity_curves
@@ -51,7 +51,7 @@ def kinematical_approximation(
51
51
 
52
52
  substrate_sld = sld[:, -1:]
53
53
 
54
- rf = _get_resnel_reflectivity(q, substrate_sld[:, None])
54
+ rf = _get_fresnel_reflectivity(q, substrate_sld[:, None])
55
55
 
56
56
  r = torch.clamp_max_(r * rf / substrate_sld.real ** 2, 1.)
57
57
 
@@ -61,12 +61,11 @@ def kinematical_approximation(
61
61
  return r
62
62
 
63
63
 
64
- def _get_resnel_reflectivity(q, substrate_slds):
65
- _RE_CONST = 0.28174103675406496
64
+ def _get_fresnel_reflectivity(q, substrate_slds):
65
+ _RE_CONST = 0.28174103675406496 # 2/sqrt(16*pi)
66
66
 
67
67
  q_c = torch.sqrt(substrate_slds + 0j) / _RE_CONST * 2
68
68
  q_prime = torch.sqrt(q ** 2 - q_c ** 2 + 0j)
69
69
  r_f = ((q - q_prime) / (q + q_prime)).abs().float() ** 2
70
70
 
71
- return r_f.squeeze(-1)
72
-
71
+ return r_f.squeeze(-1)
@@ -14,26 +14,41 @@ def abeles_constant_smearing(
14
14
  roughness: Tensor,
15
15
  sld: Tensor,
16
16
  dq: Tensor = None,
17
- gauss_num: int = 51,
18
- constant_dq: bool = True,
17
+ gauss_num: int = 31,
18
+ constant_dq: bool = False,
19
19
  abeles_func=None,
20
+ **abeles_kwargs
20
21
  ):
21
22
  abeles_func = abeles_func or abeles
23
+
24
+ if dq.dtype != thickness.dtype:
25
+ q = q.to(thickness)
26
+
27
+ if dq.dtype != thickness.dtype:
28
+ dq = dq.to(thickness)
29
+
30
+ if q.shape[0] == 1:
31
+ q = q.repeat(thickness.shape[0], 1)
32
+
22
33
  q_lin = _get_q_axes(q, dq, gauss_num, constant_dq=constant_dq)
23
34
  kernels = _get_t_gauss_kernels(dq, gauss_num)
24
-
25
- curves = abeles_func(q_lin, thickness, roughness, sld)
35
+
36
+ curves = abeles_func(q_lin, thickness, roughness, sld, **abeles_kwargs)
26
37
 
27
38
  padding = (kernels.shape[-1] - 1) // 2
39
+ padded_curves = pad(curves, (padding, padding), 'reflect')
40
+
28
41
  smeared_curves = conv1d(
29
- pad(curves[None], (padding, padding), 'reflect'), kernels[:, None], groups=kernels.shape[0],
30
- )[0]
42
+ padded_curves, kernels[:, None], groups=kernels.shape[0],
43
+ )
31
44
 
32
45
  if q.shape[0] != smeared_curves.shape[0]:
33
- q = q.expand(smeared_curves.shape[0], *q.shape[1:])
34
-
46
+ repeat_factor = smeared_curves.shape[0] // q.shape[0]
47
+ q = q.repeat(repeat_factor, 1)
48
+ q_lin = q_lin.repeat(repeat_factor, 1)
49
+
35
50
  smeared_curves = _batch_linear_interp1d(q_lin, smeared_curves, q)
36
-
51
+
37
52
  return smeared_curves
38
53
 
39
54
 
@@ -55,7 +70,7 @@ def _get_t_gauss_kernels(resolutions: Tensor, gaussnum: int = 51):
55
70
  return gauss_y
56
71
 
57
72
 
58
- def _get_q_axes(q: Tensor, resolutions: Tensor, gaussnum: int = 51, constant_dq: bool = True):
73
+ def _get_q_axes(q: Tensor, resolutions: Tensor, gaussnum: int = 51, constant_dq: bool = False):
59
74
  if constant_dq:
60
75
  return _get_q_axes_for_constant_dq(q, resolutions, gaussnum)
61
76
  else:
@@ -0,0 +1,110 @@
1
+ import torch
2
+ import scipy
3
+ import numpy as np
4
+ from functools import lru_cache
5
+ from typing import Tuple
6
+
7
+ from reflectorch.data_generation.reflectivity.abeles import abeles
8
+
9
+ #Pytorch version based on the JAX implementation of pointwise smearing in the refnx package.
10
+
11
+ @lru_cache(maxsize=128)
12
+ def gauss_legendre(n: int) -> Tuple[torch.Tensor, torch.Tensor]:
13
+ """
14
+ Calculate Gaussian quadrature abscissae and weights.
15
+
16
+ Args:
17
+ n (int): Gaussian quadrature order.
18
+
19
+ Returns:
20
+ Tuple[torch.Tensor, torch.Tensor]: The abscissae and weights for Gauss-Legendre integration.
21
+ """
22
+ return scipy.special.p_roots(n)
23
+
24
+ def gauss(x: torch.Tensor) -> torch.Tensor:
25
+ """
26
+ Calculate the Gaussian function.
27
+
28
+ Args:
29
+ x (torch.Tensor): Input tensor.
30
+
31
+ Returns:
32
+ torch.Tensor: Output tensor after applying the Gaussian function.
33
+ """
34
+ return torch.exp(-0.5 * x * x)
35
+
36
+ def abeles_pointwise_smearing(
37
+ q: torch.Tensor,
38
+ dq: torch.Tensor,
39
+ thickness: torch.Tensor,
40
+ roughness: torch.Tensor,
41
+ sld: torch.Tensor,
42
+ gauss_num: int = 17,
43
+ abeles_func=None,
44
+ **abeles_kwargs,
45
+ ) -> torch.Tensor:
46
+ """
47
+ Compute reflectivity with variable smearing using Gaussian quadrature.
48
+
49
+ Args:
50
+ q (torch.Tensor): The momentum transfer (q) values.
51
+ dq (torch.Tensor): The resolution for curve smearing.
52
+ thickness (torch.Tensor): The layer thicknesses.
53
+ roughness (torch.Tensor): The interlayer roughnesses.
54
+ sld (torch.Tensor): The SLDs of the layers.
55
+ sld_magnetic (torch.Tensor, optional): The magnetic SLDs of the layers.
56
+ magnetization_angle (torch.Tensor, optional): The magnetization angles.
57
+ polarizer_eff (torch.Tensor, optional): The polarizer efficiency.
58
+ analyzer_eff (torch.Tensor, optional): The analyzer efficiency.
59
+ abeles_func (Callable, optional): A function implementing the simulation of the reflectivity curves.
60
+ gauss_num (int, optional): Gaussian quadrature order. Defaults to 17.
61
+
62
+ Returns:
63
+ torch.Tensor: The computed reflectivity curves.
64
+ """
65
+ abeles_func = abeles_func or abeles
66
+
67
+ if q.shape[0] == 1:
68
+ q = q.repeat(thickness.shape[0], 1)
69
+
70
+ _FWHM = 2 * np.sqrt(2 * np.log(2.0))
71
+ _INTLIMIT = 3.5
72
+
73
+ bs = q.shape[0]
74
+ nq = q.shape[-1]
75
+ device = q.device
76
+
77
+ quad_order = gauss_num
78
+ abscissa, weights = gauss_legendre(quad_order)
79
+ abscissa = torch.tensor(abscissa)[None, :, None].to(device)
80
+ weights = torch.tensor(weights)[None, :, None].to(device)
81
+ prefactor = 1.0 / np.sqrt(2 * np.pi)
82
+
83
+ gaussvals = prefactor * gauss(abscissa * _INTLIMIT)
84
+
85
+ va = q[:, None, :] - _INTLIMIT * dq[:, None, :] / _FWHM
86
+ vb = q[:, None, :] + _INTLIMIT * dq[:, None, :] / _FWHM
87
+
88
+ qvals_for_res_0 = (abscissa * (vb - va) + vb + va) / 2
89
+ qvals_for_res = qvals_for_res_0.reshape(bs, -1)
90
+
91
+ refl_curves = abeles_func(
92
+ q=qvals_for_res,
93
+ thickness=thickness,
94
+ roughness=roughness,
95
+ sld=sld,
96
+ **abeles_kwargs
97
+ )
98
+
99
+ # Handle multiple channels
100
+ if refl_curves.dim() == 3:
101
+ n_channels = refl_curves.shape[1]
102
+ refl_curves = refl_curves.reshape(bs, n_channels, quad_order, nq)
103
+ refl_curves = refl_curves * gaussvals.unsqueeze(1) * weights.unsqueeze(1)
104
+ refl_curves = torch.sum(refl_curves, dim=2) * _INTLIMIT
105
+ else:
106
+ refl_curves = refl_curves.reshape(bs, quad_order, nq)
107
+ refl_curves = refl_curves * gaussvals * weights
108
+ refl_curves = torch.sum(refl_curves, dim=1) * _INTLIMIT
109
+
110
+ return refl_curves