reflectorch 1.2.0__py3-none-any.whl → 1.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of reflectorch might be problematic. Click here for more details.

Files changed (39) hide show
  1. reflectorch/data_generation/__init__.py +2 -0
  2. reflectorch/data_generation/dataset.py +27 -7
  3. reflectorch/data_generation/noise.py +115 -9
  4. reflectorch/data_generation/priors/parametric_models.py +90 -15
  5. reflectorch/data_generation/priors/parametric_subpriors.py +28 -7
  6. reflectorch/data_generation/priors/sampler_strategies.py +67 -3
  7. reflectorch/data_generation/q_generator.py +31 -11
  8. reflectorch/data_generation/reflectivity/__init__.py +56 -14
  9. reflectorch/data_generation/reflectivity/abeles.py +31 -16
  10. reflectorch/data_generation/reflectivity/kinematical.py +5 -6
  11. reflectorch/data_generation/reflectivity/memory_eff.py +1 -1
  12. reflectorch/data_generation/reflectivity/smearing.py +25 -10
  13. reflectorch/data_generation/reflectivity/smearing_pointwise.py +110 -0
  14. reflectorch/data_generation/smearing.py +42 -11
  15. reflectorch/data_generation/utils.py +92 -18
  16. reflectorch/extensions/refnx/refnx_conversion.py +77 -0
  17. reflectorch/inference/inference_model.py +220 -105
  18. reflectorch/inference/plotting.py +98 -0
  19. reflectorch/inference/scipy_fitter.py +84 -7
  20. reflectorch/ml/__init__.py +2 -0
  21. reflectorch/ml/basic_trainer.py +18 -6
  22. reflectorch/ml/callbacks.py +5 -4
  23. reflectorch/ml/loggers.py +25 -0
  24. reflectorch/ml/schedulers.py +116 -0
  25. reflectorch/ml/trainers.py +122 -23
  26. reflectorch/models/__init__.py +1 -1
  27. reflectorch/models/encoders/__init__.py +0 -2
  28. reflectorch/models/encoders/conv_encoder.py +54 -40
  29. reflectorch/models/encoders/fno.py +23 -16
  30. reflectorch/models/networks/__init__.py +2 -0
  31. reflectorch/models/networks/mlp_networks.py +324 -152
  32. reflectorch/models/networks/residual_net.py +31 -5
  33. reflectorch/runs/train.py +0 -1
  34. reflectorch/runs/utils.py +43 -9
  35. {reflectorch-1.2.0.dist-info → reflectorch-1.3.0.dist-info}/METADATA +19 -17
  36. {reflectorch-1.2.0.dist-info → reflectorch-1.3.0.dist-info}/RECORD +39 -36
  37. {reflectorch-1.2.0.dist-info → reflectorch-1.3.0.dist-info}/WHEEL +1 -1
  38. {reflectorch-1.2.0.dist-info → reflectorch-1.3.0.dist-info/licenses}/LICENSE.txt +0 -0
  39. {reflectorch-1.2.0.dist-info → reflectorch-1.3.0.dist-info}/top_level.txt +0 -0
@@ -50,6 +50,8 @@ class ConstantQ(QGenerator):
50
50
  q = q[1:]
51
51
  else:
52
52
  q = q[:-1]
53
+ self.q_min = q.min().item()
54
+ self.q_max = q.max().item()
53
55
  self.q = q
54
56
 
55
57
  def get_batch(self, batch_size: int, context: dict = None) -> Tensor:
@@ -63,14 +65,26 @@ class ConstantQ(QGenerator):
63
65
  """
64
66
  return self.q.clone()[None].expand(batch_size, self.q.shape[0])
65
67
 
68
+ def scale_q(self, q):
69
+ """Scales the q values to the range [-1, 1].
70
+
71
+ Args:
72
+ q (Tensor): unscaled q values
73
+
74
+ Returns:
75
+ Tensor: scaled q values
76
+ """
77
+ scaled_q_01 = (q - self.q_min) / (self.q_max - self.q_min)
78
+ return 2.0 * (scaled_q_01 - 0.5)
79
+
66
80
 
67
81
  class VariableQ(QGenerator):
68
82
  """Q generator for reflectivity curves with variable discretization
69
83
 
70
84
  Args:
71
- q_min_range (list, optional): the range for sampling the minimum q value of the curves, *q_min*. Defaults to [0.01, 0.03].
72
- q_max_range (list, optional): the range for sampling the maximum q value of the curves, *q_max*. Defaults to [0.1, 0.5].
73
- n_q_range (list, optional): the range for the number of points in the curves (equidistantly sampled between *q_min* and *q_max*,
85
+ q_min_range (list, optional): the range for sampling the minimum q value of the curves, q_min. Defaults to [0.01, 0.03].
86
+ q_max_range (list, optional): the range for sampling the maximum q value of the curves, q_max. Defaults to [0.1, 0.5].
87
+ n_q_range (list, optional): the range for the number of points in the curves (equidistantly sampled between q_min and q_max,
74
88
  the number of points varies between batches but is constant within a batch). Defaults to [64, 256].
75
89
  device (optional): the Pytorch device. Defaults to DEFAULT_DEVICE.
76
90
  dtype (optional): the Pytorch data type. Defaults to DEFAULT_DTYPE.
@@ -80,12 +94,14 @@ class VariableQ(QGenerator):
80
94
  q_min_range: Tuple[float, float] = (0.01, 0.03),
81
95
  q_max_range: Tuple[float, float] = (0.1, 0.5),
82
96
  n_q_range: Tuple[int, int] = (64, 256),
97
+ mode: str = 'equidistant',
83
98
  device=DEFAULT_DEVICE,
84
99
  dtype=DEFAULT_DTYPE,
85
100
  ):
86
101
  self.q_min_range = q_min_range
87
102
  self.q_max_range = q_max_range
88
103
  self.n_q_range = n_q_range
104
+ self.mode = mode
89
105
  self.device = device
90
106
  self.dtype = dtype
91
107
 
@@ -98,14 +114,18 @@ class VariableQ(QGenerator):
98
114
  Returns:
99
115
  Tensor: generated batch of q values
100
116
  """
101
- q_min = np.random.uniform(*self.q_min_range, batch_size)
102
- q_max = np.random.uniform(*self.q_max_range, batch_size)
103
- if self.n_q_range[0] == self.n_q_range[1]:
104
- n_q = self.n_q_range[0]
105
- else:
106
- n_q = np.random.randint(self.n_q_range[0], self.n_q_range[1] + 1)
107
-
108
- q = torch.from_numpy(np.linspace(q_min, q_max, n_q).T).to(self.device).to(self.dtype)
117
+
118
+ q_min = torch.rand(batch_size, device=self.device, dtype=self.dtype) * (self.q_min_range[1] - self.q_min_range[0]) + self.q_min_range[0]
119
+ q_max = torch.rand(batch_size, device=self.device, dtype=self.dtype) * (self.q_max_range[1] - self.q_max_range[0]) + self.q_max_range[0]
120
+
121
+ n_q = torch.randint(self.n_q_range[0], self.n_q_range[1] + 1, (1,), device=self.device).item()
122
+
123
+ if self.mode == 'equidistant':
124
+ q = torch.linspace(0, 1, n_q, device=self.device, dtype=self.dtype)
125
+ elif self.mode == 'random':
126
+ q = torch.rand(n_q, device=self.device, dtype=self.dtype).sort().values
127
+
128
+ q = q_min[:, None] + q * (q_max - q_min)[:, None]
109
129
 
110
130
  return q
111
131
 
@@ -10,6 +10,7 @@ from reflectorch.data_generation.reflectivity.numpy_implementations import (
10
10
  abeles_np,
11
11
  )
12
12
  from reflectorch.data_generation.reflectivity.smearing import abeles_constant_smearing
13
+ from reflectorch.data_generation.reflectivity.smearing_pointwise import abeles_pointwise_smearing
13
14
  from reflectorch.data_generation.reflectivity.kinematical import kinematical_approximation
14
15
 
15
16
 
@@ -20,9 +21,15 @@ def reflectivity(
20
21
  sld: Tensor,
21
22
  dq: Tensor = None,
22
23
  gauss_num: int = 51,
23
- constant_dq: bool = True,
24
+ constant_dq: bool = False,
24
25
  log: bool = False,
25
- abeles_func=None,
26
+ q_shift: Tensor = 0.0,
27
+ r_scale: Tensor = 1.0,
28
+ background: Tensor = 0.0,
29
+ solvent_vf = None,
30
+ solvent_mode = 'fronting',
31
+ abeles_func = None,
32
+ **abeles_kwargs
26
33
  ):
27
34
  """Function which computes the reflectivity curves from thin film parameters.
28
35
  By default it uses the fast implementation of the Abeles matrix formalism.
@@ -31,30 +38,65 @@ def reflectivity(
31
38
  q (Tensor): tensor of momentum transfer (q) values with shape [batch_size, n_points] or [n_points]
32
39
  thickness (Tensor): tensor containing the layer thicknesses (ordered from top to bottom) with shape [batch_size, n_layers]
33
40
  roughness (Tensor): tensor containing the interlayer roughnesses (ordered from top to bottom) with shape [batch_size, n_layers + 1]
34
- sld (Tensor): tensors containing the layer SLDs (real or complex; ordered from top to bottom) with shape [batch_size, n_layers + 1].
35
- It includes the substrate but excludes the ambient medium which is assumed to have an SLD of 0.
41
+ sld (Tensor): tensor containing the layer SLDs (real or complex; ordered from top to bottom) with shape
42
+ [batch_size, n_layers + 1] (excluding ambient SLD which is assumed to be 0) or [batch_size, n_layers + 2] (including ambient SLD; only for the default ``abeles_func='abeles'``)
36
43
  dq (Tensor, optional): tensor of resolutions used for curve smearing with shape [batch_size, 1].
37
44
  Either dq if ``constant_dq`` is ``True`` or dq/q if ``constant_dq`` is ``False``. Defaults to None.
38
45
  gauss_num (int, optional): the number of gaussians for curve smearing. Defaults to 51.
39
46
  constant_dq (bool, optional): if ``True`` the smearing is constant (constant dq at each point in the curve)
40
- otherwise the smearing is linear (constant dq/q at each point in the curve). Defaults to True.
47
+ otherwise the smearing is linear (constant dq/q at each point in the curve). Defaults to False.
41
48
  log (bool, optional): if True the base 10 logarithm of the reflectivity curves is returned. Defaults to False.
42
- abeles_func (Callable, optional): a function implementing the simulation of the reflectivity curves, if different than the default Abeles matrix implementation. Defaults to None.
43
-
49
+ q_shift (float or Tensor, optional): misalignment in q.
50
+ r_scale (float or Tensor, optional): normalization factor (scales reflectivity).
51
+ background (float or Tensor, optional): background intensity.
52
+ abeles_func (Callable, optional): a function implementing the simulation of the reflectivity curves, if different than the default Abeles matrix implementation ('abeles'). Defaults to None.
53
+ abeles_kwargs: Additional arguments specific to the chosen `abeles_func`.
44
54
  Returns:
45
- Tensor: tensor containing the simulated reflectivity curves with shape [batch_size, n_points]
55
+ Tensor: the computed reflectivity curves
46
56
  """
47
57
  abeles_func = abeles_func or abeles
48
- q = torch.atleast_2d(q)
58
+ q = torch.atleast_2d(q) + q_shift
59
+ q = torch.clamp(q, min=0.0)
60
+
61
+ if solvent_vf is not None:
62
+ num_layers = thickness.shape[-1]
63
+ if solvent_mode == 'fronting':
64
+ assert sld.shape[-1] == num_layers + 2
65
+ assert solvent_vf.shape[-1] == num_layers
66
+ solvent_sld = sld[..., [0]]
67
+ idx = slice(1, num_layers)
68
+ sld[..., idx] = solvent_vf * solvent_sld + (1.0 - solvent_vf) * sld[..., idx]
69
+ elif solvent_mode == 'backing':
70
+ solvent_sld = sld[..., [-1]]
71
+ idx = slice(1, num_layers) if sld.shape[-1] == num_layers + 2 else slice(0, num_layers)
72
+ sld[..., idx] = solvent_vf * solvent_sld + (1.0 - solvent_vf) * sld[..., idx]
73
+ else:
74
+ raise NotImplementedError
49
75
 
50
76
  if dq is None:
51
- reflectivity_curves = abeles_func(q, thickness, roughness, sld)
77
+ reflectivity_curves = abeles_func(q, thickness, roughness, sld, **abeles_kwargs)
52
78
  else:
53
- reflectivity_curves = abeles_constant_smearing(
54
- q, thickness, roughness, sld,
55
- dq=dq, gauss_num=gauss_num, constant_dq=constant_dq, abeles_func=abeles_func
56
- )
79
+ if dq.shape[-1] > 1:
80
+ reflectivity_curves = abeles_pointwise_smearing(
81
+ q=q, dq=dq, thickness=thickness, roughness=roughness, sld=sld,
82
+ abeles_func=abeles_func, gauss_num=gauss_num,
83
+ **abeles_kwargs,
84
+ )
85
+ else:
86
+ reflectivity_curves = abeles_constant_smearing(
87
+ q, thickness, roughness, sld,
88
+ dq=dq, gauss_num=gauss_num, constant_dq=constant_dq, abeles_func=abeles_func,
89
+ **abeles_kwargs,
90
+ )
91
+
92
+ if isinstance(r_scale, Tensor):
93
+ r_scale = r_scale.view(-1, *[1] * (reflectivity_curves.dim() - 1))
94
+ if isinstance(background, Tensor):
95
+ background = background.view(-1, *[1] * (reflectivity_curves.dim() - 1))
96
+
97
+ reflectivity_curves = reflectivity_curves * r_scale + background
57
98
 
58
99
  if log:
59
100
  reflectivity_curves = torch.log10(reflectivity_curves)
101
+
60
102
  return reflectivity_curves
@@ -1,5 +1,6 @@
1
1
  # -*- coding: utf-8 -*-
2
2
  import math
3
+ from functools import reduce
3
4
 
4
5
  import torch
5
6
  from torch import Tensor
@@ -17,8 +18,9 @@ def abeles(
17
18
  q (Tensor): tensor of momentum transfer (q) values with shape [batch_size, n_points] or [n_points]
18
19
  thickness (Tensor): tensor containing the layer thicknesses (ordered from top to bottom) with shape [batch_size, n_layers]
19
20
  roughness (Tensor): tensor containing the interlayer roughnesses (ordered from top to bottom) with shape [batch_size, n_layers + 1]
20
- sld (Tensor): tensors containing the layer SLDs (real or complex; ordered from top to bottom) with shape [batch_size, n_layers + 1].
21
- It includes the substrate but excludes the ambient medium which is assumed to have an SLD of 0.
21
+ sld (Tensor): tensor containing the layer SLDs (real or complex; ordered from top to bottom). The tensor shape should be one of the following:
22
+ - [batch_size, n_layers + 1]: in this case, the ambient SLD is not included but assumed to be 0
23
+ - [batch_size, n_layers + 2]: this shape includes the ambient SLD as the first element in the tensor
22
24
 
23
25
  Returns:
24
26
  Tensor: tensor containing the simulated reflectivity curves with shape [batch_size, n_points]
@@ -27,11 +29,24 @@ def abeles(
27
29
 
28
30
  batch_size, num_layers = thickness.shape
29
31
 
30
- sld = torch.cat([torch.zeros(batch_size, 1).to(sld), sld], -1)[:, None]
31
- thickness = torch.cat([torch.zeros(batch_size, 1).to(thickness), thickness], -1)[:, None]
32
+ if sld.shape[-1] == num_layers + 1:
33
+ # add zero ambient sld
34
+ sld = torch.cat([torch.zeros(batch_size, 1).to(sld), sld], -1)
35
+ if sld.shape[-1] != num_layers + 2:
36
+ raise ValueError(
37
+ "Number of SLD values does not equal to num_layers + 2 (substrate + ambient)."
38
+ )
39
+
40
+ sld = sld[:, None]
41
+
42
+ # add zero thickness for ambient layer:
43
+ thickness = torch.cat([torch.zeros(batch_size, 1).to(thickness), thickness], -1)[
44
+ :, None
45
+ ]
46
+
32
47
  roughness = roughness[:, None] ** 2
33
48
 
34
- sld = sld * 1e-6 + 1e-30j
49
+ sld = (sld - sld[..., :1]) * 1e-6 + 1e-36j
35
50
 
36
51
  k_z0 = (q / 2).to(c_dtype)
37
52
 
@@ -41,7 +56,7 @@ def abeles(
41
56
  if k_z0.dim() == 2:
42
57
  k_z0.unsqueeze_(-1)
43
58
 
44
- k_n = torch.sqrt(k_z0 ** 2 - 4 * math.pi * sld)
59
+ k_n = torch.sqrt(k_z0**2 - 4 * math.pi * sld)
45
60
 
46
61
  # k_n.shape - (batch, q, layers)
47
62
 
@@ -52,22 +67,22 @@ def abeles(
52
67
  exp_beta = torch.exp(beta)
53
68
  exp_m_beta = torch.exp(-beta)
54
69
 
55
- rn = (k_n - k_np1) / (k_n + k_np1) * torch.exp(- 2 * k_n * k_np1 * roughness)
70
+ rn = (k_n - k_np1) / (k_n + k_np1) * torch.exp(-2 * k_n * k_np1 * roughness)
56
71
 
57
- c_matrices = torch.stack([
58
- torch.stack([exp_beta, rn * exp_m_beta], -1),
59
- torch.stack([rn * exp_beta, exp_m_beta], -1),
60
- ], -1)
72
+ c_matrices = torch.stack(
73
+ [
74
+ torch.stack([exp_beta, rn * exp_m_beta], -1),
75
+ torch.stack([rn * exp_beta, exp_m_beta], -1),
76
+ ],
77
+ -1,
78
+ )
61
79
 
62
80
  c_matrices = [c.squeeze(-3) for c in c_matrices.split(1, -3)]
63
81
 
64
- m, c_matrices = c_matrices[0], c_matrices[1:]
65
-
66
- for c in c_matrices:
67
- m = m @ c
82
+ m = reduce(torch.matmul, c_matrices)
68
83
 
69
84
  r = (m[..., 1, 0] / m[..., 0, 0]).abs() ** 2
70
- r = torch.clamp_max_(r, 1.)
85
+ r = torch.clamp_max_(r, 1.0)
71
86
 
72
87
  return r
73
88
 
@@ -19,7 +19,7 @@ def kinematical_approximation(
19
19
  q (Tensor): tensor of momentum transfer (q) values with shape [batch_size, n_points] or [n_points]
20
20
  thickness (Tensor): tensor containing the layer thicknesses (ordered from top to bottom) with shape [batch_size, n_layers]
21
21
  roughness (Tensor): tensor containing the interlayer roughnesses (ordered from top to bottom) with shape [batch_size, n_layers + 1]
22
- sld (Tensor): tensors containing the layer SLDs (real or complex; ordered from top to bottom) with shape [batch_size, n_layers + 1].
22
+ sld (Tensor): tensor containing the layer SLDs (real or complex; ordered from top to bottom) with shape [batch_size, n_layers + 1].
23
23
  It includes the substrate but excludes the ambient medium which is assumed to have an SLD of 0.
24
24
  apply_fresnel (bool, optional): whether to use the Fresnel coefficient in the computation. Defaults to ``True``.
25
25
  log (bool, optional): if True the base 10 logarithm of the reflectivity curves is returned. Defaults to ``False``.
@@ -51,7 +51,7 @@ def kinematical_approximation(
51
51
 
52
52
  substrate_sld = sld[:, -1:]
53
53
 
54
- rf = _get_resnel_reflectivity(q, substrate_sld[:, None])
54
+ rf = _get_fresnel_reflectivity(q, substrate_sld[:, None])
55
55
 
56
56
  r = torch.clamp_max_(r * rf / substrate_sld.real ** 2, 1.)
57
57
 
@@ -61,12 +61,11 @@ def kinematical_approximation(
61
61
  return r
62
62
 
63
63
 
64
- def _get_resnel_reflectivity(q, substrate_slds):
65
- _RE_CONST = 0.28174103675406496
64
+ def _get_fresnel_reflectivity(q, substrate_slds):
65
+ _RE_CONST = 0.28174103675406496 # 2/sqrt(16*pi)
66
66
 
67
67
  q_c = torch.sqrt(substrate_slds + 0j) / _RE_CONST * 2
68
68
  q_prime = torch.sqrt(q ** 2 - q_c ** 2 + 0j)
69
69
  r_f = ((q - q_prime) / (q + q_prime)).abs().float() ** 2
70
70
 
71
- return r_f.squeeze(-1)
72
-
71
+ return r_f.squeeze(-1)
@@ -19,7 +19,7 @@ def abeles_memory_eff(
19
19
  q (Tensor): tensor of momentum transfer (q) values with shape [batch_size, n_points] or [n_points]
20
20
  thickness (Tensor): tensor containing the layer thicknesses (ordered from top to bottom) with shape [batch_size, n_layers]
21
21
  roughness (Tensor): tensor containing the interlayer roughnesses (ordered from top to bottom) with shape [batch_size, n_layers + 1]
22
- sld (Tensor): tensors containing the layer SLDs (real or complex; ordered from top to bottom) with shape [batch_size, n_layers + 1].
22
+ sld (Tensor): tensor containing the layer SLDs (real or complex; ordered from top to bottom) with shape [batch_size, n_layers + 1].
23
23
  It includes the substrate but excludes the ambient medium which is assumed to have an SLD of 0.
24
24
 
25
25
  Returns:
@@ -14,26 +14,41 @@ def abeles_constant_smearing(
14
14
  roughness: Tensor,
15
15
  sld: Tensor,
16
16
  dq: Tensor = None,
17
- gauss_num: int = 51,
18
- constant_dq: bool = True,
17
+ gauss_num: int = 31,
18
+ constant_dq: bool = False,
19
19
  abeles_func=None,
20
+ **abeles_kwargs
20
21
  ):
21
22
  abeles_func = abeles_func or abeles
23
+
24
+ if dq.dtype != thickness.dtype:
25
+ q = q.to(thickness)
26
+
27
+ if dq.dtype != thickness.dtype:
28
+ dq = dq.to(thickness)
29
+
30
+ if q.shape[0] == 1:
31
+ q = q.repeat(thickness.shape[0], 1)
32
+
22
33
  q_lin = _get_q_axes(q, dq, gauss_num, constant_dq=constant_dq)
23
34
  kernels = _get_t_gauss_kernels(dq, gauss_num)
24
-
25
- curves = abeles_func(q_lin, thickness, roughness, sld)
35
+
36
+ curves = abeles_func(q_lin, thickness, roughness, sld, **abeles_kwargs)
26
37
 
27
38
  padding = (kernels.shape[-1] - 1) // 2
39
+ padded_curves = pad(curves, (padding, padding), 'reflect')
40
+
28
41
  smeared_curves = conv1d(
29
- pad(curves[None], (padding, padding), 'reflect'), kernels[:, None], groups=kernels.shape[0],
30
- )[0]
42
+ padded_curves, kernels[:, None], groups=kernels.shape[0],
43
+ )
31
44
 
32
45
  if q.shape[0] != smeared_curves.shape[0]:
33
- q = q.expand(smeared_curves.shape[0], *q.shape[1:])
34
-
46
+ repeat_factor = smeared_curves.shape[0] // q.shape[0]
47
+ q = q.repeat(repeat_factor, 1)
48
+ q_lin = q_lin.repeat(repeat_factor, 1)
49
+
35
50
  smeared_curves = _batch_linear_interp1d(q_lin, smeared_curves, q)
36
-
51
+
37
52
  return smeared_curves
38
53
 
39
54
 
@@ -55,7 +70,7 @@ def _get_t_gauss_kernels(resolutions: Tensor, gaussnum: int = 51):
55
70
  return gauss_y
56
71
 
57
72
 
58
- def _get_q_axes(q: Tensor, resolutions: Tensor, gaussnum: int = 51, constant_dq: bool = True):
73
+ def _get_q_axes(q: Tensor, resolutions: Tensor, gaussnum: int = 51, constant_dq: bool = False):
59
74
  if constant_dq:
60
75
  return _get_q_axes_for_constant_dq(q, resolutions, gaussnum)
61
76
  else:
@@ -0,0 +1,110 @@
1
+ import torch
2
+ import scipy
3
+ import numpy as np
4
+ from functools import lru_cache
5
+ from typing import Tuple
6
+
7
+ from reflectorch.data_generation.reflectivity.abeles import abeles
8
+
9
+ #Pytorch version based on the JAX implementation of pointwise smearing in the refnx package.
10
+
11
+ @lru_cache(maxsize=128)
12
+ def gauss_legendre(n: int) -> Tuple[torch.Tensor, torch.Tensor]:
13
+ """
14
+ Calculate Gaussian quadrature abscissae and weights.
15
+
16
+ Args:
17
+ n (int): Gaussian quadrature order.
18
+
19
+ Returns:
20
+ Tuple[torch.Tensor, torch.Tensor]: The abscissae and weights for Gauss-Legendre integration.
21
+ """
22
+ return scipy.special.p_roots(n)
23
+
24
+ def gauss(x: torch.Tensor) -> torch.Tensor:
25
+ """
26
+ Calculate the Gaussian function.
27
+
28
+ Args:
29
+ x (torch.Tensor): Input tensor.
30
+
31
+ Returns:
32
+ torch.Tensor: Output tensor after applying the Gaussian function.
33
+ """
34
+ return torch.exp(-0.5 * x * x)
35
+
36
+ def abeles_pointwise_smearing(
37
+ q: torch.Tensor,
38
+ dq: torch.Tensor,
39
+ thickness: torch.Tensor,
40
+ roughness: torch.Tensor,
41
+ sld: torch.Tensor,
42
+ gauss_num: int = 17,
43
+ abeles_func=None,
44
+ **abeles_kwargs,
45
+ ) -> torch.Tensor:
46
+ """
47
+ Compute reflectivity with variable smearing using Gaussian quadrature.
48
+
49
+ Args:
50
+ q (torch.Tensor): The momentum transfer (q) values.
51
+ dq (torch.Tensor): The resolution for curve smearing.
52
+ thickness (torch.Tensor): The layer thicknesses.
53
+ roughness (torch.Tensor): The interlayer roughnesses.
54
+ sld (torch.Tensor): The SLDs of the layers.
55
+ sld_magnetic (torch.Tensor, optional): The magnetic SLDs of the layers.
56
+ magnetization_angle (torch.Tensor, optional): The magnetization angles.
57
+ polarizer_eff (torch.Tensor, optional): The polarizer efficiency.
58
+ analyzer_eff (torch.Tensor, optional): The analyzer efficiency.
59
+ abeles_func (Callable, optional): A function implementing the simulation of the reflectivity curves.
60
+ gauss_num (int, optional): Gaussian quadrature order. Defaults to 17.
61
+
62
+ Returns:
63
+ torch.Tensor: The computed reflectivity curves.
64
+ """
65
+ abeles_func = abeles_func or abeles
66
+
67
+ if q.shape[0] == 1:
68
+ q = q.repeat(thickness.shape[0], 1)
69
+
70
+ _FWHM = 2 * np.sqrt(2 * np.log(2.0))
71
+ _INTLIMIT = 3.5
72
+
73
+ bs = q.shape[0]
74
+ nq = q.shape[-1]
75
+ device = q.device
76
+
77
+ quad_order = gauss_num
78
+ abscissa, weights = gauss_legendre(quad_order)
79
+ abscissa = torch.tensor(abscissa)[None, :, None].to(device)
80
+ weights = torch.tensor(weights)[None, :, None].to(device)
81
+ prefactor = 1.0 / np.sqrt(2 * np.pi)
82
+
83
+ gaussvals = prefactor * gauss(abscissa * _INTLIMIT)
84
+
85
+ va = q[:, None, :] - _INTLIMIT * dq[:, None, :] / _FWHM
86
+ vb = q[:, None, :] + _INTLIMIT * dq[:, None, :] / _FWHM
87
+
88
+ qvals_for_res_0 = (abscissa * (vb - va) + vb + va) / 2
89
+ qvals_for_res = qvals_for_res_0.reshape(bs, -1)
90
+
91
+ refl_curves = abeles_func(
92
+ q=qvals_for_res,
93
+ thickness=thickness,
94
+ roughness=roughness,
95
+ sld=sld,
96
+ **abeles_kwargs
97
+ )
98
+
99
+ # Handle multiple channels
100
+ if refl_curves.dim() == 3:
101
+ n_channels = refl_curves.shape[1]
102
+ refl_curves = refl_curves.reshape(bs, n_channels, quad_order, nq)
103
+ refl_curves = refl_curves * gaussvals.unsqueeze(1) * weights.unsqueeze(1)
104
+ refl_curves = torch.sum(refl_curves, dim=2) * _INTLIMIT
105
+ else:
106
+ refl_curves = refl_curves.reshape(bs, quad_order, nq)
107
+ refl_curves = refl_curves * gaussvals * weights
108
+ refl_curves = torch.sum(refl_curves, dim=1) * _INTLIMIT
109
+
110
+ return refl_curves
@@ -9,15 +9,15 @@ class Smearing(object):
9
9
  The intensity at a q point will be the average of the intensities of neighbouring q points, weighted by a gaussian profile.
10
10
 
11
11
  Args:
12
- sigma_range (tuple, optional): the range for sampling the resolutions. Defaults to (1e-4, 5e-3).
12
+ sigma_range (tuple, optional): the range for sampling the resolutions. Defaults to (0.01, 0.1).
13
13
  constant_dq (bool, optional): if ``True`` the smearing is constant (the resolution is given by the constant dq at each point in the curve)
14
14
  otherwise the smearing is linear (the resolution is given by the constant dq/q at each point in the curve). Defaults to True.
15
15
  gauss_num (int, optional): the number of interpolating gaussian profiles. Defaults to 31.
16
16
  share_smeared (float, optional): the share of curves in the batch for which the resolution smearing is applied. Defaults to 0.2.
17
17
  """
18
18
  def __init__(self,
19
- sigma_range: tuple = (1e-4, 5e-3),
20
- constant_dq: bool = True,
19
+ sigma_range: tuple = (0.01, 0.1),
20
+ constant_dq: bool = False,
21
21
  gauss_num: int = 31,
22
22
  share_smeared: float = 0.2,
23
23
  ):
@@ -38,31 +38,62 @@ class Smearing(object):
38
38
  indices = torch.zeros(batch_size, device=device, dtype=torch.bool)
39
39
  indices[torch.randperm(batch_size, device=device)[:num_smeared]] = True
40
40
  return dq, indices
41
+
42
+ def scale_resolutions(self, resolutions: Tensor) -> Tensor:
43
+ """Scales the q-resolution values to [-1,1] range using the internal sigma range"""
44
+ sigma_min = 0.0 if self.share_smeared != 1.0 else self.sigma_min
45
+ return 2 * (resolutions - sigma_min) / (self.sigma_max - sigma_min) - 1
46
+
47
+ def get_curves(self, q_values: Tensor, params: BasicParams, refl_kwargs:dict = None):
48
+ refl_kwargs = refl_kwargs or {}
41
49
 
42
- def get_curves(self, q_values: Tensor, params: BasicParams):
43
50
  dq, indices = self.generate_resolutions(params.batch_size, device=params.device, dtype=params.dtype)
51
+ q_resolutions = torch.zeros(q_values.shape[0], 1, dtype=q_values.dtype, device=q_values.device)
44
52
 
45
53
  if dq is None:
46
- return params.reflectivity(q_values, log=False)
47
-
48
- curves = torch.empty(params.batch_size, q_values.shape[-1], device=params.device, dtype=params.dtype)
54
+ return params.reflectivity(q_values, **refl_kwargs), q_resolutions
55
+
56
+ refl_kwargs_not_smeared = {}
57
+ refl_kwargs_smeared = {}
58
+ for key, value in refl_kwargs.items():
59
+ if isinstance(value, torch.Tensor) and value.shape[0] == params.batch_size:
60
+ refl_kwargs_not_smeared[key] = value[~indices]
61
+ refl_kwargs_smeared[key] = value[indices]
62
+ else:
63
+ refl_kwargs_not_smeared[key] = value
64
+ refl_kwargs_smeared[key] = value
49
65
 
66
+ # Compute unsmeared reflectivity
50
67
  if (~indices).sum().item():
51
68
  if q_values.dim() == 2 and q_values.shape[0] > 1:
52
69
  q = q_values[~indices]
53
70
  else:
54
71
  q = q_values
55
72
 
56
- curves[~indices] = params[~indices].reflectivity(q, log=False)
73
+ reflectivity_not_smeared = params[~indices].reflectivity(q, **refl_kwargs_not_smeared)
74
+ else:
75
+ reflectivity_not_smeared = None
57
76
 
77
+ # Compute smeared reflectivity
58
78
  if indices.sum().item():
59
79
  if q_values.dim() == 2 and q_values.shape[0] > 1:
60
80
  q = q_values[indices]
61
81
  else:
62
82
  q = q_values
63
83
 
64
- curves[indices] = params[indices].reflectivity(
65
- q, dq=dq, constant_dq=self.constant_dq, log=False, gauss_num=self.gauss_num
84
+ reflectivity_smeared = params[indices].reflectivity(
85
+ q, dq=dq, constant_dq=self.constant_dq, gauss_num=self.gauss_num, **refl_kwargs_smeared
66
86
  )
87
+ else:
88
+ reflectivity_smeared = None
89
+
90
+ curves = torch.empty(params.batch_size, q_values.shape[-1], device=params.device, dtype=params.dtype)
91
+
92
+ if (~indices).sum().item():
93
+ curves[~indices] = reflectivity_not_smeared
94
+
95
+ curves[indices] = reflectivity_smeared
96
+
97
+ q_resolutions[indices] = dq
67
98
 
68
- return curves
99
+ return curves, q_resolutions