reflectorch 1.2.1__py3-none-any.whl → 1.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of reflectorch might be problematic. Click here for more details.
- reflectorch/data_generation/__init__.py +4 -0
- reflectorch/data_generation/dataset.py +27 -7
- reflectorch/data_generation/noise.py +115 -9
- reflectorch/data_generation/priors/parametric_models.py +91 -16
- reflectorch/data_generation/priors/parametric_subpriors.py +28 -7
- reflectorch/data_generation/priors/sampler_strategies.py +67 -3
- reflectorch/data_generation/q_generator.py +97 -43
- reflectorch/data_generation/reflectivity/__init__.py +53 -11
- reflectorch/data_generation/reflectivity/kinematical.py +4 -5
- reflectorch/data_generation/reflectivity/smearing.py +25 -10
- reflectorch/data_generation/reflectivity/smearing_pointwise.py +110 -0
- reflectorch/data_generation/smearing.py +42 -11
- reflectorch/data_generation/utils.py +93 -18
- reflectorch/extensions/refnx/refnx_conversion.py +77 -0
- reflectorch/inference/inference_model.py +795 -159
- reflectorch/inference/loading_data.py +37 -0
- reflectorch/inference/plotting.py +517 -0
- reflectorch/inference/preprocess_exp/interpolation.py +5 -2
- reflectorch/inference/scipy_fitter.py +98 -7
- reflectorch/ml/__init__.py +2 -0
- reflectorch/ml/basic_trainer.py +18 -6
- reflectorch/ml/callbacks.py +5 -4
- reflectorch/ml/loggers.py +25 -0
- reflectorch/ml/schedulers.py +116 -0
- reflectorch/ml/trainers.py +131 -23
- reflectorch/models/__init__.py +2 -1
- reflectorch/models/encoders/__init__.py +2 -2
- reflectorch/models/encoders/conv_encoder.py +54 -40
- reflectorch/models/encoders/fno.py +23 -16
- reflectorch/models/encoders/integral_kernel_embedding.py +390 -0
- reflectorch/models/networks/__init__.py +2 -0
- reflectorch/models/networks/mlp_networks.py +331 -153
- reflectorch/models/networks/residual_net.py +31 -5
- reflectorch/runs/train.py +0 -1
- reflectorch/runs/utils.py +48 -11
- reflectorch/utils.py +30 -0
- {reflectorch-1.2.1.dist-info → reflectorch-1.4.0.dist-info}/METADATA +20 -17
- {reflectorch-1.2.1.dist-info → reflectorch-1.4.0.dist-info}/RECORD +41 -36
- {reflectorch-1.2.1.dist-info → reflectorch-1.4.0.dist-info}/WHEEL +1 -1
- {reflectorch-1.2.1.dist-info → reflectorch-1.4.0.dist-info/licenses}/LICENSE.txt +0 -0
- {reflectorch-1.2.1.dist-info → reflectorch-1.4.0.dist-info}/top_level.txt +0 -0
|
@@ -9,15 +9,15 @@ class Smearing(object):
|
|
|
9
9
|
The intensity at a q point will be the average of the intensities of neighbouring q points, weighted by a gaussian profile.
|
|
10
10
|
|
|
11
11
|
Args:
|
|
12
|
-
sigma_range (tuple, optional): the range for sampling the resolutions. Defaults to (
|
|
12
|
+
sigma_range (tuple, optional): the range for sampling the resolutions. Defaults to (0.01, 0.1).
|
|
13
13
|
constant_dq (bool, optional): if ``True`` the smearing is constant (the resolution is given by the constant dq at each point in the curve)
|
|
14
14
|
otherwise the smearing is linear (the resolution is given by the constant dq/q at each point in the curve). Defaults to True.
|
|
15
15
|
gauss_num (int, optional): the number of interpolating gaussian profiles. Defaults to 31.
|
|
16
16
|
share_smeared (float, optional): the share of curves in the batch for which the resolution smearing is applied. Defaults to 0.2.
|
|
17
17
|
"""
|
|
18
18
|
def __init__(self,
|
|
19
|
-
sigma_range: tuple = (
|
|
20
|
-
constant_dq: bool =
|
|
19
|
+
sigma_range: tuple = (0.01, 0.1),
|
|
20
|
+
constant_dq: bool = False,
|
|
21
21
|
gauss_num: int = 31,
|
|
22
22
|
share_smeared: float = 0.2,
|
|
23
23
|
):
|
|
@@ -38,31 +38,62 @@ class Smearing(object):
|
|
|
38
38
|
indices = torch.zeros(batch_size, device=device, dtype=torch.bool)
|
|
39
39
|
indices[torch.randperm(batch_size, device=device)[:num_smeared]] = True
|
|
40
40
|
return dq, indices
|
|
41
|
+
|
|
42
|
+
def scale_resolutions(self, resolutions: Tensor) -> Tensor:
|
|
43
|
+
"""Scales the q-resolution values to [-1,1] range using the internal sigma range"""
|
|
44
|
+
sigma_min = 0.0 if self.share_smeared != 1.0 else self.sigma_min
|
|
45
|
+
return 2 * (resolutions - sigma_min) / (self.sigma_max - sigma_min) - 1
|
|
46
|
+
|
|
47
|
+
def get_curves(self, q_values: Tensor, params: BasicParams, refl_kwargs:dict = None):
|
|
48
|
+
refl_kwargs = refl_kwargs or {}
|
|
41
49
|
|
|
42
|
-
def get_curves(self, q_values: Tensor, params: BasicParams):
|
|
43
50
|
dq, indices = self.generate_resolutions(params.batch_size, device=params.device, dtype=params.dtype)
|
|
51
|
+
q_resolutions = torch.zeros(q_values.shape[0], 1, dtype=q_values.dtype, device=q_values.device)
|
|
44
52
|
|
|
45
53
|
if dq is None:
|
|
46
|
-
return params.reflectivity(q_values,
|
|
47
|
-
|
|
48
|
-
|
|
54
|
+
return params.reflectivity(q_values, **refl_kwargs), q_resolutions
|
|
55
|
+
|
|
56
|
+
refl_kwargs_not_smeared = {}
|
|
57
|
+
refl_kwargs_smeared = {}
|
|
58
|
+
for key, value in refl_kwargs.items():
|
|
59
|
+
if isinstance(value, torch.Tensor) and value.shape[0] == params.batch_size:
|
|
60
|
+
refl_kwargs_not_smeared[key] = value[~indices]
|
|
61
|
+
refl_kwargs_smeared[key] = value[indices]
|
|
62
|
+
else:
|
|
63
|
+
refl_kwargs_not_smeared[key] = value
|
|
64
|
+
refl_kwargs_smeared[key] = value
|
|
49
65
|
|
|
66
|
+
# Compute unsmeared reflectivity
|
|
50
67
|
if (~indices).sum().item():
|
|
51
68
|
if q_values.dim() == 2 and q_values.shape[0] > 1:
|
|
52
69
|
q = q_values[~indices]
|
|
53
70
|
else:
|
|
54
71
|
q = q_values
|
|
55
72
|
|
|
56
|
-
|
|
73
|
+
reflectivity_not_smeared = params[~indices].reflectivity(q, **refl_kwargs_not_smeared)
|
|
74
|
+
else:
|
|
75
|
+
reflectivity_not_smeared = None
|
|
57
76
|
|
|
77
|
+
# Compute smeared reflectivity
|
|
58
78
|
if indices.sum().item():
|
|
59
79
|
if q_values.dim() == 2 and q_values.shape[0] > 1:
|
|
60
80
|
q = q_values[indices]
|
|
61
81
|
else:
|
|
62
82
|
q = q_values
|
|
63
83
|
|
|
64
|
-
|
|
65
|
-
q, dq=dq, constant_dq=self.constant_dq,
|
|
84
|
+
reflectivity_smeared = params[indices].reflectivity(
|
|
85
|
+
q, dq=dq, constant_dq=self.constant_dq, gauss_num=self.gauss_num, **refl_kwargs_smeared
|
|
66
86
|
)
|
|
87
|
+
else:
|
|
88
|
+
reflectivity_smeared = None
|
|
89
|
+
|
|
90
|
+
curves = torch.empty(params.batch_size, q_values.shape[-1], device=params.device, dtype=params.dtype)
|
|
91
|
+
|
|
92
|
+
if (~indices).sum().item():
|
|
93
|
+
curves[~indices] = reflectivity_not_smeared
|
|
94
|
+
|
|
95
|
+
curves[indices] = reflectivity_smeared
|
|
96
|
+
|
|
97
|
+
q_resolutions[indices] = dq
|
|
67
98
|
|
|
68
|
-
return curves
|
|
99
|
+
return curves, q_resolutions
|
|
@@ -57,7 +57,7 @@ def get_reversed_params(thicknesses: Tensor, roughnesses: Tensor, slds: Tensor):
|
|
|
57
57
|
return reversed_params
|
|
58
58
|
|
|
59
59
|
|
|
60
|
-
def
|
|
60
|
+
def get_density_profiles_sld(
|
|
61
61
|
thicknesses: Tensor,
|
|
62
62
|
roughnesses: Tensor,
|
|
63
63
|
slds: Tensor,
|
|
@@ -120,29 +120,104 @@ def get_erf(z, z0, sigma, amp):
|
|
|
120
120
|
def get_gauss(z, z0, sigma, amp):
|
|
121
121
|
return amp / (sigma * sqrt(2 * pi)) * torch.exp(- (z - z0) ** 2 / 2 / sigma ** 2)
|
|
122
122
|
|
|
123
|
+
def get_density_profiles(
|
|
124
|
+
thicknesses: torch.Tensor,
|
|
125
|
+
roughnesses: torch.Tensor,
|
|
126
|
+
slds: torch.Tensor,
|
|
127
|
+
ambient_sld: torch.Tensor = None,
|
|
128
|
+
z_axis: torch.Tensor = None,
|
|
129
|
+
num: int = 1000,
|
|
130
|
+
padding_left: float = 0.2,
|
|
131
|
+
padding_right: float = 1.1,
|
|
132
|
+
):
|
|
133
|
+
"""
|
|
134
|
+
Args:
|
|
135
|
+
thicknesses (Tensor): finite layer thicknesses.
|
|
136
|
+
roughnesses (Tensor): interface roughnesses for all transitions (ambient→layer1 ... layerN→substrate).
|
|
137
|
+
slds (Tensor): SLDs for the finite layers + substrate.
|
|
138
|
+
ambient_sld (Tensor, optional): SLD for the top ambient. Defaults to 0.0 if None.
|
|
139
|
+
z_axis (Tensor, optional): a custom depth axis. If None, a linear axis is generated.
|
|
140
|
+
num (int): number of points in the generated z-axis (if z_axis is None).
|
|
141
|
+
padding_left (float): factor to extend the negative (above the surface) portion of z-axis.
|
|
142
|
+
padding_right (float): factor to extend the positive (into the sample) portion of z-axis.
|
|
143
|
+
|
|
144
|
+
Returns:
|
|
145
|
+
(z_axis, profile, d_profile)
|
|
146
|
+
z_axis: 1D Tensor of shape (num, ) with the depth coordinates.
|
|
147
|
+
profile: 2D Tensor of shape (batch_size, num) giving the SLD at each depth.
|
|
148
|
+
d_profile: 2D Tensor of shape (batch_size, num) giving d(SLD)/dz at each depth.
|
|
149
|
+
"""
|
|
150
|
+
|
|
151
|
+
bs, n = thicknesses.shape
|
|
152
|
+
assert roughnesses.shape == (bs, n + 1), (
|
|
153
|
+
f"Roughnesses must be (batch_size, num_layers+1). Found {roughnesses.shape} instead."
|
|
154
|
+
)
|
|
155
|
+
assert slds.shape == (bs, n + 1), (
|
|
156
|
+
f"SLDs must be (batch_size, num_layers+1). Found {slds.shape} instead."
|
|
157
|
+
)
|
|
158
|
+
assert torch.all(thicknesses >= 0), "Negative thickness encountered."
|
|
159
|
+
assert torch.all(roughnesses >= 0), "Negative roughness encountered."
|
|
160
|
+
|
|
161
|
+
if ambient_sld is None:
|
|
162
|
+
ambient_sld = torch.zeros((bs, 1), device=thicknesses.device)
|
|
163
|
+
else:
|
|
164
|
+
if ambient_sld.ndim == 1:
|
|
165
|
+
ambient_sld = ambient_sld.unsqueeze(-1)
|
|
166
|
+
ambient_sld = ambient_sld.expand(bs, 1)
|
|
167
|
+
|
|
168
|
+
slds_all = torch.cat([ambient_sld, slds], dim=-1) # new dimension: n+2
|
|
169
|
+
d_rhos = torch.diff(slds_all, dim=-1) # (bs, n+1)
|
|
170
|
+
|
|
171
|
+
interfaces = torch.cat([
|
|
172
|
+
torch.zeros((bs, 1), device=thicknesses.device), # z=0 for ambient→layer1
|
|
173
|
+
thicknesses
|
|
174
|
+
], dim=-1).cumsum(dim=-1) # now shape => (bs, n+1)
|
|
175
|
+
|
|
176
|
+
total_thickness = interfaces[..., -1].max()
|
|
177
|
+
if z_axis is None:
|
|
178
|
+
z_axis = torch.linspace(
|
|
179
|
+
-padding_left * total_thickness,
|
|
180
|
+
padding_right * total_thickness,
|
|
181
|
+
num,
|
|
182
|
+
device=thicknesses.device
|
|
183
|
+
) # shape => (num,)
|
|
184
|
+
if z_axis.ndim == 1:
|
|
185
|
+
z_axis = z_axis.unsqueeze(0) # shape => (1, num)
|
|
186
|
+
|
|
187
|
+
z_b = z_axis.repeat(bs, 1).unsqueeze(1) # (bs, 1, num)
|
|
188
|
+
interfaces_b = interfaces.unsqueeze(-1) # (bs, n+1, 1)
|
|
189
|
+
sigmas_b = (roughnesses * sqrt(2)).unsqueeze(-1) # (bs, n+1, 1)
|
|
190
|
+
d_rhos_b = d_rhos.unsqueeze(-1) # (bs, n+1, 1)
|
|
191
|
+
|
|
192
|
+
profile = get_erf(z_b, interfaces_b, sigmas_b, d_rhos_b).sum(dim=1) # (bs, num)
|
|
193
|
+
if ambient_sld is not None:
|
|
194
|
+
profile = profile + ambient_sld
|
|
195
|
+
|
|
196
|
+
d_profile = get_gauss(z_b, interfaces_b, sigmas_b, d_rhos_b).sum(dim=1) # (bs, num)
|
|
197
|
+
|
|
198
|
+
return z_axis.squeeze(0), profile, d_profile
|
|
123
199
|
|
|
124
200
|
def get_param_labels(
|
|
125
201
|
num_layers: int, *,
|
|
126
202
|
thickness_name: str = 'Thickness',
|
|
127
203
|
roughness_name: str = 'Roughness',
|
|
128
204
|
sld_name: str = 'SLD',
|
|
129
|
-
substrate_name: str = 'sub',
|
|
130
|
-
) -> List[str]:
|
|
131
|
-
thickness_labels = [f'{thickness_name} L{num_layers - i}' for i in range(num_layers)]
|
|
132
|
-
roughness_labels = [f'{roughness_name} L{num_layers - i}' for i in range(num_layers)] + [f'{roughness_name} {substrate_name}']
|
|
133
|
-
sld_labels = [f'{sld_name} L{num_layers - i}' for i in range(num_layers)] + [f'{sld_name} {substrate_name}']
|
|
134
|
-
return thickness_labels + roughness_labels + sld_labels
|
|
135
|
-
|
|
136
|
-
def get_param_labels_absorption_model(
|
|
137
|
-
num_layers: int, *,
|
|
138
|
-
thickness_name: str = 'Thickness',
|
|
139
|
-
roughness_name: str = 'Roughness',
|
|
140
|
-
real_sld_name: str = 'SLD real',
|
|
141
205
|
imag_sld_name: str = 'SLD imag',
|
|
142
206
|
substrate_name: str = 'sub',
|
|
207
|
+
parameterization_type: str = 'standard',
|
|
208
|
+
number_top_to_bottom: bool = False,
|
|
143
209
|
) -> List[str]:
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
210
|
+
def pos(i):
|
|
211
|
+
return i + 1 if number_top_to_bottom else num_layers - i
|
|
212
|
+
|
|
213
|
+
thickness_labels = [f'{thickness_name} L{pos(i)}' for i in range(num_layers)]
|
|
214
|
+
roughness_labels = [f'{roughness_name} L{pos(i)}' for i in range(num_layers)] + [f'{roughness_name} {substrate_name}']
|
|
215
|
+
sld_labels = [f'{sld_name} L{pos(i)}' for i in range(num_layers)] + [f'{sld_name} {substrate_name}']
|
|
216
|
+
|
|
217
|
+
all_labels = thickness_labels + roughness_labels + sld_labels
|
|
218
|
+
|
|
219
|
+
if parameterization_type == 'absorption':
|
|
220
|
+
imag_sld_labels = [f'{imag_sld_name} L{pos(i)}' for i in range(num_layers)] + [f'{imag_sld_name} {substrate_name}']
|
|
221
|
+
all_labels = all_labels + imag_sld_labels
|
|
222
|
+
|
|
223
|
+
return all_labels
|
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
from functools import reduce
|
|
3
|
+
from operator import or_
|
|
4
|
+
|
|
5
|
+
from reflectorch.inference.inference_model import EasyInferenceModel
|
|
6
|
+
from reflectorch import BasicParams
|
|
7
|
+
|
|
8
|
+
import refnx
|
|
9
|
+
from refnx.dataset import ReflectDataset, Data1D
|
|
10
|
+
from refnx.analysis import Transform, CurveFitter, Objective, Model, Parameter
|
|
11
|
+
from refnx.reflect import SLD, Slab, ReflectModel
|
|
12
|
+
|
|
13
|
+
def covert_reflectorch_prediction_to_refnx_structure(inference_model: EasyInferenceModel, pred_params_object: BasicParams, prior_bounds: np.array):
|
|
14
|
+
assert inference_model.trainer.loader.prior_sampler.param_model.__class__.__name__ == 'StandardModel'
|
|
15
|
+
|
|
16
|
+
n_layers = inference_model.trainer.loader.prior_sampler.max_num_layers
|
|
17
|
+
init_thicknesses = pred_params_object.thicknesses.squeeze().tolist()
|
|
18
|
+
init_roughnesses = pred_params_object.roughnesses.squeeze().tolist()
|
|
19
|
+
init_slds = pred_params_object.slds.squeeze().tolist()
|
|
20
|
+
|
|
21
|
+
sld_objects = []
|
|
22
|
+
|
|
23
|
+
for sld in init_slds:
|
|
24
|
+
sld_objects.append(SLD(value=sld))
|
|
25
|
+
|
|
26
|
+
layer_objects = [SLD(0)()]
|
|
27
|
+
for i in range(n_layers):
|
|
28
|
+
layer_objects.append(sld_objects[i](init_thicknesses[i], init_roughnesses[i]))
|
|
29
|
+
|
|
30
|
+
layer_objects.append(sld_objects[-1](0, init_roughnesses[-1]))
|
|
31
|
+
|
|
32
|
+
thickness_bounds = prior_bounds[:n_layers]
|
|
33
|
+
roughness_bounds = prior_bounds[n_layers:2*n_layers+1]
|
|
34
|
+
sld_bounds = prior_bounds[2*n_layers+1:]
|
|
35
|
+
|
|
36
|
+
for i, layer in enumerate(layer_objects):
|
|
37
|
+
if i == 0:
|
|
38
|
+
print("Ambient (air)")
|
|
39
|
+
print(80 * '-')
|
|
40
|
+
elif i < n_layers+1:
|
|
41
|
+
layer.thick.setp(bounds=thickness_bounds[i-1], vary=True)
|
|
42
|
+
layer.rough.setp(bounds=roughness_bounds[i-1], vary=True)
|
|
43
|
+
layer.sld.real.setp(bounds=sld_bounds[i-1], vary=True)
|
|
44
|
+
|
|
45
|
+
print(f'Layer {i}')
|
|
46
|
+
print(f'Thickness: value {layer.thick.value}, vary {layer.thick.vary}, bounds {layer.thick.bounds}')
|
|
47
|
+
print(f'Roughness: value {layer.rough.value}, vary {layer.rough.vary}, bounds {layer.rough.bounds}')
|
|
48
|
+
print(f'SLD: value {layer.sld.real.value}, vary {layer.sld.real.vary}, bounds {layer.sld.real.bounds}')
|
|
49
|
+
print(80 * '-')
|
|
50
|
+
else: #substrate
|
|
51
|
+
layer.rough.setp(bounds=roughness_bounds[i-1], vary=True)
|
|
52
|
+
layer.sld.real.setp(bounds=sld_bounds[i-1], vary=True)
|
|
53
|
+
|
|
54
|
+
print(f'Substrate')
|
|
55
|
+
print(f'Thickness: value {layer.thick.value}, vary {layer.thick.vary}, bounds {layer.thick.bounds}')
|
|
56
|
+
print(f'Roughness: value {layer.rough.value}, vary {layer.rough.vary}, bounds {layer.rough.bounds}')
|
|
57
|
+
print(f'SLD: value {layer.sld.real.value}, vary {layer.sld.real.vary}, bounds {layer.sld.real.bounds}')
|
|
58
|
+
|
|
59
|
+
refnx_structure = reduce(or_, layer_objects)
|
|
60
|
+
|
|
61
|
+
return refnx_structure
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
###Example usage:
|
|
65
|
+
# refnx_structure = covert_reflectorch_prediction_to_refnx_structure(inference_model, pred_params_object, prior_bounds)
|
|
66
|
+
|
|
67
|
+
# refnx_reflect_model = ReflectModel(refnx_structure, bkg=1e-10, dq=0.0)
|
|
68
|
+
# refnx_reflect_model.scale.setp(bounds=(0.8, 1.2), vary=True)
|
|
69
|
+
# refnx_reflect_model.q_offset.setp(bounds=(-0.01, 0.01), vary=True)
|
|
70
|
+
# refnx_reflect_model.bkg.setp(bounds=(1e-10, 1e-8), vary=True)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
# data = Data1D(data=(q_model, exp_curve_interp))
|
|
74
|
+
|
|
75
|
+
# refnx_objective = Objective(refnx_reflect_model, data, transform=Transform("logY"))
|
|
76
|
+
# fitter = CurveFitter(refnx_objective)
|
|
77
|
+
# fitter.fit('least_squares')
|