reflectorch 1.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (96) hide show
  1. reflectorch/__init__.py +17 -0
  2. reflectorch/data_generation/__init__.py +128 -0
  3. reflectorch/data_generation/dataset.py +216 -0
  4. reflectorch/data_generation/likelihoods.py +80 -0
  5. reflectorch/data_generation/noise.py +471 -0
  6. reflectorch/data_generation/priors/__init__.py +60 -0
  7. reflectorch/data_generation/priors/base.py +55 -0
  8. reflectorch/data_generation/priors/exp_subprior_sampler.py +298 -0
  9. reflectorch/data_generation/priors/independent_priors.py +195 -0
  10. reflectorch/data_generation/priors/multilayer_models.py +311 -0
  11. reflectorch/data_generation/priors/multilayer_structures.py +104 -0
  12. reflectorch/data_generation/priors/no_constraints.py +206 -0
  13. reflectorch/data_generation/priors/parametric_models.py +842 -0
  14. reflectorch/data_generation/priors/parametric_subpriors.py +369 -0
  15. reflectorch/data_generation/priors/params.py +252 -0
  16. reflectorch/data_generation/priors/sampler_strategies.py +370 -0
  17. reflectorch/data_generation/priors/scaler_mixin.py +65 -0
  18. reflectorch/data_generation/priors/subprior_sampler.py +371 -0
  19. reflectorch/data_generation/priors/utils.py +118 -0
  20. reflectorch/data_generation/process_data.py +41 -0
  21. reflectorch/data_generation/q_generator.py +280 -0
  22. reflectorch/data_generation/reflectivity/__init__.py +102 -0
  23. reflectorch/data_generation/reflectivity/abeles.py +97 -0
  24. reflectorch/data_generation/reflectivity/kinematical.py +71 -0
  25. reflectorch/data_generation/reflectivity/memory_eff.py +105 -0
  26. reflectorch/data_generation/reflectivity/numpy_implementations.py +120 -0
  27. reflectorch/data_generation/reflectivity/smearing.py +138 -0
  28. reflectorch/data_generation/reflectivity/smearing_pointwise.py +110 -0
  29. reflectorch/data_generation/scale_curves.py +112 -0
  30. reflectorch/data_generation/smearing.py +99 -0
  31. reflectorch/data_generation/utils.py +223 -0
  32. reflectorch/extensions/__init__.py +0 -0
  33. reflectorch/extensions/jupyter/__init__.py +11 -0
  34. reflectorch/extensions/jupyter/api.py +85 -0
  35. reflectorch/extensions/jupyter/callbacks.py +34 -0
  36. reflectorch/extensions/jupyter/components.py +758 -0
  37. reflectorch/extensions/jupyter/custom_select.py +268 -0
  38. reflectorch/extensions/jupyter/log_widget.py +241 -0
  39. reflectorch/extensions/jupyter/model_selection.py +495 -0
  40. reflectorch/extensions/jupyter/plotly_plot_manager.py +329 -0
  41. reflectorch/extensions/jupyter/widget.py +625 -0
  42. reflectorch/extensions/matplotlib/__init__.py +5 -0
  43. reflectorch/extensions/matplotlib/losses.py +32 -0
  44. reflectorch/extensions/refnx/refnx_conversion.py +77 -0
  45. reflectorch/inference/__init__.py +28 -0
  46. reflectorch/inference/inference_model.py +848 -0
  47. reflectorch/inference/input_interface.py +239 -0
  48. reflectorch/inference/loading_data.py +55 -0
  49. reflectorch/inference/multilayer_fitter.py +171 -0
  50. reflectorch/inference/multilayer_inference_model.py +193 -0
  51. reflectorch/inference/plotting.py +524 -0
  52. reflectorch/inference/preprocess_exp/__init__.py +7 -0
  53. reflectorch/inference/preprocess_exp/attenuation.py +36 -0
  54. reflectorch/inference/preprocess_exp/cut_with_q_ratio.py +31 -0
  55. reflectorch/inference/preprocess_exp/footprint.py +81 -0
  56. reflectorch/inference/preprocess_exp/interpolation.py +19 -0
  57. reflectorch/inference/preprocess_exp/normalize.py +21 -0
  58. reflectorch/inference/preprocess_exp/preprocess.py +121 -0
  59. reflectorch/inference/query_matcher.py +82 -0
  60. reflectorch/inference/record_time.py +43 -0
  61. reflectorch/inference/sampler_solution.py +56 -0
  62. reflectorch/inference/scipy_fitter.py +364 -0
  63. reflectorch/inference/torch_fitter.py +87 -0
  64. reflectorch/ml/__init__.py +32 -0
  65. reflectorch/ml/basic_trainer.py +292 -0
  66. reflectorch/ml/callbacks.py +81 -0
  67. reflectorch/ml/dataloaders.py +27 -0
  68. reflectorch/ml/loggers.py +56 -0
  69. reflectorch/ml/schedulers.py +356 -0
  70. reflectorch/ml/trainers.py +201 -0
  71. reflectorch/ml/utils.py +2 -0
  72. reflectorch/models/__init__.py +16 -0
  73. reflectorch/models/activations.py +50 -0
  74. reflectorch/models/encoders/__init__.py +19 -0
  75. reflectorch/models/encoders/conv_encoder.py +219 -0
  76. reflectorch/models/encoders/conv_res_net.py +115 -0
  77. reflectorch/models/encoders/fno.py +134 -0
  78. reflectorch/models/encoders/integral_kernel_embedding.py +390 -0
  79. reflectorch/models/networks/__init__.py +14 -0
  80. reflectorch/models/networks/mlp_networks.py +434 -0
  81. reflectorch/models/networks/residual_net.py +157 -0
  82. reflectorch/paths.py +29 -0
  83. reflectorch/runs/__init__.py +31 -0
  84. reflectorch/runs/config.py +25 -0
  85. reflectorch/runs/slurm_utils.py +93 -0
  86. reflectorch/runs/train.py +78 -0
  87. reflectorch/runs/utils.py +405 -0
  88. reflectorch/test_config.py +4 -0
  89. reflectorch/train.py +4 -0
  90. reflectorch/train_on_cluster.py +4 -0
  91. reflectorch/utils.py +98 -0
  92. reflectorch-1.5.1.dist-info/METADATA +151 -0
  93. reflectorch-1.5.1.dist-info/RECORD +96 -0
  94. reflectorch-1.5.1.dist-info/WHEEL +5 -0
  95. reflectorch-1.5.1.dist-info/licenses/LICENSE.txt +21 -0
  96. reflectorch-1.5.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,120 @@
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import numpy as np
4
+
5
+
6
+ def abeles_np(
7
+ q: np.ndarray,
8
+ thickness: np.ndarray,
9
+ roughness: np.ndarray,
10
+ sld: np.ndarray,
11
+ ):
12
+ c_dtype = np.complex128 if q.dtype is np.float64 else np.complex64
13
+
14
+ if q.ndim == thickness.ndim == roughness.ndim == sld.ndim == 1:
15
+ zero_batch = True
16
+ else:
17
+ zero_batch = False
18
+
19
+ thickness = np.atleast_2d(thickness)
20
+ roughness = np.atleast_2d(roughness)
21
+ sld = np.atleast_2d(sld)
22
+
23
+ batch_size, num_layers = thickness.shape
24
+
25
+ sld = np.concatenate([np.zeros((batch_size, 1)).astype(sld.dtype), sld], -1)[:, None]
26
+ thickness = np.concatenate([np.zeros((batch_size, 1)).astype(thickness.dtype), thickness], -1)[:, None]
27
+ roughness = roughness[:, None] ** 2
28
+
29
+ sld = sld * 1e-6 + 1e-30j
30
+
31
+ k_z0 = (q / 2).astype(c_dtype)
32
+
33
+ if len(k_z0.shape) == 1:
34
+ k_z0 = k_z0[None]
35
+
36
+ if len(k_z0.shape) == 2:
37
+ k_z0 = k_z0[..., None]
38
+
39
+ k_n = np.sqrt(k_z0 ** 2 - 4 * np.pi * sld)
40
+
41
+ # k_n.shape - (batch, q, layers)
42
+
43
+ k_n, k_np1 = k_n[..., :-1], k_n[..., 1:]
44
+
45
+ beta = 1j * thickness * k_n
46
+
47
+ exp_beta = np.exp(beta)
48
+ exp_m_beta = np.exp(-beta)
49
+
50
+ rn = (k_n - k_np1) / (k_n + k_np1) * np.exp(- 2 * k_n * k_np1 * roughness)
51
+
52
+ c_matrices = np.stack([
53
+ np.stack([exp_beta, rn * exp_m_beta], -1),
54
+ np.stack([rn * exp_beta, exp_m_beta], -1),
55
+ ], -1)
56
+
57
+ c_matrices = np.moveaxis(c_matrices, -3, 0)
58
+
59
+ m, c_matrices = c_matrices[0], c_matrices[1:]
60
+
61
+ for c in c_matrices:
62
+ m = m @ c
63
+
64
+ r = np.abs(m[..., 1, 0] / m[..., 0, 0]) ** 2
65
+ r = np.clip(r, None, 1.)
66
+
67
+ if zero_batch:
68
+ r = r[0]
69
+
70
+ return r
71
+
72
+
73
+ def kinematical_approximation_np(
74
+ q: np.ndarray,
75
+ thickness: np.ndarray,
76
+ roughness: np.ndarray,
77
+ sld: np.ndarray,
78
+ ):
79
+ if q.ndim == thickness.ndim == roughness.ndim == sld.ndim == 1:
80
+ zero_batch = True
81
+ else:
82
+ zero_batch = False
83
+
84
+ thickness = np.atleast_2d(thickness)
85
+ roughness = np.atleast_2d(roughness)
86
+ sld = np.atleast_2d(sld) * 1e-6 + 1e-30j
87
+ substrate_sld = sld[:, -1:]
88
+
89
+ batch_size, num_layers = thickness.shape
90
+
91
+ if q.ndim == 1:
92
+ q = q[None]
93
+
94
+ if q.ndim == 2:
95
+ q = q[..., None]
96
+
97
+ drho = np.concatenate([sld[..., 0][..., None], sld[..., 1:] - sld[..., :-1]], -1)[:, None]
98
+ thickness = np.cumsum(np.concatenate([np.zeros((batch_size, 1)), thickness], -1), -1)[:, None]
99
+ roughness = roughness[:, None]
100
+
101
+ r = np.abs((drho * np.exp(- (roughness * q) ** 2 / 2 + 1j * (q * thickness))).sum(-1)).astype(float) ** 2
102
+
103
+ rf = _get_resnel_reflectivity_np(q, substrate_sld[:, None])
104
+
105
+ r = np.clip(r * rf / np.real(substrate_sld) ** 2, None, 1.)
106
+
107
+ if zero_batch:
108
+ r = r[0]
109
+
110
+ return r
111
+
112
+
113
+ def _get_resnel_reflectivity_np(q, substrate_slds):
114
+ _RE_CONST = 0.28174103675406496
115
+
116
+ q_c = np.sqrt(substrate_slds + 0j) / _RE_CONST * 2
117
+ q_prime = np.sqrt(q ** 2 - q_c ** 2 + 0j)
118
+ r_f = np.abs((q - q_prime) / (q + q_prime)).astype(float) ** 2
119
+
120
+ return r_f[..., 0]
@@ -0,0 +1,138 @@
1
+ # -*- coding: utf-8 -*-
2
+ from math import pi, sqrt, log
3
+
4
+ import torch
5
+ from torch import Tensor
6
+
7
+ from reflectorch.data_generation.reflectivity.abeles import abeles
8
+ from torch.nn.functional import conv1d, pad
9
+
10
+
11
+ def abeles_constant_smearing(
12
+ q: Tensor,
13
+ thickness: Tensor,
14
+ roughness: Tensor,
15
+ sld: Tensor,
16
+ dq: Tensor = None,
17
+ gauss_num: int = 31,
18
+ constant_dq: bool = False,
19
+ abeles_func=None,
20
+ **abeles_kwargs
21
+ ):
22
+ abeles_func = abeles_func or abeles
23
+
24
+ if dq.dtype != thickness.dtype:
25
+ q = q.to(thickness)
26
+
27
+ if dq.dtype != thickness.dtype:
28
+ dq = dq.to(thickness)
29
+
30
+ if q.shape[0] == 1:
31
+ q = q.repeat(thickness.shape[0], 1)
32
+
33
+ q_lin = _get_q_axes(q, dq, gauss_num, constant_dq=constant_dq)
34
+ kernels = _get_t_gauss_kernels(dq, gauss_num)
35
+
36
+ curves = abeles_func(q_lin, thickness, roughness, sld, **abeles_kwargs)
37
+
38
+ padding = (kernels.shape[-1] - 1) // 2
39
+ padded_curves = pad(curves, (padding, padding), 'reflect')
40
+
41
+ smeared_curves = conv1d(
42
+ padded_curves, kernels[:, None], groups=kernels.shape[0],
43
+ )
44
+
45
+ if q.shape[0] != smeared_curves.shape[0]:
46
+ repeat_factor = smeared_curves.shape[0] // q.shape[0]
47
+ q = q.repeat(repeat_factor, 1)
48
+ q_lin = q_lin.repeat(repeat_factor, 1)
49
+
50
+ smeared_curves = _batch_linear_interp1d(q_lin, smeared_curves, q)
51
+
52
+ return smeared_curves
53
+
54
+
55
+ _FWHM = 2 * sqrt(2 * log(2.0))
56
+ _2PI_SQRT = 1. / sqrt(2 * pi)
57
+
58
+
59
+ def _batch_linspace(start: Tensor, end: Tensor, num: int):
60
+ return torch.linspace(0, 1, int(num), device=end.device, dtype=end.dtype)[None] * (end - start) + start
61
+
62
+
63
+ def _torch_gauss(x, s):
64
+ return _2PI_SQRT / s * torch.exp(-0.5 * x ** 2 / s / s)
65
+
66
+
67
+ def _get_t_gauss_kernels(resolutions: Tensor, gaussnum: int = 51):
68
+ gauss_x = _batch_linspace(-1.7 * resolutions, 1.7 * resolutions, gaussnum)
69
+ gauss_y = _torch_gauss(gauss_x, resolutions / _FWHM) * (gauss_x[:, 1] - gauss_x[:, 0])[:, None]
70
+ return gauss_y
71
+
72
+
73
+ def _get_q_axes(q: Tensor, resolutions: Tensor, gaussnum: int = 51, constant_dq: bool = False):
74
+ if constant_dq:
75
+ return _get_q_axes_for_constant_dq(q, resolutions, gaussnum)
76
+ else:
77
+ return _get_q_axes_for_linear_dq(q, resolutions, gaussnum)
78
+
79
+
80
+ def _get_q_axes_for_linear_dq(q: Tensor, resolutions: Tensor, gaussnum: int = 51):
81
+ gaussgpoint = (gaussnum - 1) / 2
82
+
83
+ lowq = torch.clamp_min_(q.min(1).values, 1e-6)
84
+ highq = q.max(1).values
85
+
86
+ start = torch.log10(lowq)[:, None] - 6 * resolutions / _FWHM
87
+ end = torch.log10(highq[:, None] * (1 + 6 * resolutions / _FWHM))
88
+
89
+ interpnums = torch.abs(
90
+ (torch.abs(end - start)) / (1.7 * resolutions / _FWHM / gaussgpoint)
91
+ ).round().to(int)
92
+
93
+ q_lin = 10 ** _batch_linspace_with_padding(start, end, interpnums)
94
+
95
+ return q_lin
96
+
97
+
98
+ def _get_q_axes_for_constant_dq(q: Tensor, resolutions: Tensor, gaussnum: int = 51) -> Tensor:
99
+ gaussgpoint = (gaussnum - 1) / 2
100
+
101
+ start = q.min(1).values[:, None] - resolutions * 1.7
102
+ end = q.max(1).values[:, None] + resolutions * 1.7
103
+
104
+ interpnums = torch.abs(
105
+ (torch.abs(end - start)) / (1.7 * resolutions / gaussgpoint)
106
+ ).round().to(int)
107
+
108
+ q_lin = _batch_linspace_with_padding(start, end, interpnums)
109
+ q_lin = torch.clamp_min_(q_lin, 1e-6)
110
+
111
+ return q_lin
112
+
113
+
114
+ def _batch_linspace_with_padding(start: Tensor, end: Tensor, nums: Tensor) -> Tensor:
115
+ max_num = nums.max().int().item()
116
+
117
+ deltas = 1 / (nums - 1)
118
+
119
+ x = torch.clamp_min_(_batch_linspace(deltas * (nums - max_num), torch.ones_like(deltas), max_num), 0)
120
+
121
+ x = x * (end - start) + start
122
+
123
+ return x
124
+
125
+
126
+ def _batch_linear_interp1d(x: Tensor, y: Tensor, x_new: Tensor) -> Tensor:
127
+ eps = torch.finfo(y.dtype).eps
128
+
129
+ ind = torch.searchsorted(x.contiguous(), x_new.contiguous())
130
+
131
+ ind = torch.clamp_(ind - 1, 0, x.shape[-1] - 2)
132
+ slopes = (y[..., 1:] - y[..., :-1]) / (eps + (x[..., 1:] - x[..., :-1]))
133
+ ind_y = ind + torch.arange(slopes.shape[0], device=slopes.device)[:, None] * y.shape[1]
134
+ ind_slopes = ind + torch.arange(slopes.shape[0], device=slopes.device)[:, None] * slopes.shape[1]
135
+
136
+ y_new = y.flatten()[ind_y] + slopes.flatten()[ind_slopes] * (x_new - x.flatten()[ind_y])
137
+
138
+ return y_new
@@ -0,0 +1,110 @@
1
+ import torch
2
+ import scipy
3
+ import numpy as np
4
+ from functools import lru_cache
5
+ from typing import Tuple
6
+
7
+ from reflectorch.data_generation.reflectivity.abeles import abeles
8
+
9
+ #Pytorch version based on the JAX implementation of pointwise smearing in the refnx package.
10
+
11
+ @lru_cache(maxsize=128)
12
+ def gauss_legendre(n: int) -> Tuple[torch.Tensor, torch.Tensor]:
13
+ """
14
+ Calculate Gaussian quadrature abscissae and weights.
15
+
16
+ Args:
17
+ n (int): Gaussian quadrature order.
18
+
19
+ Returns:
20
+ Tuple[torch.Tensor, torch.Tensor]: The abscissae and weights for Gauss-Legendre integration.
21
+ """
22
+ return scipy.special.p_roots(n)
23
+
24
+ def gauss(x: torch.Tensor) -> torch.Tensor:
25
+ """
26
+ Calculate the Gaussian function.
27
+
28
+ Args:
29
+ x (torch.Tensor): Input tensor.
30
+
31
+ Returns:
32
+ torch.Tensor: Output tensor after applying the Gaussian function.
33
+ """
34
+ return torch.exp(-0.5 * x * x)
35
+
36
+ def abeles_pointwise_smearing(
37
+ q: torch.Tensor,
38
+ dq: torch.Tensor,
39
+ thickness: torch.Tensor,
40
+ roughness: torch.Tensor,
41
+ sld: torch.Tensor,
42
+ gauss_num: int = 17,
43
+ abeles_func=None,
44
+ **abeles_kwargs,
45
+ ) -> torch.Tensor:
46
+ """
47
+ Compute reflectivity with variable smearing using Gaussian quadrature.
48
+
49
+ Args:
50
+ q (torch.Tensor): The momentum transfer (q) values.
51
+ dq (torch.Tensor): The resolution for curve smearing.
52
+ thickness (torch.Tensor): The layer thicknesses.
53
+ roughness (torch.Tensor): The interlayer roughnesses.
54
+ sld (torch.Tensor): The SLDs of the layers.
55
+ sld_magnetic (torch.Tensor, optional): The magnetic SLDs of the layers.
56
+ magnetization_angle (torch.Tensor, optional): The magnetization angles.
57
+ polarizer_eff (torch.Tensor, optional): The polarizer efficiency.
58
+ analyzer_eff (torch.Tensor, optional): The analyzer efficiency.
59
+ abeles_func (Callable, optional): A function implementing the simulation of the reflectivity curves.
60
+ gauss_num (int, optional): Gaussian quadrature order. Defaults to 17.
61
+
62
+ Returns:
63
+ torch.Tensor: The computed reflectivity curves.
64
+ """
65
+ abeles_func = abeles_func or abeles
66
+
67
+ if q.shape[0] == 1:
68
+ q = q.repeat(thickness.shape[0], 1)
69
+
70
+ _FWHM = 2 * np.sqrt(2 * np.log(2.0))
71
+ _INTLIMIT = 3.5
72
+
73
+ bs = q.shape[0]
74
+ nq = q.shape[-1]
75
+ device = q.device
76
+
77
+ quad_order = gauss_num
78
+ abscissa, weights = gauss_legendre(quad_order)
79
+ abscissa = torch.tensor(abscissa)[None, :, None].to(device)
80
+ weights = torch.tensor(weights)[None, :, None].to(device)
81
+ prefactor = 1.0 / np.sqrt(2 * np.pi)
82
+
83
+ gaussvals = prefactor * gauss(abscissa * _INTLIMIT)
84
+
85
+ va = q[:, None, :] - _INTLIMIT * dq[:, None, :] / _FWHM
86
+ vb = q[:, None, :] + _INTLIMIT * dq[:, None, :] / _FWHM
87
+
88
+ qvals_for_res_0 = (abscissa * (vb - va) + vb + va) / 2
89
+ qvals_for_res = qvals_for_res_0.reshape(bs, -1)
90
+
91
+ refl_curves = abeles_func(
92
+ q=qvals_for_res,
93
+ thickness=thickness,
94
+ roughness=roughness,
95
+ sld=sld,
96
+ **abeles_kwargs
97
+ )
98
+
99
+ # Handle multiple channels
100
+ if refl_curves.dim() == 3:
101
+ n_channels = refl_curves.shape[1]
102
+ refl_curves = refl_curves.reshape(bs, n_channels, quad_order, nq)
103
+ refl_curves = refl_curves * gaussvals.unsqueeze(1) * weights.unsqueeze(1)
104
+ refl_curves = torch.sum(refl_curves, dim=2) * _INTLIMIT
105
+ else:
106
+ refl_curves = refl_curves.reshape(bs, quad_order, nq)
107
+ refl_curves = refl_curves * gaussvals * weights
108
+ refl_curves = torch.sum(refl_curves, dim=1) * _INTLIMIT
109
+
110
+ return refl_curves
@@ -0,0 +1,112 @@
1
+ from pathlib import Path
2
+
3
+ import torch
4
+ from torch import Tensor
5
+
6
+ from reflectorch.data_generation.priors import PriorSampler
7
+ from reflectorch.paths import SAVED_MODELS_DIR
8
+
9
+
10
+ class CurvesScaler(object):
11
+ """Base class for curve scalers"""
12
+ def scale(self, curves: Tensor):
13
+ raise NotImplementedError
14
+
15
+ def restore(self, curves: Tensor):
16
+ raise NotImplementedError
17
+
18
+
19
+ class LogAffineCurvesScaler(CurvesScaler):
20
+ """ Curve scaler which scales the reflectivity curves according to the logarithmic affine transformation:
21
+ :math:`\log_{10}(R + eps) \cdot weight + bias`.
22
+
23
+ Args:
24
+ weight (float): multiplication factor in the transformation
25
+ bias (float): addition term in the transformation
26
+ eps (float): sets the minimum intensity value of the reflectivity curves which is considered
27
+ """
28
+ def __init__(self, weight: float = 0.1, bias: float = 0.5, eps: float = 1e-10):
29
+ self.weight = weight
30
+ self.bias = bias
31
+ self.eps = eps
32
+
33
+ def scale(self, curves: Tensor):
34
+ """scales the reflectivity curves to a ML-friendly range
35
+
36
+ Args:
37
+ curves (Tensor): original reflectivity curves
38
+
39
+ Returns:
40
+ Tensor: reflectivity curves scaled to a ML-friendly range
41
+ """
42
+ return torch.log10(curves + self.eps) * self.weight + self.bias
43
+
44
+ def restore(self, curves: Tensor):
45
+ """restores the physical reflectivity curves
46
+
47
+ Args:
48
+ curves (Tensor): scaled reflectivity curves
49
+
50
+ Returns:
51
+ Tensor: reflectivity curves restored to the physical range
52
+ """
53
+ return 10 ** ((curves - self.bias) / self.weight) - self.eps
54
+
55
+
56
+ class MeanNormalizationCurvesScaler(CurvesScaler):
57
+ """Curve scaler which scales the reflectivity curves by the precomputed mean of a batch of curves
58
+
59
+ Args:
60
+ path (str, optional): path to the precomputed mean of the curves, only used if ``curves_mean`` is None. Defaults to None.
61
+ curves_mean (Tensor, optional): the precomputed mean of the curves. Defaults to None.
62
+ device (torch.device, optional): the Pytorch device. Defaults to 'cuda'.
63
+ """
64
+
65
+ def __init__(self, path: str = None, curves_mean: Tensor = None, device: torch.device = 'cuda'):
66
+ if curves_mean is None:
67
+ curves_mean = torch.load(self.get_path(path))
68
+ self.curves_mean = curves_mean.to(device)
69
+
70
+ def scale(self, curves: Tensor):
71
+ """scales the reflectivity curves to a ML-friendly range
72
+
73
+ Args:
74
+ curves (Tensor): original reflectivity curves
75
+
76
+ Returns:
77
+ Tensor: reflectivity curves scaled to a ML-friendly range
78
+ """
79
+ self.curves_mean = self.curves_mean.to(curves)
80
+ return curves / self.curves_mean - 1
81
+
82
+ def restore(self, curves: Tensor):
83
+ """restores the physical reflectivity curves
84
+
85
+ Args:
86
+ curves (Tensor): scaled reflectivity curves
87
+
88
+ Returns:
89
+ Tensor: reflectivity curves restored to the physical range
90
+ """
91
+ self.curves_mean = self.curves_mean.to(curves)
92
+ return (curves + 1) * self.curves_mean
93
+
94
+ @staticmethod
95
+ def save(prior_sampler: PriorSampler, q: Tensor, path: str, num: int = 16384):
96
+ """computes the mean of a batch of reflectivity curves and saves it
97
+
98
+ Args:
99
+ prior_sampler (PriorSampler): the prior sampler
100
+ q (Tensor): the q values
101
+ path (str): the path for saving the mean of the curves
102
+ num (int, optional): the number of curves used to compute the mean. Defaults to 16384.
103
+ """
104
+ params = prior_sampler.sample(num)
105
+ curves_mean = params.reflectivity(q, log=False).mean(0).cpu()
106
+ torch.save(curves_mean, MeanNormalizationCurvesScaler.get_path(path))
107
+
108
+ @staticmethod
109
+ def get_path(path: str) -> Path:
110
+ if not path.endswith('.pt'):
111
+ path = path + '.pt'
112
+ return SAVED_MODELS_DIR / path
@@ -0,0 +1,99 @@
1
+ import torch
2
+ from torch import Tensor
3
+
4
+ from reflectorch.data_generation.priors.parametric_subpriors import BasicParams
5
+
6
+
7
+ class Smearing(object):
8
+ """Class which applies resolution smearing to the reflectivity curves.
9
+ The intensity at a q point will be the average of the intensities of neighbouring q points, weighted by a gaussian profile.
10
+
11
+ Args:
12
+ sigma_range (tuple, optional): the range for sampling the resolutions. Defaults to (0.01, 0.1).
13
+ constant_dq (bool, optional): if ``True`` the smearing is constant (the resolution is given by the constant dq at each point in the curve)
14
+ otherwise the smearing is linear (the resolution is given by the constant dq/q at each point in the curve). Defaults to True.
15
+ gauss_num (int, optional): the number of interpolating gaussian profiles. Defaults to 31.
16
+ share_smeared (float, optional): the share of curves in the batch for which the resolution smearing is applied. Defaults to 0.2.
17
+ """
18
+ def __init__(self,
19
+ sigma_range: tuple = (0.01, 0.1),
20
+ constant_dq: bool = False,
21
+ gauss_num: int = 31,
22
+ share_smeared: float = 0.2,
23
+ ):
24
+ self.sigma_min, self.sigma_max = sigma_range
25
+ self.sigma_delta = self.sigma_max - self.sigma_min
26
+ self.constant_dq = constant_dq
27
+ self.gauss_num = gauss_num
28
+ self.share_smeared = share_smeared
29
+
30
+ def __repr__(self):
31
+ return f'Smearing(({self.sigma_min}, {self.sigma_max})'
32
+
33
+ def generate_resolutions(self, batch_size: int, device=None, dtype=None):
34
+ num_smeared = int(batch_size * self.share_smeared)
35
+ if not num_smeared:
36
+ return None, None
37
+ dq = torch.rand(num_smeared, 1, device=device, dtype=dtype) * self.sigma_delta + self.sigma_min
38
+ indices = torch.zeros(batch_size, device=device, dtype=torch.bool)
39
+ indices[torch.randperm(batch_size, device=device)[:num_smeared]] = True
40
+ return dq, indices
41
+
42
+ def scale_resolutions(self, resolutions: Tensor) -> Tensor:
43
+ """Scales the q-resolution values to [-1,1] range using the internal sigma range"""
44
+ sigma_min = 0.0 if self.share_smeared != 1.0 else self.sigma_min
45
+ return 2 * (resolutions - sigma_min) / (self.sigma_max - sigma_min) - 1
46
+
47
+ def get_curves(self, q_values: Tensor, params: BasicParams, refl_kwargs:dict = None):
48
+ refl_kwargs = refl_kwargs or {}
49
+
50
+ dq, indices = self.generate_resolutions(params.batch_size, device=params.device, dtype=params.dtype)
51
+ q_resolutions = torch.zeros(q_values.shape[0], 1, dtype=q_values.dtype, device=q_values.device)
52
+
53
+ if dq is None:
54
+ return params.reflectivity(q_values, **refl_kwargs), q_resolutions
55
+
56
+ refl_kwargs_not_smeared = {}
57
+ refl_kwargs_smeared = {}
58
+ for key, value in refl_kwargs.items():
59
+ if isinstance(value, torch.Tensor) and value.shape[0] == params.batch_size:
60
+ refl_kwargs_not_smeared[key] = value[~indices]
61
+ refl_kwargs_smeared[key] = value[indices]
62
+ else:
63
+ refl_kwargs_not_smeared[key] = value
64
+ refl_kwargs_smeared[key] = value
65
+
66
+ # Compute unsmeared reflectivity
67
+ if (~indices).sum().item():
68
+ if q_values.dim() == 2 and q_values.shape[0] > 1:
69
+ q = q_values[~indices]
70
+ else:
71
+ q = q_values
72
+
73
+ reflectivity_not_smeared = params[~indices].reflectivity(q, **refl_kwargs_not_smeared)
74
+ else:
75
+ reflectivity_not_smeared = None
76
+
77
+ # Compute smeared reflectivity
78
+ if indices.sum().item():
79
+ if q_values.dim() == 2 and q_values.shape[0] > 1:
80
+ q = q_values[indices]
81
+ else:
82
+ q = q_values
83
+
84
+ reflectivity_smeared = params[indices].reflectivity(
85
+ q, dq=dq, constant_dq=self.constant_dq, gauss_num=self.gauss_num, **refl_kwargs_smeared
86
+ )
87
+ else:
88
+ reflectivity_smeared = None
89
+
90
+ curves = torch.empty(params.batch_size, q_values.shape[-1], device=params.device, dtype=params.dtype)
91
+
92
+ if (~indices).sum().item():
93
+ curves[~indices] = reflectivity_not_smeared
94
+
95
+ curves[indices] = reflectivity_smeared
96
+
97
+ q_resolutions[indices] = dq
98
+
99
+ return curves, q_resolutions