reflectorch 1.2.1__py3-none-any.whl → 1.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of reflectorch might be problematic. Click here for more details.
- reflectorch/data_generation/__init__.py +4 -0
- reflectorch/data_generation/dataset.py +27 -7
- reflectorch/data_generation/noise.py +115 -9
- reflectorch/data_generation/priors/parametric_models.py +91 -16
- reflectorch/data_generation/priors/parametric_subpriors.py +28 -7
- reflectorch/data_generation/priors/sampler_strategies.py +67 -3
- reflectorch/data_generation/q_generator.py +97 -43
- reflectorch/data_generation/reflectivity/__init__.py +53 -11
- reflectorch/data_generation/reflectivity/kinematical.py +4 -5
- reflectorch/data_generation/reflectivity/smearing.py +25 -10
- reflectorch/data_generation/reflectivity/smearing_pointwise.py +110 -0
- reflectorch/data_generation/smearing.py +42 -11
- reflectorch/data_generation/utils.py +93 -18
- reflectorch/extensions/refnx/refnx_conversion.py +77 -0
- reflectorch/inference/inference_model.py +795 -159
- reflectorch/inference/loading_data.py +37 -0
- reflectorch/inference/plotting.py +517 -0
- reflectorch/inference/preprocess_exp/interpolation.py +5 -2
- reflectorch/inference/scipy_fitter.py +98 -7
- reflectorch/ml/__init__.py +2 -0
- reflectorch/ml/basic_trainer.py +18 -6
- reflectorch/ml/callbacks.py +5 -4
- reflectorch/ml/loggers.py +25 -0
- reflectorch/ml/schedulers.py +116 -0
- reflectorch/ml/trainers.py +131 -23
- reflectorch/models/__init__.py +2 -1
- reflectorch/models/encoders/__init__.py +2 -2
- reflectorch/models/encoders/conv_encoder.py +54 -40
- reflectorch/models/encoders/fno.py +23 -16
- reflectorch/models/encoders/integral_kernel_embedding.py +390 -0
- reflectorch/models/networks/__init__.py +2 -0
- reflectorch/models/networks/mlp_networks.py +331 -153
- reflectorch/models/networks/residual_net.py +31 -5
- reflectorch/runs/train.py +0 -1
- reflectorch/runs/utils.py +48 -11
- reflectorch/utils.py +30 -0
- {reflectorch-1.2.1.dist-info → reflectorch-1.4.0.dist-info}/METADATA +20 -17
- {reflectorch-1.2.1.dist-info → reflectorch-1.4.0.dist-info}/RECORD +41 -36
- {reflectorch-1.2.1.dist-info → reflectorch-1.4.0.dist-info}/WHEEL +1 -1
- {reflectorch-1.2.1.dist-info → reflectorch-1.4.0.dist-info/licenses}/LICENSE.txt +0 -0
- {reflectorch-1.2.1.dist-info → reflectorch-1.4.0.dist-info}/top_level.txt +0 -0
|
@@ -73,11 +73,13 @@ class ConstrainedRoughnessSamplerStrategy(BasicSamplerStrategy):
|
|
|
73
73
|
roughness_mask: Tensor,
|
|
74
74
|
logdist: bool = False,
|
|
75
75
|
max_thickness_share: float = 0.5,
|
|
76
|
+
max_total_thickness: float = None,
|
|
76
77
|
):
|
|
77
78
|
super().__init__(logdist=logdist)
|
|
78
79
|
self.thickness_mask = thickness_mask
|
|
79
80
|
self.roughness_mask = roughness_mask
|
|
80
81
|
self.max_thickness_share = max_thickness_share
|
|
82
|
+
self.max_total_thickness = max_total_thickness
|
|
81
83
|
|
|
82
84
|
def sample(self, batch_size: int,
|
|
83
85
|
total_min_bounds: Tensor,
|
|
@@ -106,7 +108,8 @@ class ConstrainedRoughnessSamplerStrategy(BasicSamplerStrategy):
|
|
|
106
108
|
thickness_mask=self.thickness_mask.to(device),
|
|
107
109
|
roughness_mask=self.roughness_mask.to(device),
|
|
108
110
|
widths_sampler_func=self.widths_sampler_func,
|
|
109
|
-
|
|
111
|
+
coef_roughness=self.max_thickness_share,
|
|
112
|
+
max_total_thickness=self.max_total_thickness,
|
|
110
113
|
)
|
|
111
114
|
|
|
112
115
|
class ConstrainedRoughnessAndImgSldSamplerStrategy(BasicSamplerStrategy):
|
|
@@ -129,6 +132,7 @@ class ConstrainedRoughnessAndImgSldSamplerStrategy(BasicSamplerStrategy):
|
|
|
129
132
|
logdist: bool = False,
|
|
130
133
|
max_thickness_share: float = 0.5,
|
|
131
134
|
max_sld_share: float = 0.2,
|
|
135
|
+
max_total_thickness: float = None,
|
|
132
136
|
):
|
|
133
137
|
super().__init__(logdist=logdist)
|
|
134
138
|
self.thickness_mask = thickness_mask
|
|
@@ -137,6 +141,7 @@ class ConstrainedRoughnessAndImgSldSamplerStrategy(BasicSamplerStrategy):
|
|
|
137
141
|
self.isld_mask = isld_mask
|
|
138
142
|
self.max_thickness_share = max_thickness_share
|
|
139
143
|
self.max_sld_share = max_sld_share
|
|
144
|
+
self.max_total_thickness = max_total_thickness
|
|
140
145
|
|
|
141
146
|
def sample(self, batch_size: int,
|
|
142
147
|
total_min_bounds: Tensor,
|
|
@@ -169,6 +174,7 @@ class ConstrainedRoughnessAndImgSldSamplerStrategy(BasicSamplerStrategy):
|
|
|
169
174
|
widths_sampler_func=self.widths_sampler_func,
|
|
170
175
|
coef_roughness=self.max_thickness_share,
|
|
171
176
|
coef_isld=self.max_sld_share,
|
|
177
|
+
max_total_thickness=self.max_total_thickness,
|
|
172
178
|
)
|
|
173
179
|
|
|
174
180
|
def basic_sampler(
|
|
@@ -214,15 +220,44 @@ def constrained_roughness_sampler(
|
|
|
214
220
|
thickness_mask: Tensor,
|
|
215
221
|
roughness_mask: Tensor,
|
|
216
222
|
widths_sampler_func,
|
|
217
|
-
|
|
223
|
+
coef_roughness: float = 0.5,
|
|
224
|
+
max_total_thickness: float = None,
|
|
218
225
|
):
|
|
219
226
|
params, min_bounds, max_bounds = basic_sampler(
|
|
220
227
|
batch_size, total_min_bounds, total_max_bounds, total_min_delta, total_max_delta,
|
|
221
228
|
widths_sampler_func=widths_sampler_func,
|
|
222
229
|
)
|
|
223
230
|
|
|
231
|
+
if max_total_thickness is not None:
|
|
232
|
+
total_thickness = max_bounds[:, thickness_mask].sum(-1)
|
|
233
|
+
indices = total_thickness > max_total_thickness
|
|
234
|
+
|
|
235
|
+
if indices.any():
|
|
236
|
+
eps = 0.01
|
|
237
|
+
rand_scale = torch.rand_like(total_thickness) * eps + 1 - eps
|
|
238
|
+
scale_coef = max_total_thickness / total_thickness * rand_scale
|
|
239
|
+
scale_coef[~indices] = 1.0
|
|
240
|
+
min_bounds[:, thickness_mask] *= scale_coef[:, None]
|
|
241
|
+
max_bounds[:, thickness_mask] *= scale_coef[:, None]
|
|
242
|
+
params[:, thickness_mask] *= scale_coef[:, None]
|
|
243
|
+
|
|
244
|
+
min_bounds[:, thickness_mask] = torch.clamp_min(
|
|
245
|
+
min_bounds[:, thickness_mask],
|
|
246
|
+
total_min_bounds[:, thickness_mask],
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
max_bounds[:, thickness_mask] = torch.clamp_min(
|
|
250
|
+
max_bounds[:, thickness_mask],
|
|
251
|
+
total_min_bounds[:, thickness_mask],
|
|
252
|
+
)
|
|
253
|
+
|
|
254
|
+
params[:, thickness_mask] = torch.clamp_min(
|
|
255
|
+
params[:, thickness_mask],
|
|
256
|
+
total_min_bounds[:, thickness_mask],
|
|
257
|
+
)
|
|
258
|
+
|
|
224
259
|
max_roughness = torch.minimum(
|
|
225
|
-
get_max_allowed_roughness(thicknesses=params[..., thickness_mask], coef=
|
|
260
|
+
get_max_allowed_roughness(thicknesses=params[..., thickness_mask], coef=coef_roughness),
|
|
226
261
|
total_max_bounds[..., roughness_mask]
|
|
227
262
|
)
|
|
228
263
|
min_roughness = total_min_bounds[..., roughness_mask]
|
|
@@ -256,12 +291,41 @@ def constrained_roughness_and_isld_sampler(
|
|
|
256
291
|
widths_sampler_func,
|
|
257
292
|
coef_roughness: float = 0.5,
|
|
258
293
|
coef_isld: float = 0.2,
|
|
294
|
+
max_total_thickness: float = None,
|
|
259
295
|
):
|
|
260
296
|
params, min_bounds, max_bounds = basic_sampler(
|
|
261
297
|
batch_size, total_min_bounds, total_max_bounds, total_min_delta, total_max_delta,
|
|
262
298
|
widths_sampler_func=widths_sampler_func,
|
|
263
299
|
)
|
|
264
300
|
|
|
301
|
+
if max_total_thickness is not None:
|
|
302
|
+
total_thickness = max_bounds[:, thickness_mask].sum(-1)
|
|
303
|
+
indices = total_thickness > max_total_thickness
|
|
304
|
+
|
|
305
|
+
if indices.any():
|
|
306
|
+
eps = 0.01
|
|
307
|
+
rand_scale = torch.rand_like(total_thickness) * eps + 1 - eps
|
|
308
|
+
scale_coef = max_total_thickness / total_thickness * rand_scale
|
|
309
|
+
scale_coef[~indices] = 1.0
|
|
310
|
+
min_bounds[:, thickness_mask] *= scale_coef[:, None]
|
|
311
|
+
max_bounds[:, thickness_mask] *= scale_coef[:, None]
|
|
312
|
+
params[:, thickness_mask] *= scale_coef[:, None]
|
|
313
|
+
|
|
314
|
+
min_bounds[:, thickness_mask] = torch.clamp_min(
|
|
315
|
+
min_bounds[:, thickness_mask],
|
|
316
|
+
total_min_bounds[:, thickness_mask],
|
|
317
|
+
)
|
|
318
|
+
|
|
319
|
+
max_bounds[:, thickness_mask] = torch.clamp_min(
|
|
320
|
+
max_bounds[:, thickness_mask],
|
|
321
|
+
total_min_bounds[:, thickness_mask],
|
|
322
|
+
)
|
|
323
|
+
|
|
324
|
+
params[:, thickness_mask] = torch.clamp_min(
|
|
325
|
+
params[:, thickness_mask],
|
|
326
|
+
total_min_bounds[:, thickness_mask],
|
|
327
|
+
)
|
|
328
|
+
|
|
265
329
|
max_roughness = torch.minimum(
|
|
266
330
|
get_max_allowed_roughness(thicknesses=params[..., thickness_mask], coef=coef_roughness),
|
|
267
331
|
total_max_bounds[..., roughness_mask]
|
|
@@ -16,6 +16,7 @@ __all__ = [
|
|
|
16
16
|
"VariableQ",
|
|
17
17
|
"EquidistantQ",
|
|
18
18
|
"ConstantAngle",
|
|
19
|
+
"MaskedVariableQ",
|
|
19
20
|
]
|
|
20
21
|
|
|
21
22
|
|
|
@@ -50,6 +51,8 @@ class ConstantQ(QGenerator):
|
|
|
50
51
|
q = q[1:]
|
|
51
52
|
else:
|
|
52
53
|
q = q[:-1]
|
|
54
|
+
self.q_min = q.min().item()
|
|
55
|
+
self.q_max = q.max().item()
|
|
53
56
|
self.q = q
|
|
54
57
|
|
|
55
58
|
def get_batch(self, batch_size: int, context: dict = None) -> Tensor:
|
|
@@ -63,14 +66,26 @@ class ConstantQ(QGenerator):
|
|
|
63
66
|
"""
|
|
64
67
|
return self.q.clone()[None].expand(batch_size, self.q.shape[0])
|
|
65
68
|
|
|
69
|
+
def scale_q(self, q):
|
|
70
|
+
"""Scales the q values to the range [-1, 1].
|
|
71
|
+
|
|
72
|
+
Args:
|
|
73
|
+
q (Tensor): unscaled q values
|
|
74
|
+
|
|
75
|
+
Returns:
|
|
76
|
+
Tensor: scaled q values
|
|
77
|
+
"""
|
|
78
|
+
scaled_q_01 = (q - self.q_min) / (self.q_max - self.q_min)
|
|
79
|
+
return 2.0 * (scaled_q_01 - 0.5)
|
|
80
|
+
|
|
66
81
|
|
|
67
82
|
class VariableQ(QGenerator):
|
|
68
83
|
"""Q generator for reflectivity curves with variable discretization
|
|
69
84
|
|
|
70
85
|
Args:
|
|
71
|
-
q_min_range (list, optional): the range for sampling the minimum q value of the curves,
|
|
72
|
-
q_max_range (list, optional): the range for sampling the maximum q value of the curves,
|
|
73
|
-
n_q_range (list, optional): the range for the number of points in the curves (equidistantly sampled between
|
|
86
|
+
q_min_range (list, optional): the range for sampling the minimum q value of the curves, q_min. Defaults to [0.01, 0.03].
|
|
87
|
+
q_max_range (list, optional): the range for sampling the maximum q value of the curves, q_max. Defaults to [0.1, 0.5].
|
|
88
|
+
n_q_range (list, optional): the range for the number of points in the curves (equidistantly sampled between q_min and q_max,
|
|
74
89
|
the number of points varies between batches but is constant within a batch). Defaults to [64, 256].
|
|
75
90
|
device (optional): the Pytorch device. Defaults to DEFAULT_DEVICE.
|
|
76
91
|
dtype (optional): the Pytorch data type. Defaults to DEFAULT_DTYPE.
|
|
@@ -80,12 +95,14 @@ class VariableQ(QGenerator):
|
|
|
80
95
|
q_min_range: Tuple[float, float] = (0.01, 0.03),
|
|
81
96
|
q_max_range: Tuple[float, float] = (0.1, 0.5),
|
|
82
97
|
n_q_range: Tuple[int, int] = (64, 256),
|
|
98
|
+
mode: str = 'equidistant',
|
|
83
99
|
device=DEFAULT_DEVICE,
|
|
84
100
|
dtype=DEFAULT_DTYPE,
|
|
85
101
|
):
|
|
86
102
|
self.q_min_range = q_min_range
|
|
87
103
|
self.q_max_range = q_max_range
|
|
88
104
|
self.n_q_range = n_q_range
|
|
105
|
+
self.mode = mode
|
|
89
106
|
self.device = device
|
|
90
107
|
self.dtype = dtype
|
|
91
108
|
|
|
@@ -98,14 +115,23 @@ class VariableQ(QGenerator):
|
|
|
98
115
|
Returns:
|
|
99
116
|
Tensor: generated batch of q values
|
|
100
117
|
"""
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
118
|
+
|
|
119
|
+
q_min = torch.rand(batch_size, device=self.device, dtype=self.dtype) * (self.q_min_range[1] - self.q_min_range[0]) + self.q_min_range[0]
|
|
120
|
+
q_max = torch.rand(batch_size, device=self.device, dtype=self.dtype) * (self.q_max_range[1] - self.q_max_range[0]) + self.q_max_range[0]
|
|
121
|
+
|
|
122
|
+
n_q = torch.randint(self.n_q_range[0], self.n_q_range[1] + 1, (1,), device=self.device).item()
|
|
123
|
+
|
|
124
|
+
if self.mode == 'equidistant':
|
|
125
|
+
q = torch.linspace(0, 1, n_q, device=self.device, dtype=self.dtype)
|
|
126
|
+
elif self.mode == 'random':
|
|
127
|
+
q = torch.rand(n_q, device=self.device, dtype=self.dtype).sort().values
|
|
128
|
+
elif self.mode == 'logspace':
|
|
129
|
+
q = torch.logspace(
|
|
130
|
+
start=torch.log10(torch.tensor(1e-4, dtype=self.dtype, device=self.device)),
|
|
131
|
+
end=torch.log10(torch.tensor(1.0, dtype=self.dtype, device=self.device)),
|
|
132
|
+
steps=n_q, dtype=self.dtype, device=self.device)
|
|
133
|
+
|
|
134
|
+
q = q_min[:, None] + q * (q_max - q_min)[:, None]
|
|
109
135
|
|
|
110
136
|
return q
|
|
111
137
|
|
|
@@ -178,49 +204,77 @@ class EquidistantQ(QGenerator):
|
|
|
178
204
|
return qs
|
|
179
205
|
|
|
180
206
|
|
|
181
|
-
class
|
|
207
|
+
class MaskedVariableQ:
|
|
182
208
|
def __init__(self,
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
self.
|
|
209
|
+
q_min_range=(0.01, 0.03),
|
|
210
|
+
q_max_range=(0.1, 0.5),
|
|
211
|
+
n_q_range=(64, 256),
|
|
212
|
+
mode='equidistant',
|
|
213
|
+
shuffle_mask=False,
|
|
214
|
+
total_thickness_constraint=True,
|
|
215
|
+
min_points_per_fringe=4,
|
|
216
|
+
device=DEFAULT_DEVICE,
|
|
217
|
+
dtype=DEFAULT_DTYPE):
|
|
218
|
+
self.q_min_range = q_min_range
|
|
219
|
+
self.q_max_range = q_max_range
|
|
220
|
+
self.n_q_range = n_q_range
|
|
193
221
|
self.device = device
|
|
194
222
|
self.dtype = dtype
|
|
195
|
-
|
|
196
|
-
|
|
223
|
+
self.mode = mode
|
|
224
|
+
self.shuffle_mask = shuffle_mask
|
|
225
|
+
self.total_thickness_constraint = total_thickness_constraint
|
|
226
|
+
self.min_points_per_fringe = min_points_per_fringe
|
|
227
|
+
|
|
228
|
+
def get_batch(self, batch_size, context):
|
|
197
229
|
assert context is not None
|
|
198
230
|
|
|
199
|
-
|
|
200
|
-
|
|
231
|
+
q_min = torch.rand(batch_size, device=self.device, dtype=self.dtype) * (self.q_min_range[1] - self.q_min_range[0]) + self.q_min_range[0]
|
|
232
|
+
q_max = torch.rand(batch_size, device=self.device, dtype=self.dtype) * (self.q_max_range[1] - self.q_max_range[0]) + self.q_max_range[0]
|
|
201
233
|
|
|
202
|
-
|
|
234
|
+
max_n_q = self.n_q_range[1]
|
|
203
235
|
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
236
|
+
if self.mode == 'equidistant':
|
|
237
|
+
positions = torch.linspace(0, 1, max_n_q, device=self.device, dtype=self.dtype).expand(batch_size, max_n_q)
|
|
238
|
+
elif self.mode == 'random':
|
|
239
|
+
positions = torch.rand(batch_size, max_n_q, device=self.device, dtype=self.dtype)
|
|
240
|
+
positions, _ = positions.sort(dim=-1)
|
|
241
|
+
elif self.mode == 'mixed':
|
|
242
|
+
positions = torch.empty(batch_size, max_n_q, device=self.device, dtype=self.dtype)
|
|
207
243
|
|
|
208
|
-
|
|
244
|
+
half = batch_size // 2 # half batch gets equidistant
|
|
245
|
+
eq_pos = torch.linspace(0, 1, max_n_q, device=self.device, dtype=self.dtype).expand(half, max_n_q)
|
|
246
|
+
positions[:half] = eq_pos
|
|
209
247
|
|
|
210
|
-
|
|
248
|
+
rand_pos = torch.rand(batch_size - half, max_n_q, device=self.device, dtype=self.dtype) # other half gets sorted random
|
|
249
|
+
rand_pos, _ = rand_pos.sort(dim=-1)
|
|
250
|
+
positions[half:] = rand_pos
|
|
251
|
+
else:
|
|
252
|
+
raise ValueError(f"Unknown spacing mode: {self.mode}")
|
|
211
253
|
|
|
212
|
-
|
|
254
|
+
q = q_min[:, None] + positions * (q_max - q_min)[:, None]
|
|
213
255
|
|
|
214
|
-
|
|
215
|
-
context['num_q_values'] = num_q_values
|
|
256
|
+
n_qs = torch.randint(self.n_q_range[0], self.n_q_range[1] + 1, (batch_size,), device=self.device)
|
|
216
257
|
|
|
217
|
-
|
|
258
|
+
if 'params' in context and self.total_thickness_constraint: ### N_points > 1 + (Q_spread * total_thickness * min_np_per_kiessing_fringe) / (2*pi)
|
|
259
|
+
d_total = context['params'].thicknesses.sum(-1)
|
|
260
|
+
limit = 1 + ((q_max - q_min) * d_total * self.min_points_per_fringe) / (2*np.pi)
|
|
261
|
+
limit = limit.ceil().int()
|
|
262
|
+
n_qs = torch.maximum(n_qs, limit)
|
|
263
|
+
n_qs = torch.clamp(n_qs, max=self.n_q_range[1])
|
|
218
264
|
|
|
265
|
+
indices = torch.arange(max_n_q, device=self.device).expand(batch_size, max_n_q)
|
|
266
|
+
valid_mask = indices < n_qs[:, None] # right side padding
|
|
267
|
+
|
|
268
|
+
if self.shuffle_mask: # shuffle valid positions (inter-spread padding)
|
|
269
|
+
perm = torch.argsort(torch.rand(batch_size, max_n_q, device=self.device), dim=-1)
|
|
270
|
+
valid_mask = torch.gather(valid_mask, dim=1, index=perm)
|
|
271
|
+
|
|
272
|
+
context['key_padding_mask'] = valid_mask
|
|
273
|
+
context['n_points'] = valid_mask.sum(dim=-1)
|
|
274
|
+
|
|
275
|
+
return q
|
|
276
|
+
|
|
277
|
+
def scale_q(self, q):
|
|
278
|
+
scaled_q_01 = (q - self.q_min_range[0]) / (self.q_max_range[1] - self.q_min_range[0])
|
|
219
279
|
|
|
220
|
-
|
|
221
|
-
batch_size = num_q_values.shape[0]
|
|
222
|
-
dqs = (q_max / num_q_values)[:, None]
|
|
223
|
-
q_values = torch.arange(1, num_q_values.max().item() + 1)[None].repeat(batch_size, 1) * dqs
|
|
224
|
-
mask = (q_values > q_max + dqs / 2)
|
|
225
|
-
q_values[mask] = 0.
|
|
226
|
-
return q_values, mask
|
|
280
|
+
return 2.0 * (scaled_q_01 - 0.5)
|
|
@@ -10,6 +10,7 @@ from reflectorch.data_generation.reflectivity.numpy_implementations import (
|
|
|
10
10
|
abeles_np,
|
|
11
11
|
)
|
|
12
12
|
from reflectorch.data_generation.reflectivity.smearing import abeles_constant_smearing
|
|
13
|
+
from reflectorch.data_generation.reflectivity.smearing_pointwise import abeles_pointwise_smearing
|
|
13
14
|
from reflectorch.data_generation.reflectivity.kinematical import kinematical_approximation
|
|
14
15
|
|
|
15
16
|
|
|
@@ -20,9 +21,15 @@ def reflectivity(
|
|
|
20
21
|
sld: Tensor,
|
|
21
22
|
dq: Tensor = None,
|
|
22
23
|
gauss_num: int = 51,
|
|
23
|
-
constant_dq: bool =
|
|
24
|
+
constant_dq: bool = False,
|
|
24
25
|
log: bool = False,
|
|
25
|
-
|
|
26
|
+
q_shift: Tensor = 0.0,
|
|
27
|
+
r_scale: Tensor = 1.0,
|
|
28
|
+
background: Tensor = 0.0,
|
|
29
|
+
solvent_vf = None,
|
|
30
|
+
solvent_mode = 'fronting',
|
|
31
|
+
abeles_func = None,
|
|
32
|
+
**abeles_kwargs
|
|
26
33
|
):
|
|
27
34
|
"""Function which computes the reflectivity curves from thin film parameters.
|
|
28
35
|
By default it uses the fast implementation of the Abeles matrix formalism.
|
|
@@ -37,24 +44,59 @@ def reflectivity(
|
|
|
37
44
|
Either dq if ``constant_dq`` is ``True`` or dq/q if ``constant_dq`` is ``False``. Defaults to None.
|
|
38
45
|
gauss_num (int, optional): the number of gaussians for curve smearing. Defaults to 51.
|
|
39
46
|
constant_dq (bool, optional): if ``True`` the smearing is constant (constant dq at each point in the curve)
|
|
40
|
-
otherwise the smearing is linear (constant dq/q at each point in the curve). Defaults to
|
|
47
|
+
otherwise the smearing is linear (constant dq/q at each point in the curve). Defaults to False.
|
|
41
48
|
log (bool, optional): if True the base 10 logarithm of the reflectivity curves is returned. Defaults to False.
|
|
49
|
+
q_shift (float or Tensor, optional): misalignment in q.
|
|
50
|
+
r_scale (float or Tensor, optional): normalization factor (scales reflectivity).
|
|
51
|
+
background (float or Tensor, optional): background intensity.
|
|
42
52
|
abeles_func (Callable, optional): a function implementing the simulation of the reflectivity curves, if different than the default Abeles matrix implementation ('abeles'). Defaults to None.
|
|
43
|
-
|
|
53
|
+
abeles_kwargs: Additional arguments specific to the chosen `abeles_func`.
|
|
44
54
|
Returns:
|
|
45
|
-
Tensor:
|
|
55
|
+
Tensor: the computed reflectivity curves
|
|
46
56
|
"""
|
|
47
57
|
abeles_func = abeles_func or abeles
|
|
48
|
-
q = torch.atleast_2d(q)
|
|
58
|
+
q = torch.atleast_2d(q) + q_shift
|
|
59
|
+
q = torch.clamp(q, min=0.0)
|
|
60
|
+
|
|
61
|
+
if solvent_vf is not None:
|
|
62
|
+
num_layers = thickness.shape[-1]
|
|
63
|
+
if solvent_mode == 'fronting':
|
|
64
|
+
assert sld.shape[-1] == num_layers + 2
|
|
65
|
+
assert solvent_vf.shape[-1] == num_layers
|
|
66
|
+
solvent_sld = sld[..., [0]]
|
|
67
|
+
idx = slice(1, num_layers)
|
|
68
|
+
sld[..., idx] = solvent_vf * solvent_sld + (1.0 - solvent_vf) * sld[..., idx]
|
|
69
|
+
elif solvent_mode == 'backing':
|
|
70
|
+
solvent_sld = sld[..., [-1]]
|
|
71
|
+
idx = slice(1, num_layers) if sld.shape[-1] == num_layers + 2 else slice(0, num_layers)
|
|
72
|
+
sld[..., idx] = solvent_vf * solvent_sld + (1.0 - solvent_vf) * sld[..., idx]
|
|
73
|
+
else:
|
|
74
|
+
raise NotImplementedError
|
|
49
75
|
|
|
50
76
|
if dq is None:
|
|
51
|
-
reflectivity_curves = abeles_func(q, thickness, roughness, sld)
|
|
77
|
+
reflectivity_curves = abeles_func(q, thickness, roughness, sld, **abeles_kwargs)
|
|
52
78
|
else:
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
79
|
+
if dq.shape[-1] > 1:
|
|
80
|
+
reflectivity_curves = abeles_pointwise_smearing(
|
|
81
|
+
q=q, dq=dq, thickness=thickness, roughness=roughness, sld=sld,
|
|
82
|
+
abeles_func=abeles_func, gauss_num=gauss_num,
|
|
83
|
+
**abeles_kwargs,
|
|
84
|
+
)
|
|
85
|
+
else:
|
|
86
|
+
reflectivity_curves = abeles_constant_smearing(
|
|
87
|
+
q, thickness, roughness, sld,
|
|
88
|
+
dq=dq, gauss_num=gauss_num, constant_dq=constant_dq, abeles_func=abeles_func,
|
|
89
|
+
**abeles_kwargs,
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
if isinstance(r_scale, Tensor):
|
|
93
|
+
r_scale = r_scale.view(-1, *[1] * (reflectivity_curves.dim() - 1))
|
|
94
|
+
if isinstance(background, Tensor):
|
|
95
|
+
background = background.view(-1, *[1] * (reflectivity_curves.dim() - 1))
|
|
96
|
+
|
|
97
|
+
reflectivity_curves = reflectivity_curves * r_scale + background
|
|
57
98
|
|
|
58
99
|
if log:
|
|
59
100
|
reflectivity_curves = torch.log10(reflectivity_curves)
|
|
101
|
+
|
|
60
102
|
return reflectivity_curves
|
|
@@ -51,7 +51,7 @@ def kinematical_approximation(
|
|
|
51
51
|
|
|
52
52
|
substrate_sld = sld[:, -1:]
|
|
53
53
|
|
|
54
|
-
rf =
|
|
54
|
+
rf = _get_fresnel_reflectivity(q, substrate_sld[:, None])
|
|
55
55
|
|
|
56
56
|
r = torch.clamp_max_(r * rf / substrate_sld.real ** 2, 1.)
|
|
57
57
|
|
|
@@ -61,12 +61,11 @@ def kinematical_approximation(
|
|
|
61
61
|
return r
|
|
62
62
|
|
|
63
63
|
|
|
64
|
-
def
|
|
65
|
-
_RE_CONST = 0.28174103675406496
|
|
64
|
+
def _get_fresnel_reflectivity(q, substrate_slds):
|
|
65
|
+
_RE_CONST = 0.28174103675406496 # 2/sqrt(16*pi)
|
|
66
66
|
|
|
67
67
|
q_c = torch.sqrt(substrate_slds + 0j) / _RE_CONST * 2
|
|
68
68
|
q_prime = torch.sqrt(q ** 2 - q_c ** 2 + 0j)
|
|
69
69
|
r_f = ((q - q_prime) / (q + q_prime)).abs().float() ** 2
|
|
70
70
|
|
|
71
|
-
return r_f.squeeze(-1)
|
|
72
|
-
|
|
71
|
+
return r_f.squeeze(-1)
|
|
@@ -14,26 +14,41 @@ def abeles_constant_smearing(
|
|
|
14
14
|
roughness: Tensor,
|
|
15
15
|
sld: Tensor,
|
|
16
16
|
dq: Tensor = None,
|
|
17
|
-
gauss_num: int =
|
|
18
|
-
constant_dq: bool =
|
|
17
|
+
gauss_num: int = 31,
|
|
18
|
+
constant_dq: bool = False,
|
|
19
19
|
abeles_func=None,
|
|
20
|
+
**abeles_kwargs
|
|
20
21
|
):
|
|
21
22
|
abeles_func = abeles_func or abeles
|
|
23
|
+
|
|
24
|
+
if dq.dtype != thickness.dtype:
|
|
25
|
+
q = q.to(thickness)
|
|
26
|
+
|
|
27
|
+
if dq.dtype != thickness.dtype:
|
|
28
|
+
dq = dq.to(thickness)
|
|
29
|
+
|
|
30
|
+
if q.shape[0] == 1:
|
|
31
|
+
q = q.repeat(thickness.shape[0], 1)
|
|
32
|
+
|
|
22
33
|
q_lin = _get_q_axes(q, dq, gauss_num, constant_dq=constant_dq)
|
|
23
34
|
kernels = _get_t_gauss_kernels(dq, gauss_num)
|
|
24
|
-
|
|
25
|
-
curves = abeles_func(q_lin, thickness, roughness, sld)
|
|
35
|
+
|
|
36
|
+
curves = abeles_func(q_lin, thickness, roughness, sld, **abeles_kwargs)
|
|
26
37
|
|
|
27
38
|
padding = (kernels.shape[-1] - 1) // 2
|
|
39
|
+
padded_curves = pad(curves, (padding, padding), 'reflect')
|
|
40
|
+
|
|
28
41
|
smeared_curves = conv1d(
|
|
29
|
-
|
|
30
|
-
)
|
|
42
|
+
padded_curves, kernels[:, None], groups=kernels.shape[0],
|
|
43
|
+
)
|
|
31
44
|
|
|
32
45
|
if q.shape[0] != smeared_curves.shape[0]:
|
|
33
|
-
|
|
34
|
-
|
|
46
|
+
repeat_factor = smeared_curves.shape[0] // q.shape[0]
|
|
47
|
+
q = q.repeat(repeat_factor, 1)
|
|
48
|
+
q_lin = q_lin.repeat(repeat_factor, 1)
|
|
49
|
+
|
|
35
50
|
smeared_curves = _batch_linear_interp1d(q_lin, smeared_curves, q)
|
|
36
|
-
|
|
51
|
+
|
|
37
52
|
return smeared_curves
|
|
38
53
|
|
|
39
54
|
|
|
@@ -55,7 +70,7 @@ def _get_t_gauss_kernels(resolutions: Tensor, gaussnum: int = 51):
|
|
|
55
70
|
return gauss_y
|
|
56
71
|
|
|
57
72
|
|
|
58
|
-
def _get_q_axes(q: Tensor, resolutions: Tensor, gaussnum: int = 51, constant_dq: bool =
|
|
73
|
+
def _get_q_axes(q: Tensor, resolutions: Tensor, gaussnum: int = 51, constant_dq: bool = False):
|
|
59
74
|
if constant_dq:
|
|
60
75
|
return _get_q_axes_for_constant_dq(q, resolutions, gaussnum)
|
|
61
76
|
else:
|
|
@@ -0,0 +1,110 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import scipy
|
|
3
|
+
import numpy as np
|
|
4
|
+
from functools import lru_cache
|
|
5
|
+
from typing import Tuple
|
|
6
|
+
|
|
7
|
+
from reflectorch.data_generation.reflectivity.abeles import abeles
|
|
8
|
+
|
|
9
|
+
#Pytorch version based on the JAX implementation of pointwise smearing in the refnx package.
|
|
10
|
+
|
|
11
|
+
@lru_cache(maxsize=128)
|
|
12
|
+
def gauss_legendre(n: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
13
|
+
"""
|
|
14
|
+
Calculate Gaussian quadrature abscissae and weights.
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
n (int): Gaussian quadrature order.
|
|
18
|
+
|
|
19
|
+
Returns:
|
|
20
|
+
Tuple[torch.Tensor, torch.Tensor]: The abscissae and weights for Gauss-Legendre integration.
|
|
21
|
+
"""
|
|
22
|
+
return scipy.special.p_roots(n)
|
|
23
|
+
|
|
24
|
+
def gauss(x: torch.Tensor) -> torch.Tensor:
|
|
25
|
+
"""
|
|
26
|
+
Calculate the Gaussian function.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
x (torch.Tensor): Input tensor.
|
|
30
|
+
|
|
31
|
+
Returns:
|
|
32
|
+
torch.Tensor: Output tensor after applying the Gaussian function.
|
|
33
|
+
"""
|
|
34
|
+
return torch.exp(-0.5 * x * x)
|
|
35
|
+
|
|
36
|
+
def abeles_pointwise_smearing(
|
|
37
|
+
q: torch.Tensor,
|
|
38
|
+
dq: torch.Tensor,
|
|
39
|
+
thickness: torch.Tensor,
|
|
40
|
+
roughness: torch.Tensor,
|
|
41
|
+
sld: torch.Tensor,
|
|
42
|
+
gauss_num: int = 17,
|
|
43
|
+
abeles_func=None,
|
|
44
|
+
**abeles_kwargs,
|
|
45
|
+
) -> torch.Tensor:
|
|
46
|
+
"""
|
|
47
|
+
Compute reflectivity with variable smearing using Gaussian quadrature.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
q (torch.Tensor): The momentum transfer (q) values.
|
|
51
|
+
dq (torch.Tensor): The resolution for curve smearing.
|
|
52
|
+
thickness (torch.Tensor): The layer thicknesses.
|
|
53
|
+
roughness (torch.Tensor): The interlayer roughnesses.
|
|
54
|
+
sld (torch.Tensor): The SLDs of the layers.
|
|
55
|
+
sld_magnetic (torch.Tensor, optional): The magnetic SLDs of the layers.
|
|
56
|
+
magnetization_angle (torch.Tensor, optional): The magnetization angles.
|
|
57
|
+
polarizer_eff (torch.Tensor, optional): The polarizer efficiency.
|
|
58
|
+
analyzer_eff (torch.Tensor, optional): The analyzer efficiency.
|
|
59
|
+
abeles_func (Callable, optional): A function implementing the simulation of the reflectivity curves.
|
|
60
|
+
gauss_num (int, optional): Gaussian quadrature order. Defaults to 17.
|
|
61
|
+
|
|
62
|
+
Returns:
|
|
63
|
+
torch.Tensor: The computed reflectivity curves.
|
|
64
|
+
"""
|
|
65
|
+
abeles_func = abeles_func or abeles
|
|
66
|
+
|
|
67
|
+
if q.shape[0] == 1:
|
|
68
|
+
q = q.repeat(thickness.shape[0], 1)
|
|
69
|
+
|
|
70
|
+
_FWHM = 2 * np.sqrt(2 * np.log(2.0))
|
|
71
|
+
_INTLIMIT = 3.5
|
|
72
|
+
|
|
73
|
+
bs = q.shape[0]
|
|
74
|
+
nq = q.shape[-1]
|
|
75
|
+
device = q.device
|
|
76
|
+
|
|
77
|
+
quad_order = gauss_num
|
|
78
|
+
abscissa, weights = gauss_legendre(quad_order)
|
|
79
|
+
abscissa = torch.tensor(abscissa)[None, :, None].to(device)
|
|
80
|
+
weights = torch.tensor(weights)[None, :, None].to(device)
|
|
81
|
+
prefactor = 1.0 / np.sqrt(2 * np.pi)
|
|
82
|
+
|
|
83
|
+
gaussvals = prefactor * gauss(abscissa * _INTLIMIT)
|
|
84
|
+
|
|
85
|
+
va = q[:, None, :] - _INTLIMIT * dq[:, None, :] / _FWHM
|
|
86
|
+
vb = q[:, None, :] + _INTLIMIT * dq[:, None, :] / _FWHM
|
|
87
|
+
|
|
88
|
+
qvals_for_res_0 = (abscissa * (vb - va) + vb + va) / 2
|
|
89
|
+
qvals_for_res = qvals_for_res_0.reshape(bs, -1)
|
|
90
|
+
|
|
91
|
+
refl_curves = abeles_func(
|
|
92
|
+
q=qvals_for_res,
|
|
93
|
+
thickness=thickness,
|
|
94
|
+
roughness=roughness,
|
|
95
|
+
sld=sld,
|
|
96
|
+
**abeles_kwargs
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
# Handle multiple channels
|
|
100
|
+
if refl_curves.dim() == 3:
|
|
101
|
+
n_channels = refl_curves.shape[1]
|
|
102
|
+
refl_curves = refl_curves.reshape(bs, n_channels, quad_order, nq)
|
|
103
|
+
refl_curves = refl_curves * gaussvals.unsqueeze(1) * weights.unsqueeze(1)
|
|
104
|
+
refl_curves = torch.sum(refl_curves, dim=2) * _INTLIMIT
|
|
105
|
+
else:
|
|
106
|
+
refl_curves = refl_curves.reshape(bs, quad_order, nq)
|
|
107
|
+
refl_curves = refl_curves * gaussvals * weights
|
|
108
|
+
refl_curves = torch.sum(refl_curves, dim=1) * _INTLIMIT
|
|
109
|
+
|
|
110
|
+
return refl_curves
|