reflectorch 1.2.1__py3-none-any.whl → 1.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of reflectorch might be problematic. Click here for more details.

Files changed (41) hide show
  1. reflectorch/data_generation/__init__.py +4 -0
  2. reflectorch/data_generation/dataset.py +27 -7
  3. reflectorch/data_generation/noise.py +115 -9
  4. reflectorch/data_generation/priors/parametric_models.py +91 -16
  5. reflectorch/data_generation/priors/parametric_subpriors.py +28 -7
  6. reflectorch/data_generation/priors/sampler_strategies.py +67 -3
  7. reflectorch/data_generation/q_generator.py +97 -43
  8. reflectorch/data_generation/reflectivity/__init__.py +53 -11
  9. reflectorch/data_generation/reflectivity/kinematical.py +4 -5
  10. reflectorch/data_generation/reflectivity/smearing.py +25 -10
  11. reflectorch/data_generation/reflectivity/smearing_pointwise.py +110 -0
  12. reflectorch/data_generation/smearing.py +42 -11
  13. reflectorch/data_generation/utils.py +93 -18
  14. reflectorch/extensions/refnx/refnx_conversion.py +77 -0
  15. reflectorch/inference/inference_model.py +795 -159
  16. reflectorch/inference/loading_data.py +37 -0
  17. reflectorch/inference/plotting.py +517 -0
  18. reflectorch/inference/preprocess_exp/interpolation.py +5 -2
  19. reflectorch/inference/scipy_fitter.py +98 -7
  20. reflectorch/ml/__init__.py +2 -0
  21. reflectorch/ml/basic_trainer.py +18 -6
  22. reflectorch/ml/callbacks.py +5 -4
  23. reflectorch/ml/loggers.py +25 -0
  24. reflectorch/ml/schedulers.py +116 -0
  25. reflectorch/ml/trainers.py +131 -23
  26. reflectorch/models/__init__.py +2 -1
  27. reflectorch/models/encoders/__init__.py +2 -2
  28. reflectorch/models/encoders/conv_encoder.py +54 -40
  29. reflectorch/models/encoders/fno.py +23 -16
  30. reflectorch/models/encoders/integral_kernel_embedding.py +390 -0
  31. reflectorch/models/networks/__init__.py +2 -0
  32. reflectorch/models/networks/mlp_networks.py +331 -153
  33. reflectorch/models/networks/residual_net.py +31 -5
  34. reflectorch/runs/train.py +0 -1
  35. reflectorch/runs/utils.py +48 -11
  36. reflectorch/utils.py +30 -0
  37. {reflectorch-1.2.1.dist-info → reflectorch-1.4.0.dist-info}/METADATA +20 -17
  38. {reflectorch-1.2.1.dist-info → reflectorch-1.4.0.dist-info}/RECORD +41 -36
  39. {reflectorch-1.2.1.dist-info → reflectorch-1.4.0.dist-info}/WHEEL +1 -1
  40. {reflectorch-1.2.1.dist-info → reflectorch-1.4.0.dist-info/licenses}/LICENSE.txt +0 -0
  41. {reflectorch-1.2.1.dist-info → reflectorch-1.4.0.dist-info}/top_level.txt +0 -0
@@ -9,15 +9,15 @@ class Smearing(object):
9
9
  The intensity at a q point will be the average of the intensities of neighbouring q points, weighted by a gaussian profile.
10
10
 
11
11
  Args:
12
- sigma_range (tuple, optional): the range for sampling the resolutions. Defaults to (1e-4, 5e-3).
12
+ sigma_range (tuple, optional): the range for sampling the resolutions. Defaults to (0.01, 0.1).
13
13
  constant_dq (bool, optional): if ``True`` the smearing is constant (the resolution is given by the constant dq at each point in the curve)
14
14
  otherwise the smearing is linear (the resolution is given by the constant dq/q at each point in the curve). Defaults to True.
15
15
  gauss_num (int, optional): the number of interpolating gaussian profiles. Defaults to 31.
16
16
  share_smeared (float, optional): the share of curves in the batch for which the resolution smearing is applied. Defaults to 0.2.
17
17
  """
18
18
  def __init__(self,
19
- sigma_range: tuple = (1e-4, 5e-3),
20
- constant_dq: bool = True,
19
+ sigma_range: tuple = (0.01, 0.1),
20
+ constant_dq: bool = False,
21
21
  gauss_num: int = 31,
22
22
  share_smeared: float = 0.2,
23
23
  ):
@@ -38,31 +38,62 @@ class Smearing(object):
38
38
  indices = torch.zeros(batch_size, device=device, dtype=torch.bool)
39
39
  indices[torch.randperm(batch_size, device=device)[:num_smeared]] = True
40
40
  return dq, indices
41
+
42
+ def scale_resolutions(self, resolutions: Tensor) -> Tensor:
43
+ """Scales the q-resolution values to [-1,1] range using the internal sigma range"""
44
+ sigma_min = 0.0 if self.share_smeared != 1.0 else self.sigma_min
45
+ return 2 * (resolutions - sigma_min) / (self.sigma_max - sigma_min) - 1
46
+
47
+ def get_curves(self, q_values: Tensor, params: BasicParams, refl_kwargs:dict = None):
48
+ refl_kwargs = refl_kwargs or {}
41
49
 
42
- def get_curves(self, q_values: Tensor, params: BasicParams):
43
50
  dq, indices = self.generate_resolutions(params.batch_size, device=params.device, dtype=params.dtype)
51
+ q_resolutions = torch.zeros(q_values.shape[0], 1, dtype=q_values.dtype, device=q_values.device)
44
52
 
45
53
  if dq is None:
46
- return params.reflectivity(q_values, log=False)
47
-
48
- curves = torch.empty(params.batch_size, q_values.shape[-1], device=params.device, dtype=params.dtype)
54
+ return params.reflectivity(q_values, **refl_kwargs), q_resolutions
55
+
56
+ refl_kwargs_not_smeared = {}
57
+ refl_kwargs_smeared = {}
58
+ for key, value in refl_kwargs.items():
59
+ if isinstance(value, torch.Tensor) and value.shape[0] == params.batch_size:
60
+ refl_kwargs_not_smeared[key] = value[~indices]
61
+ refl_kwargs_smeared[key] = value[indices]
62
+ else:
63
+ refl_kwargs_not_smeared[key] = value
64
+ refl_kwargs_smeared[key] = value
49
65
 
66
+ # Compute unsmeared reflectivity
50
67
  if (~indices).sum().item():
51
68
  if q_values.dim() == 2 and q_values.shape[0] > 1:
52
69
  q = q_values[~indices]
53
70
  else:
54
71
  q = q_values
55
72
 
56
- curves[~indices] = params[~indices].reflectivity(q, log=False)
73
+ reflectivity_not_smeared = params[~indices].reflectivity(q, **refl_kwargs_not_smeared)
74
+ else:
75
+ reflectivity_not_smeared = None
57
76
 
77
+ # Compute smeared reflectivity
58
78
  if indices.sum().item():
59
79
  if q_values.dim() == 2 and q_values.shape[0] > 1:
60
80
  q = q_values[indices]
61
81
  else:
62
82
  q = q_values
63
83
 
64
- curves[indices] = params[indices].reflectivity(
65
- q, dq=dq, constant_dq=self.constant_dq, log=False, gauss_num=self.gauss_num
84
+ reflectivity_smeared = params[indices].reflectivity(
85
+ q, dq=dq, constant_dq=self.constant_dq, gauss_num=self.gauss_num, **refl_kwargs_smeared
66
86
  )
87
+ else:
88
+ reflectivity_smeared = None
89
+
90
+ curves = torch.empty(params.batch_size, q_values.shape[-1], device=params.device, dtype=params.dtype)
91
+
92
+ if (~indices).sum().item():
93
+ curves[~indices] = reflectivity_not_smeared
94
+
95
+ curves[indices] = reflectivity_smeared
96
+
97
+ q_resolutions[indices] = dq
67
98
 
68
- return curves
99
+ return curves, q_resolutions
@@ -57,7 +57,7 @@ def get_reversed_params(thicknesses: Tensor, roughnesses: Tensor, slds: Tensor):
57
57
  return reversed_params
58
58
 
59
59
 
60
- def get_density_profiles(
60
+ def get_density_profiles_sld(
61
61
  thicknesses: Tensor,
62
62
  roughnesses: Tensor,
63
63
  slds: Tensor,
@@ -120,29 +120,104 @@ def get_erf(z, z0, sigma, amp):
120
120
  def get_gauss(z, z0, sigma, amp):
121
121
  return amp / (sigma * sqrt(2 * pi)) * torch.exp(- (z - z0) ** 2 / 2 / sigma ** 2)
122
122
 
123
+ def get_density_profiles(
124
+ thicknesses: torch.Tensor,
125
+ roughnesses: torch.Tensor,
126
+ slds: torch.Tensor,
127
+ ambient_sld: torch.Tensor = None,
128
+ z_axis: torch.Tensor = None,
129
+ num: int = 1000,
130
+ padding_left: float = 0.2,
131
+ padding_right: float = 1.1,
132
+ ):
133
+ """
134
+ Args:
135
+ thicknesses (Tensor): finite layer thicknesses.
136
+ roughnesses (Tensor): interface roughnesses for all transitions (ambient→layer1 ... layerN→substrate).
137
+ slds (Tensor): SLDs for the finite layers + substrate.
138
+ ambient_sld (Tensor, optional): SLD for the top ambient. Defaults to 0.0 if None.
139
+ z_axis (Tensor, optional): a custom depth axis. If None, a linear axis is generated.
140
+ num (int): number of points in the generated z-axis (if z_axis is None).
141
+ padding_left (float): factor to extend the negative (above the surface) portion of z-axis.
142
+ padding_right (float): factor to extend the positive (into the sample) portion of z-axis.
143
+
144
+ Returns:
145
+ (z_axis, profile, d_profile)
146
+ z_axis: 1D Tensor of shape (num, ) with the depth coordinates.
147
+ profile: 2D Tensor of shape (batch_size, num) giving the SLD at each depth.
148
+ d_profile: 2D Tensor of shape (batch_size, num) giving d(SLD)/dz at each depth.
149
+ """
150
+
151
+ bs, n = thicknesses.shape
152
+ assert roughnesses.shape == (bs, n + 1), (
153
+ f"Roughnesses must be (batch_size, num_layers+1). Found {roughnesses.shape} instead."
154
+ )
155
+ assert slds.shape == (bs, n + 1), (
156
+ f"SLDs must be (batch_size, num_layers+1). Found {slds.shape} instead."
157
+ )
158
+ assert torch.all(thicknesses >= 0), "Negative thickness encountered."
159
+ assert torch.all(roughnesses >= 0), "Negative roughness encountered."
160
+
161
+ if ambient_sld is None:
162
+ ambient_sld = torch.zeros((bs, 1), device=thicknesses.device)
163
+ else:
164
+ if ambient_sld.ndim == 1:
165
+ ambient_sld = ambient_sld.unsqueeze(-1)
166
+ ambient_sld = ambient_sld.expand(bs, 1)
167
+
168
+ slds_all = torch.cat([ambient_sld, slds], dim=-1) # new dimension: n+2
169
+ d_rhos = torch.diff(slds_all, dim=-1) # (bs, n+1)
170
+
171
+ interfaces = torch.cat([
172
+ torch.zeros((bs, 1), device=thicknesses.device), # z=0 for ambient→layer1
173
+ thicknesses
174
+ ], dim=-1).cumsum(dim=-1) # now shape => (bs, n+1)
175
+
176
+ total_thickness = interfaces[..., -1].max()
177
+ if z_axis is None:
178
+ z_axis = torch.linspace(
179
+ -padding_left * total_thickness,
180
+ padding_right * total_thickness,
181
+ num,
182
+ device=thicknesses.device
183
+ ) # shape => (num,)
184
+ if z_axis.ndim == 1:
185
+ z_axis = z_axis.unsqueeze(0) # shape => (1, num)
186
+
187
+ z_b = z_axis.repeat(bs, 1).unsqueeze(1) # (bs, 1, num)
188
+ interfaces_b = interfaces.unsqueeze(-1) # (bs, n+1, 1)
189
+ sigmas_b = (roughnesses * sqrt(2)).unsqueeze(-1) # (bs, n+1, 1)
190
+ d_rhos_b = d_rhos.unsqueeze(-1) # (bs, n+1, 1)
191
+
192
+ profile = get_erf(z_b, interfaces_b, sigmas_b, d_rhos_b).sum(dim=1) # (bs, num)
193
+ if ambient_sld is not None:
194
+ profile = profile + ambient_sld
195
+
196
+ d_profile = get_gauss(z_b, interfaces_b, sigmas_b, d_rhos_b).sum(dim=1) # (bs, num)
197
+
198
+ return z_axis.squeeze(0), profile, d_profile
123
199
 
124
200
  def get_param_labels(
125
201
  num_layers: int, *,
126
202
  thickness_name: str = 'Thickness',
127
203
  roughness_name: str = 'Roughness',
128
204
  sld_name: str = 'SLD',
129
- substrate_name: str = 'sub',
130
- ) -> List[str]:
131
- thickness_labels = [f'{thickness_name} L{num_layers - i}' for i in range(num_layers)]
132
- roughness_labels = [f'{roughness_name} L{num_layers - i}' for i in range(num_layers)] + [f'{roughness_name} {substrate_name}']
133
- sld_labels = [f'{sld_name} L{num_layers - i}' for i in range(num_layers)] + [f'{sld_name} {substrate_name}']
134
- return thickness_labels + roughness_labels + sld_labels
135
-
136
- def get_param_labels_absorption_model(
137
- num_layers: int, *,
138
- thickness_name: str = 'Thickness',
139
- roughness_name: str = 'Roughness',
140
- real_sld_name: str = 'SLD real',
141
205
  imag_sld_name: str = 'SLD imag',
142
206
  substrate_name: str = 'sub',
207
+ parameterization_type: str = 'standard',
208
+ number_top_to_bottom: bool = False,
143
209
  ) -> List[str]:
144
- thickness_labels = [f'{thickness_name} L{num_layers - i}' for i in range(num_layers)]
145
- roughness_labels = [f'{roughness_name} L{num_layers - i}' for i in range(num_layers)] + [f'{roughness_name} {substrate_name}']
146
- real_sld_labels = [f'{real_sld_name} L{num_layers - i}' for i in range(num_layers)] + [f'{real_sld_name} {substrate_name}']
147
- imag_sld_labels = [f'{imag_sld_name} L{num_layers - i}' for i in range(num_layers)] + [f'{imag_sld_name} {substrate_name}']
148
- return thickness_labels + roughness_labels + real_sld_labels + imag_sld_labels
210
+ def pos(i):
211
+ return i + 1 if number_top_to_bottom else num_layers - i
212
+
213
+ thickness_labels = [f'{thickness_name} L{pos(i)}' for i in range(num_layers)]
214
+ roughness_labels = [f'{roughness_name} L{pos(i)}' for i in range(num_layers)] + [f'{roughness_name} {substrate_name}']
215
+ sld_labels = [f'{sld_name} L{pos(i)}' for i in range(num_layers)] + [f'{sld_name} {substrate_name}']
216
+
217
+ all_labels = thickness_labels + roughness_labels + sld_labels
218
+
219
+ if parameterization_type == 'absorption':
220
+ imag_sld_labels = [f'{imag_sld_name} L{pos(i)}' for i in range(num_layers)] + [f'{imag_sld_name} {substrate_name}']
221
+ all_labels = all_labels + imag_sld_labels
222
+
223
+ return all_labels
@@ -0,0 +1,77 @@
1
+ import numpy as np
2
+ from functools import reduce
3
+ from operator import or_
4
+
5
+ from reflectorch.inference.inference_model import EasyInferenceModel
6
+ from reflectorch import BasicParams
7
+
8
+ import refnx
9
+ from refnx.dataset import ReflectDataset, Data1D
10
+ from refnx.analysis import Transform, CurveFitter, Objective, Model, Parameter
11
+ from refnx.reflect import SLD, Slab, ReflectModel
12
+
13
+ def covert_reflectorch_prediction_to_refnx_structure(inference_model: EasyInferenceModel, pred_params_object: BasicParams, prior_bounds: np.array):
14
+ assert inference_model.trainer.loader.prior_sampler.param_model.__class__.__name__ == 'StandardModel'
15
+
16
+ n_layers = inference_model.trainer.loader.prior_sampler.max_num_layers
17
+ init_thicknesses = pred_params_object.thicknesses.squeeze().tolist()
18
+ init_roughnesses = pred_params_object.roughnesses.squeeze().tolist()
19
+ init_slds = pred_params_object.slds.squeeze().tolist()
20
+
21
+ sld_objects = []
22
+
23
+ for sld in init_slds:
24
+ sld_objects.append(SLD(value=sld))
25
+
26
+ layer_objects = [SLD(0)()]
27
+ for i in range(n_layers):
28
+ layer_objects.append(sld_objects[i](init_thicknesses[i], init_roughnesses[i]))
29
+
30
+ layer_objects.append(sld_objects[-1](0, init_roughnesses[-1]))
31
+
32
+ thickness_bounds = prior_bounds[:n_layers]
33
+ roughness_bounds = prior_bounds[n_layers:2*n_layers+1]
34
+ sld_bounds = prior_bounds[2*n_layers+1:]
35
+
36
+ for i, layer in enumerate(layer_objects):
37
+ if i == 0:
38
+ print("Ambient (air)")
39
+ print(80 * '-')
40
+ elif i < n_layers+1:
41
+ layer.thick.setp(bounds=thickness_bounds[i-1], vary=True)
42
+ layer.rough.setp(bounds=roughness_bounds[i-1], vary=True)
43
+ layer.sld.real.setp(bounds=sld_bounds[i-1], vary=True)
44
+
45
+ print(f'Layer {i}')
46
+ print(f'Thickness: value {layer.thick.value}, vary {layer.thick.vary}, bounds {layer.thick.bounds}')
47
+ print(f'Roughness: value {layer.rough.value}, vary {layer.rough.vary}, bounds {layer.rough.bounds}')
48
+ print(f'SLD: value {layer.sld.real.value}, vary {layer.sld.real.vary}, bounds {layer.sld.real.bounds}')
49
+ print(80 * '-')
50
+ else: #substrate
51
+ layer.rough.setp(bounds=roughness_bounds[i-1], vary=True)
52
+ layer.sld.real.setp(bounds=sld_bounds[i-1], vary=True)
53
+
54
+ print(f'Substrate')
55
+ print(f'Thickness: value {layer.thick.value}, vary {layer.thick.vary}, bounds {layer.thick.bounds}')
56
+ print(f'Roughness: value {layer.rough.value}, vary {layer.rough.vary}, bounds {layer.rough.bounds}')
57
+ print(f'SLD: value {layer.sld.real.value}, vary {layer.sld.real.vary}, bounds {layer.sld.real.bounds}')
58
+
59
+ refnx_structure = reduce(or_, layer_objects)
60
+
61
+ return refnx_structure
62
+
63
+
64
+ ###Example usage:
65
+ # refnx_structure = covert_reflectorch_prediction_to_refnx_structure(inference_model, pred_params_object, prior_bounds)
66
+
67
+ # refnx_reflect_model = ReflectModel(refnx_structure, bkg=1e-10, dq=0.0)
68
+ # refnx_reflect_model.scale.setp(bounds=(0.8, 1.2), vary=True)
69
+ # refnx_reflect_model.q_offset.setp(bounds=(-0.01, 0.01), vary=True)
70
+ # refnx_reflect_model.bkg.setp(bounds=(1e-10, 1e-8), vary=True)
71
+
72
+
73
+ # data = Data1D(data=(q_model, exp_curve_interp))
74
+
75
+ # refnx_objective = Objective(refnx_reflect_model, data, transform=Transform("logY"))
76
+ # fitter = CurveFitter(refnx_objective)
77
+ # fitter.fit('least_squares')